Python NumPy: How to Set Axis for Rows and Columns in NumPy
Understanding how the axis parameter works in NumPy is essential for performing efficient array operations. Whether you need to compute sums, means, or other aggregations array-wise, column-wise, or row-wise, the axis parameter controls the direction of the computation. This guide explains what each axis value means, how to use it correctly, and provides clear examples with outputs.
Understanding Axes in NumPy
In NumPy, a 2D array has two axes:
axis=0- refers to the vertical direction (along rows, i.e., operating on each column).axis=1- refers to the horizontal direction (along columns, i.e., operating on each row).axis=None- flattens the array and operates on all elements at once.
A common source of confusion is that axis=0 does not mean "row-wise." Instead, it means the operation collapses the rows and produces one result per column. Similarly, axis=1 collapses the columns and produces one result per row.
Think of the axis parameter as the axis that gets collapsed (removed) during the operation, not the axis that is preserved.
Key Functions Used
Before diving into the examples, here are the core NumPy functions referenced in this guide:
| Function | Description |
|---|---|
np.array(object) | Creates a NumPy array from a list, tuple, or other iterable. |
ndarray.reshape(rows, columns) | Reshapes the array to the specified dimensions. Use -1 for one dimension to let NumPy infer it automatically. |
ndarray.sum(axis) | Computes the sum of elements along the specified axis. |
Array-Wise Calculation with axis=None
When you set axis=None, NumPy treats the entire array as a flat sequence of elements and returns a single scalar value. This is the default behavior of most aggregation functions.
import numpy as np
nparray = np.array([
[1, 2, 3],
[11, 22, 33],
[4, 5, 6],
[8, 9, 10],
[20, 30, 40]
])
print("Array shape:", nparray.shape)
print(nparray)
# Sum all elements in the array (axis=None)
output = nparray.sum(axis=None)
print("\nSum array-wise:", output)
Output:
Array shape: (5, 3)
[[ 1 2 3]
[11 22 33]
[ 4 5 6]
[ 8 9 10]
[20 30 40]]
Sum array-wise: 204
Every element in the array is added together: 1 + 2 + 3 + 11 + ... + 40 = 204.
Column-Wise Calculation with axis=0
Setting axis=0 collapses the row dimension. The operation moves down each column and produces one result per column. For a 2D array with shape (5, 3), the result will have shape (3,).
import numpy as np
nparray = np.array([
[1, 2, 3],
[11, 22, 33],
[4, 5, 6],
[8, 9, 10],
[20, 30, 40]
])
print(nparray)
# Sum along axis=0 (column-wise)
output = nparray.sum(axis=0)
print("\nSum column-wise:", output)
Output:
[[ 1 2 3]
[11 22 33]
[ 4 5 6]
[ 8 9 10]
[20 30 40]]
Sum column-wise: [44 68 92]
Here is how each column sum is computed:
- Column 0:
1 + 11 + 4 + 8 + 20 = 44 - Column 1:
2 + 22 + 5 + 9 + 30 = 68 - Column 2:
3 + 33 + 6 + 10 + 40 = 92
Row-Wise Calculation with axis=1
Setting axis=1 collapses the column dimension. The operation moves across each row and produces one result per row. For a 2D array with shape (5, 3), the result will have shape (5,).
import numpy as np
nparray = np.array([
[1, 2, 3],
[11, 22, 33],
[4, 5, 6],
[8, 9, 10],
[20, 30, 40]
])
print(nparray)
# Sum along axis=1 (row-wise)
output = nparray.sum(axis=1)
print("\nSum row-wise:", output)
Output:
[[ 1 2 3]
[11 22 33]
[ 4 5 6]
[ 8 9 10]
[20 30 40]]
Sum row-wise: [ 6 66 15 27 90]
Each element in the result corresponds to the sum of one row:
- Row 0:
1 + 2 + 3 = 6 - Row 1:
11 + 22 + 33 = 66 - Row 2:
4 + 5 + 6 = 15 - Row 3:
8 + 9 + 10 = 27 - Row 4:
20 + 30 + 40 = 90
Common Mistake: Confusing axis=0 and axis=1
One of the most frequent errors when working with NumPy axes is mixing up axis=0 (column-wise) and axis=1 (row-wise). Consider the following scenario where you want the sum of each row but accidentally use axis=0:
import numpy as np
data = np.array([
[10, 20, 30],
[40, 50, 60]
])
# WRONG: trying to get row sums but using axis=0
row_sums = data.sum(axis=0)
print("Incorrect row sums:", row_sums)
Output:
Incorrect row sums: [50 70 90]
The result has 3 elements (one per column), not 2 (one per row). The correct approach is:
# CORRECT: use axis=1 for row-wise sums
row_sums = data.sum(axis=1)
print("Correct row sums:", row_sums)
Output:
Correct row sums: [30 150]
Remember: axis=0 collapses rows (result per column), and axis=1 collapses columns (result per row). The axis number indicates which dimension is removed, not which one is kept.
Using axis with Other NumPy Functions
The axis parameter is not exclusive to sum(). It works consistently across many NumPy aggregation functions:
import numpy as np
data = np.array([
[5, 1, 9],
[3, 7, 2],
[8, 4, 6]
])
print("Mean (column-wise, axis=0):", np.mean(data, axis=0))
print("Max (row-wise, axis=1):", np.max(data, axis=1))
print("Min (array-wise, axis=None):", np.min(data, axis=None))
print("Std (column-wise, axis=0):", np.std(data, axis=0).round(2))
Output:
Mean (column-wise, axis=0): [5.33333333 4. 5.66666667]
Max (row-wise, axis=1): [9 7 8]
Min (array-wise, axis=None): 1
Std (column-wise, axis=0): [2.05 2.45 2.87]
Quick Reference Summary
| Axis Value | Direction | Collapses | Result Shape (for (m, n) input) |
|---|---|---|---|
None | All elements | Both dimensions | Scalar |
0 | Down columns | Row dimension | (n,) |
1 | Across rows | Column dimension | (m,) |
Understanding how the axis parameter works in NumPy allows you to write concise, efficient, and correct array operations. Once you internalize that the axis value represents the dimension being collapsed, working with multi-dimensional data becomes significantly more intuitive.