Skip to main content

Python PyTorch: How to Slice a 3D Tensor in PyTorch

Slicing 3D tensors is a fundamental operation in deep learning and data processing with PyTorch. Whether you're working with batches of images, sequences of feature vectors, or volumetric data, understanding how to extract specific portions of a 3D tensor is essential. This guide explains the slicing syntax for 3D tensors, breaks down each dimension, and provides practical examples with outputs.

Understanding 3D Tensor Structure

A 3D tensor has three dimensions, often referred to as:

  • Dimension 0: The "depth" or batch dimension (e.g., different samples)
  • Dimension 1: The "row" dimension (e.g., channels or sequences)
  • Dimension 2: The "column" dimension (e.g., individual values or features)
import torch

# Create a 3D tensor with shape (2, 2, 8)
# 2 matrices, each with 2 rows and 8 columns
a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],

[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

print("Shape:", a.shape)
print(a)

Output:

Shape: torch.Size([2, 2, 8])
tensor([[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],

[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]])

This tensor contains 2 matrices (dimension 0), each with 2 rows (dimension 1) and 8 values (dimension 2).

Slicing Syntax

The general syntax for slicing a 3D tensor is:

tensor[dim0_start:dim0_end, dim1_start:dim1_end, dim2_start:dim2_end]
ComponentControlsExample
dim0_start:dim0_endWhich matrices to select0:1 selects the first matrix
dim1_start:dim1_endWhich rows within each matrix0:2 selects both rows
dim2_start:dim2_endWhich values within each row:5 selects the first 5 values
tip

Just like Python list slicing, you can use : alone to select everything along a dimension, omit the start to default to 0, or omit the end to go to the last element.

Slicing Examples

Select One Matrix, One Row, First 7 Values

import torch

a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

# First matrix (index 0), first row (index 0), first 7 values
result = a[0:1, 0:1, :7]
print("Shape:", result.shape)
print(result)

Output:

Shape: torch.Size([1, 1, 7])
tensor([[[1, 2, 3, 4, 5, 6, 7]]])

Select One Matrix, All Rows, First 3 Values

import torch

a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

# First matrix, both rows, first 3 values per row
result = a[0:1, 0:2, :3]
print("Shape:", result.shape)
print(result)

Output:

Shape: torch.Size([1, 2, 3])
tensor([[[ 1, 2, 3],
[10, 11, 12]]])

Select All Matrices, One Specific Row

import torch

a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

# Both matrices, only the second row (index 1), all values
result = a[0:2, 1, 0:8]
print("Shape:", result.shape)
print(result)

Output:

Shape: torch.Size([2, 8])
tensor([[10, 11, 12, 13, 14, 15, 16, 17],
[81, 82, 83, 84, 85, 86, 87, 88]])
note

Notice that specifying a single index (1) instead of a range (1:2) removes that dimension from the result. The output is 2D instead of 3D.

Select All Matrices, All Rows, Last 3 Values

import torch

a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

# All matrices, all rows, last 3 values
result = a[:, :, -3:]
print("Shape:", result.shape)
print(result)

Output:

Shape: torch.Size([2, 2, 3])
tensor([[[ 6, 7, 8],
[15, 16, 17]],

[[76, 77, 78],
[86, 87, 88]]])

Negative indexing works the same as in Python lists - -3: means "from the third-to-last element to the end."

Select a Single Element

import torch

a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

# Second matrix, first row, fifth value (index 4)
value = a[1, 0, 4]
print("Value:", value)
print("Type:", type(value))

Output:

Value: tensor(75)
Type: <class 'torch.Tensor'>

Using Step in Slices

You can add a step value to skip elements, just like Python's start:stop:step syntax:

import torch

a = torch.tensor([
[[ 1, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 12, 13, 14, 15, 16, 17]],
[[71, 72, 73, 74, 75, 76, 77, 78],
[81, 82, 83, 84, 85, 86, 87, 88]]
])

# All matrices, all rows, every other value
result = a[:, :, ::2]
print("Shape:", result.shape)
print(result)

Output:

Shape: torch.Size([2, 2, 4])
tensor([[[ 1, 3, 5, 7],
[10, 12, 14, 16]],

[[71, 73, 75, 77],
[81, 83, 85, 87]]])

Common Mistake: Confusing Index vs. Range Slicing

Using a single index removes a dimension, while using a range preserves it. This distinction matters when your code expects a specific number of dimensions:

import torch

a = torch.tensor([
[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]
])

# Single index: removes dimension 0, result is 2D
result_index = a[0]
print("Single index shape:", result_index.shape)

# Range slice: preserves dimension 0, result is 3D
result_range = a[0:1]
print("Range slice shape: ", result_range.shape)

Output:

Single index shape: torch.Size([2, 3])
Range slice shape: torch.Size([1, 2, 3])
warning

If downstream operations (like neural network layers) expect a 3D input, using a[0] instead of a[0:1] will cause a shape mismatch error. Use range slicing (0:1) to preserve the dimension, or use a[0].unsqueeze(0) to add it back.

Using Boolean and Advanced Indexing

For more complex selections, you can use boolean masks:

import torch

a = torch.tensor([
[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]
])

# Select all elements greater than 5
mask = a > 5
result = a[mask]
print("Elements > 5:", result)

Output:

Elements > 5: tensor([ 6,  7,  8,  9, 10, 11, 12])

Note that boolean indexing flattens the result into a 1D tensor.

Quick Reference

Slice SyntaxSelectsResulting Shape (for (2, 2, 8) tensor)
a[0]First matrix(2, 8) - dimension removed
a[0:1]First matrix(1, 2, 8) - dimension preserved
a[:, 0, :]First row of every matrix(2, 8)
a[:, :, :3]First 3 values in every row(2, 2, 3)
a[:, :, -2:]Last 2 values in every row(2, 2, 2)
a[:, :, ::2]Every other value(2, 2, 4)
a[1, 0, 4]Single elementScalar tensor

Slicing 3D tensors in PyTorch follows the same intuitive start:stop:step pattern as Python lists, extended across three dimensions. Mastering this syntax lets you efficiently extract batches, channels, feature subsets, and individual elements from complex tensor structures.