Skip to main content

Python PyTorch: How to Sort the Elements of a Tensor

Sorting tensor elements is a fundamental operation in PyTorch - used in ranking predictions, selecting top-k values, organizing feature maps, and preparing data for algorithms that require ordered input. The torch.sort() function provides flexible sorting along any dimension, returning both the sorted values and their original indices.

This guide covers sorting 1D and 2D tensors in ascending and descending order, along the rows or columns.

Understanding torch.sort() Syntax

torch.sort(input, dim=-1, descending=False)
ParameterDescriptionDefault
inputThe tensor to sortRequired
dimThe dimension along which to sort-1 (last dimension)
descendingTrue for descending order, False for ascendingFalse

Returns: A named tuple (values, indices) where:

  • values - the sorted tensor
  • indices - the positions of each sorted element in the original tensor

The indices are particularly useful for reordering other tensors to match the sorted order.

Sorting a 1D Tensor

Ascending Order (Default)

import torch

tensor = torch.tensor([42, -7, 15, 3, 0, 28, -12])
print("Original:", tensor)

values, indices = torch.sort(tensor)
print("Sorted values: ", values)
print("Original indices:", indices)

Output:

Original: tensor([ 42,  -7,  15,   3,   0,  28, -12])
Sorted values: tensor([-12, -7, 0, 3, 15, 28, 42])
Original indices: tensor([6, 1, 4, 3, 2, 5, 0])

The indices tell you where each sorted value came from. For example, -12 (the smallest value) was at index 6 in the original tensor.

Descending Order

import torch

tensor = torch.tensor([42, -7, 15, 3, 0, 28, -12])

values, indices = torch.sort(tensor, descending=True)
print("Sorted values (desc): ", values)
print("Original indices: ", indices)

Output:

Sorted values (desc):  tensor([ 42,  28,  15,   3,   0,  -7, -12])
Original indices: tensor([0, 5, 2, 3, 4, 1, 6])

Sorting a 2D Tensor Along Columns (dim=0)

For a 2D tensor, setting dim=0 sorts each column independently. Values within each column are rearranged while rows may get mixed:

import torch

tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])
print("Original:\n", tensor)

# Sort each column ascending
values, indices = torch.sort(tensor, dim=0)
print("\nSorted along columns (ascending):")
print("Values:\n", values)
print("Indices:\n", indices)

Output:

Original:
tensor([[ 43.0000, 31.0000, -92.0000],
[ 3.0000, -4.3000, 53.0000],
[ -4.2000, 7.0000, -6.2000]])

Sorted along columns (ascending):
Values:
tensor([[ -4.2000, -4.3000, -92.0000],
[ 3.0000, 7.0000, -6.2000],
[ 43.0000, 31.0000, 53.0000]])
Indices:
tensor([[2, 1, 0],
[1, 2, 2],
[0, 0, 1]])

Each column is sorted independently from smallest to largest. The indices show which original row each value came from.

Descending Order Along Columns

import torch

tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])
print("Original:\n", tensor)

# Sort each column descending
values, indices = torch.sort(tensor, dim=0, descending=True)
print("Sorted along columns (descending):")
print("Values:\n", values)
print("Indices:\n", indices)

Output:

Original:
tensor([[ 43.0000, 31.0000, -92.0000],
[ 3.0000, -4.3000, 53.0000],
[ -4.2000, 7.0000, -6.2000]])
Sorted along columns (descending):
Values:
tensor([[ 43.0000, 31.0000, 53.0000],
[ 3.0000, 7.0000, -6.2000],
[ -4.2000, -4.3000, -92.0000]])
Indices:
tensor([[0, 0, 1],
[1, 2, 2],
[2, 1, 0]])

Sorting a 2D Tensor Along Rows (dim=1)

Setting dim=1 sorts each row independently:

import torch

tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])

# Sort each row ascending
values, indices = torch.sort(tensor, dim=1)
print("Sorted along rows (ascending):")
print("Values:\n", values)
print("Indices:\n", indices)

Output:

Sorted along rows (ascending):
Values:
tensor([[-92.0000, 31.0000, 43.0000],
[ -4.3000, 3.0000, 53.0000],
[ -6.2000, -4.2000, 7.0000]])
Indices:
tensor([[2, 1, 0],
[1, 0, 2],
[2, 0, 1]])

Each row is sorted from smallest to largest. The indices show the original column position of each value within its row.

Descending Order Along Rows

import torch

tensor = torch.tensor([
[43.0, 31.0, -92.0],
[ 3.0, -4.3, 53.0],
[-4.2, 7.0, -6.2]
])

# Sort each row descending
values, indices = torch.sort(tensor, dim=1, descending=True)
print("Sorted along rows (descending):")
print("Values:\n", values)
print("Indices:\n", indices)

Output:

Sorted along rows (descending):
Values:
tensor([[ 43.0000, 31.0000, -92.0000],
[ 53.0000, 3.0000, -4.3000],
[ 7.0000, -4.2000, -6.2000]])
Indices:
tensor([[0, 1, 2],
[2, 0, 1],
[1, 0, 2]])
tip

Remember the dimension behavior:

  • dim=0: Sorts down each column (rearranges rows within each column)
  • dim=1: Sorts across each row (rearranges columns within each row)

Using Sorted Indices to Reorder Another Tensor

The indices returned by torch.sort() are commonly used to reorder a related tensor in the same order:

import torch

scores = torch.tensor([85.0, 92.0, 78.0, 95.0, 88.0])
names = ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve']

# Sort scores descending and get the ranking
values, indices = torch.sort(scores, descending=True)

print("Ranking:")
for rank, idx in enumerate(indices.tolist(), 1):
print(f" {rank}. {names[idx]} - {scores[idx].item()}")

Output:

Ranking:
1. Diana - 95.0
2. Bob - 92.0
3. Eve - 88.0
4. Alice - 85.0
5. Charlie - 78.0

Getting Only the Top-K or Bottom-K Elements

If you only need the largest or smallest values, torch.topk() is more efficient than sorting the entire tensor:

import torch

tensor = torch.tensor([42, -7, 15, 3, 0, 28, -12])

# Top 3 largest values
top_values, top_indices = torch.topk(tensor, k=3)
print("Top 3:", top_values, "at indices", top_indices)

# Bottom 3 smallest values
bottom_values, bottom_indices = torch.topk(tensor, k=3, largest=False)
print("Bottom 3:", bottom_values, "at indices", bottom_indices)

Output:

Top 3: tensor([42, 28, 15]) at indices tensor([0, 5, 2])
Bottom 3: tensor([-12, -7, 0]) at indices tensor([6, 1, 4])
info

torch.topk() is faster than torch.sort() when you only need a few elements because it uses a partial sort algorithm. Use torch.sort() when you need the complete sorted order.

Common Mistake: Confusing dim=0 and dim=1

A frequent error is using the wrong dimension, which sorts along an unintended axis:

import torch

tensor = torch.tensor([
[10, 30, 20],
[60, 40, 50]
])

# WRONG if intent is to sort each row: dim=0 sorts each column
values, _ = torch.sort(tensor, dim=0)
print("dim=0 (sorts columns, not rows):")
print(values)

Output:

dim=0 (sorts columns, not rows):
tensor([[10, 30, 20],
[60, 40, 50]])

This sorts each column independently, which may not be the desired behavior if you wanted to sort within each row.

The correct approach for row-wise sorting:

import torch

tensor = torch.tensor([
[10, 30, 20],
[60, 40, 50]
])

# CORRECT: dim=1 sorts within each row
values, _ = torch.sort(tensor, dim=1)
print("dim=1 (sorts each row):")
print(values)

Output:

dim=1 (sorts each row):
tensor([[10, 20, 30],
[40, 50, 60]])

Quick Reference

GoalCode
Sort 1D ascendingtorch.sort(tensor)
Sort 1D descendingtorch.sort(tensor, descending=True)
Sort 2D along columnstorch.sort(tensor, dim=0)
Sort 2D along rowstorch.sort(tensor, dim=1)
Get only top-k valuestorch.topk(tensor, k=3)
Get only bottom-k valuestorch.topk(tensor, k=3, largest=False)
Get sorted indices onlytorch.argsort(tensor)

The torch.sort() function returns both sorted values and their original indices, making it versatile for ranking, reordering, and selection operations across any tensor dimension.