Skip to main content

Python PyTorch: How to Squeeze and Unsqueeze a Tensor

In deep learning, tensor dimensions must often be adjusted to match the expected input shapes of models, layers, or operations. PyTorch provides two essential methods for this: torch.squeeze() removes dimensions of size 1, and torch.unsqueeze() adds a new dimension of size 1 at a specified position.

This guide explains both operations with clear examples showing how tensor shapes change.

Understanding Squeeze and Unsqueeze

OperationWhat It DoesExample Shape Change
squeeze()Removes all dimensions of size 1(1, 3, 1, 4)(3, 4)
squeeze(dim=n)Removes dimension n only if its size is 1(1, 3, 1, 4) with dim=0(3, 1, 4)
unsqueeze(dim=n)Inserts a new dimension of size 1 at position n(3, 4) with dim=0(1, 3, 4)

These operations do not change the data - they only reshape how the data is organized into dimensions.

Squeezing a Tensor with torch.squeeze()

Syntax

torch.squeeze(input, dim=None)
ParameterDescription
inputThe input tensor
dimOptional integer specifying which dimension to squeeze. If omitted, all dimensions of size 1 are removed

Removing All Size-1 Dimensions

When called without a dim argument, squeeze() removes every dimension that has size 1:

import torch

tensor = torch.randn(3, 1, 2, 1, 4)
print("Before squeeze:", tensor.shape)

squeezed = torch.squeeze(tensor)
print("After squeeze: ", squeezed.shape)

Output:

Before squeeze: torch.Size([3, 1, 2, 1, 4])
After squeeze: torch.Size([3, 2, 4])

Both dimensions of size 1 (at positions 1 and 3) are removed. The data is unchanged - only the shape is different.

Squeezing a Specific Dimension

When you specify dim, only that particular dimension is removed - and only if its size is 1. If the specified dimension has a size greater than 1, the tensor is returned unchanged:

import torch

tensor = torch.randn(3, 1, 2, 1, 4)
print("Original shape:", tensor.shape)

# Squeeze dimension 1 (size is 1 → removed)
result = torch.squeeze(tensor, dim=1)
print("Squeeze dim=1:", result.shape)

# Squeeze dimension 0 (size is 3 → NOT removed)
result = torch.squeeze(tensor, dim=0)
print("Squeeze dim=0:", result.shape)

# Squeeze dimension 3 (size is 1 → removed)
result = torch.squeeze(tensor, dim=3)
print("Squeeze dim=3:", result.shape)

Output:

Original shape: torch.Size([3, 1, 2, 1, 4])
Squeeze dim=1: torch.Size([3, 2, 1, 4])
Squeeze dim=0: torch.Size([3, 1, 2, 1, 4])
Squeeze dim=3: torch.Size([3, 1, 2, 4])
tip

Squeezing a specific dimension is safer than squeezing all dimensions. In deep learning pipelines, you often want to remove only the batch dimension or a specific singleton dimension, not all of them. Using dim prevents accidentally collapsing dimensions you need.

Unsqueezing a Tensor with torch.unsqueeze()

Syntax

torch.unsqueeze(input, dim)
ParameterDescription
inputThe input tensor
dimRequired integer specifying where to insert the new dimension of size 1

Unlike squeeze(), the dim parameter is required for unsqueeze().

Adding a Dimension to a 1D Tensor

Unsqueezing a 1D tensor creates a 2D tensor - either a row vector or a column vector depending on the dimension:

import torch

tensor = torch.arange(5, dtype=torch.float)
print("Original:", tensor.shape, tensor)

# Insert dimension at position 0 → row vector (1×5)
row = torch.unsqueeze(tensor, dim=0)
print("\nUnsqueeze dim=0:", row.shape)
print(row)

# Insert dimension at position 1 → column vector (5×1)
col = torch.unsqueeze(tensor, dim=1)
print("\nUnsqueeze dim=1:", col.shape)
print(col)

Output:

Original: torch.Size([5]) tensor([0., 1., 2., 3., 4.])

Unsqueeze dim=0: torch.Size([1, 5])
tensor([[0., 1., 2., 3., 4.]])

Unsqueeze dim=1: torch.Size([5, 1])
tensor([[0.],
[1.],
[2.],
[3.],
[4.]])

Adding a Batch Dimension

A common use case is adding a batch dimension (position 0) to a single sample before passing it to a model:

import torch

# Single image: 3 channels, 32x32 pixels
image = torch.randn(3, 32, 32)
print("Single image shape:", image.shape)

# Add batch dimension for model input
batch = image.unsqueeze(0)
print("With batch dimension:", batch.shape)

Output:

Single image shape: torch.Size([3, 32, 32])
With batch dimension: torch.Size([1, 3, 32, 32])

Using Negative Dimensions

Negative values for dim count from the end. -1 inserts the new dimension at the last position:

import torch

tensor = torch.randn(3, 4)
print("Original:", tensor.shape)

# dim=-1 inserts at the last position
result = tensor.unsqueeze(-1)
print("Unsqueeze dim=-1:", result.shape)

# dim=-2 inserts at the second-to-last position
result = tensor.unsqueeze(-2)
print("Unsqueeze dim=-2:", result.shape)

Output:

Original: torch.Size([3, 4])
Unsqueeze dim=-1: torch.Size([3, 4, 1])
Unsqueeze dim=-2: torch.Size([3, 1, 4])

Squeeze and Unsqueeze Are Inverses

Squeezing and unsqueezing are complementary operations. You can undo an unsqueeze with a squeeze (and vice versa):

import torch

original = torch.randn(3, 4)
print("Original: ", original.shape)

# Add a dimension
expanded = original.unsqueeze(1)
print("After unsqueeze:", expanded.shape)

# Remove it
restored = expanded.squeeze(1)
print("After squeeze: ", restored.shape)

# Verify the data is unchanged
print("Data unchanged: ", torch.equal(original, restored))

Output:

Original:     torch.Size([3, 4])
After unsqueeze: torch.Size([3, 1, 4])
After squeeze: torch.Size([3, 4])
Data unchanged: True

Common Mistake: Squeezing the Batch Dimension Accidentally

When working with batches of size 1, calling squeeze() without specifying a dimension can accidentally remove the batch dimension:

import torch

# Model output: batch of 1 sample, 10 classes
output = torch.randn(1, 10)
print("Original shape:", output.shape)

# WRONG: squeeze without dim removes ALL size-1 dimensions including batch
squeezed = torch.squeeze(output)
print("After squeeze():", squeezed.shape)

Output:

Original shape: torch.Size([1, 10])
After squeeze(): torch.Size([10])

The batch dimension is gone, which may break downstream code that expects a 2D tensor.

The correct approach - specify the dimension explicitly:

# CORRECT: only squeeze a specific dimension if needed
# If you want to keep the batch dimension, don't squeeze dim=0
squeezed = torch.squeeze(output, dim=2) # No change since dim 2 doesn't exist at size 1
print("After squeeze(dim=2):", output.shape) # Unchanged
danger

Avoid calling torch.squeeze() without specifying dim in production code. When batch size happens to be 1, it silently removes the batch dimension, causing shape mismatches that are difficult to debug. Always specify which dimension to squeeze.

Using Tensor Methods Instead of Functions

Both operations can also be called as tensor methods instead of standalone functions:

import torch

tensor = torch.randn(1, 3, 1, 4)

# Function style
squeezed = torch.squeeze(tensor, dim=0)
unsqueezed = torch.unsqueeze(tensor, dim=2)

# Method style (equivalent)
squeezed = tensor.squeeze(dim=0)
unsqueezed = tensor.unsqueeze(dim=2)

print("Squeezed:", squeezed.shape)
print("Unsqueezed:", unsqueezed.shape)

Output:

Squeezed: torch.Size([3, 1, 4])
Unsqueezed: torch.Size([1, 3, 1, 1, 4])

Both styles produce identical results. The method style (tensor.squeeze()) is generally preferred for readability.

Quick Reference

OperationCodeShape Change
Remove all size-1 dimstensor.squeeze()(1,3,1,4)(3,4)
Remove specific size-1 dimtensor.squeeze(dim=0)(1,3,4)(3,4)
Add dim at position 0tensor.unsqueeze(0)(3,4)(1,3,4)
Add dim at last positiontensor.unsqueeze(-1)(3,4)(3,4,1)
Add batch dimensionimage.unsqueeze(0)(C,H,W)(1,C,H,W)
Remove batch dimensionoutput.squeeze(0)(1,N)(N,)

Squeeze and unsqueeze are simple but essential operations for managing tensor shapes in PyTorch. They ensure your tensors have the correct number of dimensions for broadcasting, model input, loss computation, and other operations that are strict about shape requirements.