Python PyTorch: How to Sort the Elements of a Tensor
Sorting tensor elements is a fundamental operation in PyTorch - used in ranking predictions, selecting top-k values, organizing feature maps, and preparing data for algorithms that require ordered input. The torch.sort() function provides flexible sorting along any dimension, returning both the sorted values and their original indices.
This guide covers sorting 1D and 2D tensors in ascending and descending order, along the rows or columns.
Understanding torch.sort() Syntax
torch.sort(input, dim=-1, descending=False)
| Parameter | Description | Default |
|---|---|---|
input | The tensor to sort | Required |
dim | The dimension along which to sort | -1 (last dimension) |
descending | True for descending order, False for ascending | False |
Returns: A named tuple (values, indices) where:
values- the sorted tensorindices- the positions of each sorted element in the original tensor
The indices are particularly useful for reordering other tensors to match the sorted order.
Sorting a 1D Tensor
Ascending Order (Default)
import torch
tensor = torch.tensor([42, -7, 15, 3, 0, 28, -12])
print("Original:", tensor)
values, indices = torch.sort(tensor)
print("Sorted values: ", values)
print("Original indices:", indices)
Output:
Original: tensor([ 42, -7, 15, 3, 0, 28, -12])
Sorted values: tensor([-12, -7, 0, 3, 15, 28, 42])
Original indices: tensor([6, 1, 4, 3, 2, 5, 0])
The indices tell you where each sorted value came from. For example, -12 (the smallest value) was at index 6 in the original tensor.
Descending Order
import torch
tensor = torch.tensor([42, -7, 15, 3, 0, 28, -12])
values, indices = torch.sort(tensor, descending=True)
print("Sorted values (desc): ", values)
print("Original indices: ", indices)
Output:
Sorted values (desc): tensor([ 42, 28, 15, 3, 0, -7, -12])
Original indices: tensor([0, 5, 2, 3, 4, 1, 6])
Sorting a 2D Tensor Along Columns (dim=0)
For a 2D tensor, setting dim=0 sorts each column independently. Values within each column are rearranged while rows may get mixed:
import torch
tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])
print("Original:\n", tensor)
# Sort each column ascending
values, indices = torch.sort(tensor, dim=0)
print("\nSorted along columns (ascending):")
print("Values:\n", values)
print("Indices:\n", indices)
Output:
Original:
tensor([[ 43.0000, 31.0000, -92.0000],
[ 3.0000, -4.3000, 53.0000],
[ -4.2000, 7.0000, -6.2000]])
Sorted along columns (ascending):
Values:
tensor([[ -4.2000, -4.3000, -92.0000],
[ 3.0000, 7.0000, -6.2000],
[ 43.0000, 31.0000, 53.0000]])
Indices:
tensor([[2, 1, 0],
[1, 2, 2],
[0, 0, 1]])
Each column is sorted independently from smallest to largest. The indices show which original row each value came from.
Descending Order Along Columns
import torch
tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])
print("Original:\n", tensor)
# Sort each column descending
values, indices = torch.sort(tensor, dim=0, descending=True)
print("Sorted along columns (descending):")
print("Values:\n", values)
print("Indices:\n", indices)
Output:
Original:
tensor([[ 43.0000, 31.0000, -92.0000],
[ 3.0000, -4.3000, 53.0000],
[ -4.2000, 7.0000, -6.2000]])
Sorted along columns (descending):
Values:
tensor([[ 43.0000, 31.0000, 53.0000],
[ 3.0000, 7.0000, -6.2000],
[ -4.2000, -4.3000, -92.0000]])
Indices:
tensor([[0, 0, 1],
[1, 2, 2],
[2, 1, 0]])
Sorting a 2D Tensor Along Rows (dim=1)
Setting dim=1 sorts each row independently:
import torch
tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])
# Sort each row ascending
values, indices = torch.sort(tensor, dim=1)
print("Sorted along rows (ascending):")
print("Values:\n", values)
print("Indices:\n", indices)
Output:
Sorted along rows (ascending):
Values:
tensor([[-92.0000, 31.0000, 43.0000],
[ -4.3000, 3.0000, 53.0000],
[ -6.2000, -4.2000, 7.0000]])
Indices:
tensor([[2, 1, 0],
[1, 0, 2],
[2, 0, 1]])
Each row is sorted from smallest to largest. The indices show the original column position of each value within its row.
Descending Order Along Rows
import torch
tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])
# Sort each row descending
values, indices = torch.sort(tensor, dim=1, descending=True)
print("Sorted along rows (descending):")
print("Values:\n", values)
print("Indices:\n", indices)
Output:
Sorted along rows (descending):
Values:
tensor([[ 43.0000, 31.0000, -92.0000],
[ 53.0000, 3.0000, -4.3000],
[ 7.0000, -4.2000, -6.2000]])
Indices:
tensor([[0, 1, 2],
[2, 0, 1],
[1, 0, 2]])
Remember the dimension behavior:
dim=0: Sorts down each column (rearranges rows within each column)dim=1: Sorts across each row (rearranges columns within each row)
Using Sorted Indices to Reorder Another Tensor
The indices returned by torch.sort() are commonly used to reorder a related tensor in the same order:
import torch
scores = torch.tensor([85.0, 92.0, 78.0, 95.0, 88.0])
names = ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve']
# Sort scores descending and get the ranking
values, indices = torch.sort(scores, descending=True)
print("Ranking:")
for rank, idx in enumerate(indices.tolist(), 1):
print(f" {rank}. {names[idx]} - {scores[idx].item()}")
Output:
Ranking:
1. Diana - 95.0
2. Bob - 92.0
3. Eve - 88.0
4. Alice - 85.0
5. Charlie - 78.0
Getting Only the Top-K or Bottom-K Elements
If you only need the largest or smallest values, torch.topk() is more efficient than sorting the entire tensor:
import torch
tensor = torch.tensor([42, -7, 15, 3, 0, 28, -12])
# Top 3 largest values
top_values, top_indices = torch.topk(tensor, k=3)
print("Top 3:", top_values, "at indices", top_indices)
# Bottom 3 smallest values
bottom_values, bottom_indices = torch.topk(tensor, k=3, largest=False)
print("Bottom 3:", bottom_values, "at indices", bottom_indices)
Output:
Top 3: tensor([42, 28, 15]) at indices tensor([0, 5, 2])
Bottom 3: tensor([-12, -7, 0]) at indices tensor([6, 1, 4])
torch.topk() is faster than torch.sort() when you only need a few elements because it uses a partial sort algorithm. Use torch.sort() when you need the complete sorted order.
Common Mistake: Confusing dim=0 and dim=1
A frequent error is using the wrong dimension, which sorts along an unintended axis:
import torch
tensor = torch.tensor([
[10, 30, 20],
[60, 40, 50]
])
# WRONG if intent is to sort each row: dim=0 sorts each column
values, _ = torch.sort(tensor, dim=0)
print("dim=0 (sorts columns, not rows):")
print(values)
Output:
dim=0 (sorts columns, not rows):
tensor([[10, 30, 20],
[60, 40, 50]])
This sorts each column independently, which may not be the desired behavior if you wanted to sort within each row.
The correct approach for row-wise sorting:
import torch
tensor = torch.tensor([
[10, 30, 20],
[60, 40, 50]
])
# CORRECT: dim=1 sorts within each row
values, _ = torch.sort(tensor, dim=1)
print("dim=1 (sorts each row):")
print(values)
Output:
dim=1 (sorts each row):
tensor([[10, 20, 30],
[40, 50, 60]])
Quick Reference
| Goal | Code |
|---|---|
| Sort 1D ascending | torch.sort(tensor) |
| Sort 1D descending | torch.sort(tensor, descending=True) |
| Sort 2D along columns | torch.sort(tensor, dim=0) |
| Sort 2D along rows | torch.sort(tensor, dim=1) |
| Get only top-k values | torch.topk(tensor, k=3) |
| Get only bottom-k values | torch.topk(tensor, k=3, largest=False) |
| Get sorted indices only | torch.argsort(tensor) |
The torch.sort() function returns both sorted values and their original indices, making it versatile for ranking, reordering, and selection operations across any tensor dimension.