Skip to main content

Python NumPy: How to Index a 3D NumPy Array Using a 2D Index Array in Python

When working with multi-dimensional data in NumPy, you often need to select elements from a 3D array where the index along one axis varies based on position. For example, selecting different "depth" layers at each row-column position of a 3D array - a task that requires the indices to be stored in a separate 2D array.

NumPy's np.take_along_axis() function is specifically designed for this type of advanced indexing.

In this guide, you will learn how this technique works step by step, see practical examples, and understand how to apply it to your own data processing tasks.

The Problem

Consider a 3D array with shape (3, 3, 3), i.e. think of it as 3 layers of 3×3 grids. You want to select one element per position from these layers, where a separate 2D array tells you which layer to pick from at each position.

import numpy as np

# 3D array: shape (3, 3, 3), 3 layers of 3x3 grids
val_arr = np.arange(27).reshape(3, 3, 3)
print("3D array:")
print(val_arr)

Output:

3D array:
[[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]]

[[ 9 10 11]
[12 13 14]
[15 16 17]]

[[18 19 20]
[21 22 23]
[24 25 26]]]

The 2D index array specifies which layer (axis 0) to select at each (row, column) position:

# 2D index array: shape (3, 3) (which layer to pick at each position)
z_indices = np.array([
[1, 0, 2],
[0, 0, 1],
[2, 0, 1]
])
print("Index array:")
print(z_indices)

Output:

Index array:
[[1 0 2]
[0 0 1]
[2 0 1]]

For position (0, 0), the index is 1, meaning select from layer 1 → value 9. For position (0, 1), the index is 0, meaning select from layer 0 → value 1. And so on.

The Solution: np.take_along_axis()

The np.take_along_axis() function selects elements from an array along a specified axis using indices from another array. The key requirement is that both arrays must have the same number of dimensions.

Step 1: Expand the Index Array Dimensions

The 3D array has 3 dimensions, but the index array has only 2. Use np.expand_dims() to add the missing dimension:

# Original shape: (3, 3) → expanded shape: (1, 3, 3)
z_indices_expanded = np.expand_dims(z_indices, axis=0)

print(f"Original index shape: {z_indices.shape}")
print(f"Expanded index shape: {z_indices_expanded.shape}")

Output:

Original index shape: (3, 3)
Expanded index shape: (1, 3, 3)

Step 2: Select Elements With take_along_axis

Now use np.take_along_axis() to select elements along axis 0 (the "layer" axis):

result = np.take_along_axis(val_arr, z_indices_expanded, axis=0)

print(f"Result shape: {result.shape}")
print("Result:")
print(result)

Output:

Result shape: (1, 3, 3)
Result:
[[[ 9 1 20]
[ 3 4 14]
[24 7 17]]]

Step 3: Remove the Extra Dimension

The result has an unnecessary trailing dimension (1, 3, 3). Use np.squeeze() to collapse it back to (3, 3):

# Squeeze axis 0
final_result = np.squeeze(result, axis=0)

print(f"Final shape: {final_result.shape}") # (3, 3)
print("Result:\n", final_result)

Output:

Final shape: (3, 3)
Result:
[[ 9 1 20]
[ 3 4 14]
[24 7 17]]

Verification

Let's verify a few values manually:

print("Verification:")
print(f"Position (0,0), index 1: val_arr[1,0,0] = {val_arr[1,0,0]} (result: {result[0,0]})")
print(f"Position (0,1), index 0: val_arr[0,0,1] = {val_arr[0,0,1]} (result: {result[0,1]})")
print(f"Position (0,2), index 2: val_arr[2,0,2] = {val_arr[2,0,2]} (result: {result[0,2]})")
print(f"Position (2,0), index 2: val_arr[2,2,0] = {val_arr[2,2,0]} (result: {result[2,0]})")

Output:

Verification:
Position (0,0), index 1: val_arr[1,0,0] = 9 (result: 9)
Position (0,1), index 0: val_arr[0,0,1] = 1 (result: 1)
Position (0,2), index 2: val_arr[2,0,2] = 20 (result: 20)
Position (2,0), index 2: val_arr[2,2,0] = 24 (result: 24)
PositionIndexLayerValue
(0, 0)1Layer 1 → val_arr[1, 0, 0]9
(0, 1)0Layer 0 → val_arr[0, 0, 1]1
(0, 2)2Layer 2 → val_arr[2, 0, 2]20
(2, 0)2Layer 2 → val_arr[2, 2, 0]24

Complete Implementation

import numpy as np

# Create a 3D array of shape (3, 3, 3)
val_arr = np.arange(27).reshape(3, 3, 3)

# Create a 2D array of indices of shape (3, 3)
z_indices = np.array([
[1, 0, 2],
[0, 0, 1],
[2, 0, 1]
])

# Step 1: Expand dimensions to match the 3D array
z_indices_expanded = np.expand_dims(z_indices, axis=0)

# Step 2: Select elements along axis 0
result = np.take_along_axis(val_arr, z_indices_expanded, axis=0)

# Step 3: Remove the extra dimension
result = np.squeeze(result, axis=0)

print(result)

Output:

[[ 9  1 20]
[ 3 4 14]
[24 7 17]]

Indexing Along Different Axes

The same technique works along any axis. The axis you specify in take_along_axis determines which dimension the indices select from.

Indexing Along the Last Axis (axis=2)

Select different column indices at each (layer, row) position:

import numpy as np

val_arr = np.arange(27).reshape(3, 3, 3)

# Which column to select at each (layer, row) position
col_indices = np.array([
[0, 2, 1],
[1, 0, 2],
[2, 1, 0]
])

# Expand along the last axis (axis=2)
col_indices_expanded = np.expand_dims(col_indices, axis=-1)

result = np.take_along_axis(val_arr, col_indices_expanded, axis=2)
result = np.squeeze(result, axis=-1)

print(result)

Output:

[[ 0  5  7]
[10 12 17]
[20 22 24]]

Indexing Along the Middle Axis (axis=1)

Select different row indices at each (layer, column) position:

import numpy as np

val_arr = np.arange(27).reshape(3, 3, 3)

# Which row to select at each (layer, column) position
row_indices = np.array([
[2, 0, 1],
[0, 1, 2],
[1, 2, 0]
])

# Expand along axis 1
row_indices_expanded = np.expand_dims(row_indices, axis=1)

result = np.take_along_axis(val_arr, row_indices_expanded, axis=1)
result = np.squeeze(result, axis=1)

print(result)

Output:

[[ 6  1  5]
[ 9 13 17]
[21 25 20]]

Alternative: Using np.choose() for Axis 0

For indexing along axis 0 specifically, np.choose() provides a simpler syntax:

import numpy as np

val_arr = np.arange(27).reshape(3, 3, 3)
z_indices = np.array([[1, 0, 2], [0, 0, 1], [2, 0, 1]])

result = np.choose(z_indices, val_arr)
print(result)

Output:

[[ 9  1 20]
[ 3 4 14]
[24 7 17]]
When to use np.choose() vs. np.take_along_axis()
  • np.choose(): Simpler syntax, but limited to indexing along axis 0 and has a maximum of 32 choices.
  • np.take_along_axis(): More general - works along any axis and has no choice limit. Recommended for most use cases.

Practical Example: Selecting Best Scores Across Categories

Imagine you have test scores for 4 students across 3 subjects, measured over 2 time periods. You know which time period each student performed best in for each subject:

import numpy as np

# Scores: shape (2, 4, 3) (2 periods, 4 students, 3 subjects)
scores = np.array([
[[85, 90, 78], # Period 0: Student 0
[92, 88, 95], # Period 0: Student 1
[76, 82, 89], # Period 0: Student 2
[88, 91, 84]], # Period 0: Student 3

[[89, 87, 82], # Period 1: Student 0
[90, 94, 91], # Period 1: Student 1
[80, 79, 92], # Period 1: Student 2
[85, 93, 88]] # Period 1: Student 3
])

# Best period for each student-subject: shape (4, 3)
best_period = np.array([
[1, 0, 1], # Student 0: best in period 1, 0, 1
[0, 1, 0], # Student 1: best in period 0, 1, 0
[1, 0, 1], # Student 2
[0, 1, 1] # Student 3
])

# Select best scores
best_period_expanded = np.expand_dims(best_period, axis=0)
best_scores = np.take_along_axis(scores, best_period_expanded, axis=0)
best_scores = np.squeeze(best_scores, axis=0)

print("Best scores per student per subject:")
print(best_scores)

Output:

Best scores per student per subject:
[[89 90 82]
[92 94 95]
[80 82 92]
[88 93 88]]

Common Mistake: Dimension Mismatch

Forgetting to expand the index array's dimensions causes an error:

import numpy as np

val_arr = np.arange(27).reshape(3, 3, 3)
z_indices = np.array([[1, 0, 2], [0, 0, 1], [2, 0, 1]])

# ❌ Index array has 2 dimensions, val_arr has 3
try:
result = np.take_along_axis(val_arr, z_indices, axis=0)
except ValueError as e:
print(f"Error: {e}")

Output:

Error: `indices` and `arr` must have the same number of dimensions

Fix: Always expand the index array to match the number of dimensions:

# ✅ Expand to 3 dimensions first
z_indices_expanded = np.expand_dims(z_indices, axis=0)
result = np.take_along_axis(val_arr, z_indices_expanded, axis=0)
result = np.squeeze(result, axis=0)
print("Fixed result:")
print(result)

Output:

Fixed result:
[[ 9 1 20]
[ 3 4 14]
[24 7 17]]

Conclusion

Indexing a 3D NumPy array using indices stored in a 2D array is accomplished with three steps:

  • expand the index array's dimensions with np.expand_dims(),
  • select elements with np.take_along_axis(),
  • squeeze the result to remove the extra dimension.

This technique works along any axis and scales to arrays of any dimensionality, making it indispensable for advanced data selection tasks in scientific computing, image processing, and machine learning.