Skip to main content

Python PyTorch: How to Find the k-th and Top k Elements of a Tensor

Finding the k-th smallest element or the top k largest elements in a tensor are common operations in machine learning and data processing workflows - from selecting top predictions in classification models to implementing k-nearest neighbors algorithms. PyTorch provides two dedicated functions for these tasks: torch.kthvalue() and torch.topk().

This guide explains both functions with clear examples, covering 1D and multi-dimensional tensors, and demonstrates practical use cases.

Finding the k-th Smallest Element with torch.kthvalue()

The torch.kthvalue() function sorts the tensor in ascending order and returns the k-th smallest element along with its index in the original tensor.

Syntax:

torch.kthvalue(input, k, dim=None, keepdim=False, out=None)
ParameterDescription
inputThe input tensor
kInteger specifying which smallest element to find (1-based)
dimDimension along which to find the k-th value (default: last dimension)
keepdimWhether the output tensor retains the reduced dimension

Returns: A named tuple (values, indices) containing the k-th smallest value and its index.

Example: k-th Element in a 1D Tensor

import torch

tensor = torch.tensor([4, 5, -3, 9, 7])
print(f"Original tensor: {tensor}")

# Find the 1st smallest (minimum)
val, idx = torch.kthvalue(tensor, 1)
print(f"\n1st smallest -> Value: {val.item()}, Index: {idx.item()}")

# Find the 3rd smallest
val, idx = torch.kthvalue(tensor, 3)
print(f"3rd smallest -> Value: {val.item()}, Index: {idx.item()}")

# Find the 5th smallest (maximum)
val, idx = torch.kthvalue(tensor, 5)
print(f"5th smallest -> Value: {val.item()}, Index: {idx.item()}")

Output:

Original tensor: tensor([ 4,  5, -3,  9,  7])

1st smallest -> Value: -3, Index: 2
3rd smallest -> Value: 5, Index: 1
5th smallest -> Value: 9, Index: 3

The sorted order is [-3, 4, 5, 7, 9], so the 3rd smallest is 5, which is at index 1 in the original tensor.

Example: k-th Element Along a Dimension in a 2D Tensor

import torch

tensor = torch.tensor([
[10, 30, 20],
[60, 40, 50],
[90, 70, 80]
])
print(f"Original tensor:\n{tensor}\n")

# Find the 2nd smallest value along each row (dim=1)
values, indices = torch.kthvalue(tensor, 2, dim=1)
print(f"2nd smallest per row:")
print(f" Values: {values}")
print(f" Indices: {indices}")

# Find the 2nd smallest value along each column (dim=0)
values, indices = torch.kthvalue(tensor, 2, dim=0)
print(f"\n2nd smallest per column:")
print(f" Values: {values}")
print(f" Indices: {indices}")

Output:

Original tensor:
tensor([[10, 30, 20],
[60, 40, 50],
[90, 70, 80]])

2nd smallest per row:
Values: tensor([20, 50, 80])
Indices: tensor([2, 2, 2])

2nd smallest per column:
Values: tensor([60, 40, 50])
Indices: tensor([1, 1, 1])
  • dim=1 (along rows): For row [10, 30, 20], sorted is [10, 20, 30], so the 2nd smallest is 20.
  • dim=0 (along columns): For column [10, 60, 90], sorted is [10, 60, 90], so the 2nd smallest is 60.

Finding the Top k Elements with torch.topk()

The torch.topk() function returns the k largest (or smallest) elements from a tensor along a specified dimension, along with their indices.

Syntax:

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
ParameterDescription
inputThe input tensor
kNumber of top elements to return
dimDimension along which to find top k (default: last dimension)
largestIf True, returns the k largest; if False, returns the k smallest
sortedWhether the returned elements are sorted in descending order

Returns: A named tuple (values, indices).

Example: Top k Elements in a 1D Tensor

import torch

tensor = torch.tensor([5.344, 8.343, -2.398, -0.995, 5.0, 30.421])
print(f"Original tensor: {tensor}")

# Find the top 3 largest elements
values, indices = torch.topk(tensor, 3)
print(f"\nTop 3 largest:")
print(f" Values: {values}")
print(f" Indices: {indices}")

# Find the top 3 smallest elements
values, indices = torch.topk(tensor, 3, largest=False)
print(f"\nTop 3 smallest:")
print(f" Values: {values}")
print(f" Indices: {indices}")

Output:

Original tensor: tensor([ 5.3440,  8.3430, -2.3980, -0.9950,  5.0000, 30.4210])

Top 3 largest:
Values: tensor([30.4210, 8.3430, 5.3440])
Indices: tensor([5, 1, 0])

Top 3 smallest:
Values: tensor([-2.3980, -0.9950, 5.0000])
Indices: tensor([2, 3, 4])

Example: Top k Along a Dimension in a 2D Tensor

import torch

tensor = torch.tensor([
[12, 45, 23, 67],
[89, 34, 56, 11],
[78, 90, 43, 21]
])
print(f"Original tensor:\n{tensor}\n")

# Top 2 largest per row (dim=1)
values, indices = torch.topk(tensor, 2, dim=1)
print(f"Top 2 per row:")
print(f" Values:\n{values}")
print(f" Indices:\n{indices}")

# Top 2 largest per column (dim=0)
values, indices = torch.topk(tensor, 2, dim=0)
print(f"\nTop 2 per column:")
print(f" Values:\n{values}")
print(f" Indices:\n{indices}")

Output:

Original tensor:
tensor([[12, 45, 23, 67],
[89, 34, 56, 11],
[78, 90, 43, 21]])

Top 2 per row:
Values:
tensor([[67, 45],
[89, 56],
[90, 78]])
Indices:
tensor([[3, 1],
[0, 2],
[1, 0]])

Top 2 per column:
Values:
tensor([[89, 90, 56, 67],
[78, 45, 43, 21]])
Indices:
tensor([[1, 2, 1, 0],
[2, 0, 2, 2]])

Practical Example: Top k Predictions in Classification

A common real-world use of torch.topk() is finding the top k predicted classes from a model's output:

import torch

# Simulated model output (logits for 6 classes)
logits = torch.tensor([1.2, 3.8, 0.5, 4.1, 2.7, 0.9])
class_names = ["cat", "dog", "bird", "fish", "horse", "snake"]

# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=0)

# Get top 3 predictions
top_probs, top_indices = torch.topk(probs, 3)

print("Top 3 predictions:")
for prob, idx in zip(top_probs, top_indices):
print(f" {class_names[idx]:>6s}: {prob.item():.4f} ({prob.item()*100:.1f}%)")

Output:

Top 3 predictions:
fish: 0.4738 (47.4%)
dog: 0.3510 (35.1%)
horse: 0.1168 (11.7%)

Finding the Median with torch.kthvalue()

You can use kthvalue() to find the median of a tensor:

import torch

tensor = torch.tensor([15.0, 3.0, 8.0, 22.0, 11.0])
n = tensor.size(0)

# For odd-length tensors, median is the middle element
median_val, median_idx = torch.kthvalue(tensor, (n + 1) // 2)
print(f"Tensor: {tensor}")
print(f"Median: {median_val.item()} (at index {median_idx.item()})")

Output:

Tensor: tensor([15.,  3.,  8., 22., 11.])
Median: 11.0 (at index 4)

Common Mistake: k Out of Range

Passing a k value larger than the dimension size raises a runtime error:

import torch

tensor = torch.tensor([1, 2, 3])

# ❌ Wrong: k=5 but tensor only has 3 elements
try:
values, indices = torch.topk(tensor, 5)
except RuntimeError as e:
print(f"Error: {e}")

Output:

Error: k (5) is too big for dimension size (3)

The fix: Ensure k does not exceed the tensor size along the target dimension:

import torch

tensor = torch.tensor([1, 2, 3])

# ✅ Correct: clamp k to the tensor size
k = min(5, tensor.size(0))
values, indices = torch.topk(tensor, k)
print(f"Top {k}: {values}")

Output:

Top 3: tensor([3, 2, 1])

Quick Comparison: kthvalue() vs. topk()

Featuretorch.kthvalue()torch.topk()
ReturnsSingle k-th smallest elementk largest (or smallest) elements
Sorting orderAscending (finds k-th smallest)Descending by default (configurable)
Use caseFinding a specific rank (median, percentile)Getting the top/bottom k items
largest parameter❌ Not available (always smallest)✅ Toggle between largest/smallest
Output shapeScalar (or reduced dimension)k elements per dimension

Conclusion

PyTorch provides efficient, GPU-accelerated functions for finding ranked elements in tensors:

  • torch.kthvalue(input, k) returns the k-th smallest element and its index - ideal for finding specific ranks like medians or percentiles.
  • torch.topk(input, k) returns the top k largest (or smallest with largest=False) elements and their indices - perfect for tasks like selecting top predictions, beam search, or feature selection.

Both functions support the dim parameter for operating along specific dimensions of multi-dimensional tensors, making them versatile tools for tensor analysis in deep learning workflows.