Python PyTorch: How to Find the k-th and Top k Elements of a Tensor
Finding the k-th smallest element or the top k largest elements in a tensor are common operations in machine learning and data processing workflows - from selecting top predictions in classification models to implementing k-nearest neighbors algorithms. PyTorch provides two dedicated functions for these tasks: torch.kthvalue() and torch.topk().
This guide explains both functions with clear examples, covering 1D and multi-dimensional tensors, and demonstrates practical use cases.
Finding the k-th Smallest Element with torch.kthvalue()
The torch.kthvalue() function sorts the tensor in ascending order and returns the k-th smallest element along with its index in the original tensor.
Syntax:
torch.kthvalue(input, k, dim=None, keepdim=False, out=None)
| Parameter | Description |
|---|---|
input | The input tensor |
k | Integer specifying which smallest element to find (1-based) |
dim | Dimension along which to find the k-th value (default: last dimension) |
keepdim | Whether the output tensor retains the reduced dimension |
Returns: A named tuple (values, indices) containing the k-th smallest value and its index.
Example: k-th Element in a 1D Tensor
import torch
tensor = torch.tensor([4, 5, -3, 9, 7])
print(f"Original tensor: {tensor}")
# Find the 1st smallest (minimum)
val, idx = torch.kthvalue(tensor, 1)
print(f"\n1st smallest -> Value: {val.item()}, Index: {idx.item()}")
# Find the 3rd smallest
val, idx = torch.kthvalue(tensor, 3)
print(f"3rd smallest -> Value: {val.item()}, Index: {idx.item()}")
# Find the 5th smallest (maximum)
val, idx = torch.kthvalue(tensor, 5)
print(f"5th smallest -> Value: {val.item()}, Index: {idx.item()}")
Output:
Original tensor: tensor([ 4, 5, -3, 9, 7])
1st smallest -> Value: -3, Index: 2
3rd smallest -> Value: 5, Index: 1
5th smallest -> Value: 9, Index: 3
The sorted order is [-3, 4, 5, 7, 9], so the 3rd smallest is 5, which is at index 1 in the original tensor.
Example: k-th Element Along a Dimension in a 2D Tensor
import torch
tensor = torch.tensor([
[10, 30, 20],
[60, 40, 50],
[90, 70, 80]
])
print(f"Original tensor:\n{tensor}\n")
# Find the 2nd smallest value along each row (dim=1)
values, indices = torch.kthvalue(tensor, 2, dim=1)
print(f"2nd smallest per row:")
print(f" Values: {values}")
print(f" Indices: {indices}")
# Find the 2nd smallest value along each column (dim=0)
values, indices = torch.kthvalue(tensor, 2, dim=0)
print(f"\n2nd smallest per column:")
print(f" Values: {values}")
print(f" Indices: {indices}")
Output:
Original tensor:
tensor([[10, 30, 20],
[60, 40, 50],
[90, 70, 80]])
2nd smallest per row:
Values: tensor([20, 50, 80])
Indices: tensor([2, 2, 2])
2nd smallest per column:
Values: tensor([60, 40, 50])
Indices: tensor([1, 1, 1])
dim=1(along rows): For row[10, 30, 20], sorted is[10, 20, 30], so the 2nd smallest is20.dim=0(along columns): For column[10, 60, 90], sorted is[10, 60, 90], so the 2nd smallest is60.
Finding the Top k Elements with torch.topk()
The torch.topk() function returns the k largest (or smallest) elements from a tensor along a specified dimension, along with their indices.
Syntax:
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
| Parameter | Description |
|---|---|
input | The input tensor |
k | Number of top elements to return |
dim | Dimension along which to find top k (default: last dimension) |
largest | If True, returns the k largest; if False, returns the k smallest |
sorted | Whether the returned elements are sorted in descending order |
Returns: A named tuple (values, indices).
Example: Top k Elements in a 1D Tensor
import torch
tensor = torch.tensor([5.344, 8.343, -2.398, -0.995, 5.0, 30.421])
print(f"Original tensor: {tensor}")
# Find the top 3 largest elements
values, indices = torch.topk(tensor, 3)
print(f"\nTop 3 largest:")
print(f" Values: {values}")
print(f" Indices: {indices}")
# Find the top 3 smallest elements
values, indices = torch.topk(tensor, 3, largest=False)
print(f"\nTop 3 smallest:")
print(f" Values: {values}")
print(f" Indices: {indices}")
Output:
Original tensor: tensor([ 5.3440, 8.3430, -2.3980, -0.9950, 5.0000, 30.4210])
Top 3 largest:
Values: tensor([30.4210, 8.3430, 5.3440])
Indices: tensor([5, 1, 0])
Top 3 smallest:
Values: tensor([-2.3980, -0.9950, 5.0000])
Indices: tensor([2, 3, 4])
Example: Top k Along a Dimension in a 2D Tensor
import torch
tensor = torch.tensor([
[12, 45, 23, 67],
[89, 34, 56, 11],
[78, 90, 43, 21]
])
print(f"Original tensor:\n{tensor}\n")
# Top 2 largest per row (dim=1)
values, indices = torch.topk(tensor, 2, dim=1)
print(f"Top 2 per row:")
print(f" Values:\n{values}")
print(f" Indices:\n{indices}")
# Top 2 largest per column (dim=0)
values, indices = torch.topk(tensor, 2, dim=0)
print(f"\nTop 2 per column:")
print(f" Values:\n{values}")
print(f" Indices:\n{indices}")
Output:
Original tensor:
tensor([[12, 45, 23, 67],
[89, 34, 56, 11],
[78, 90, 43, 21]])
Top 2 per row:
Values:
tensor([[67, 45],
[89, 56],
[90, 78]])
Indices:
tensor([[3, 1],
[0, 2],
[1, 0]])
Top 2 per column:
Values:
tensor([[89, 90, 56, 67],
[78, 45, 43, 21]])
Indices:
tensor([[1, 2, 1, 0],
[2, 0, 2, 2]])
Practical Example: Top k Predictions in Classification
A common real-world use of torch.topk() is finding the top k predicted classes from a model's output:
import torch
# Simulated model output (logits for 6 classes)
logits = torch.tensor([1.2, 3.8, 0.5, 4.1, 2.7, 0.9])
class_names = ["cat", "dog", "bird", "fish", "horse", "snake"]
# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=0)
# Get top 3 predictions
top_probs, top_indices = torch.topk(probs, 3)
print("Top 3 predictions:")
for prob, idx in zip(top_probs, top_indices):
print(f" {class_names[idx]:>6s}: {prob.item():.4f} ({prob.item()*100:.1f}%)")
Output:
Top 3 predictions:
fish: 0.4738 (47.4%)
dog: 0.3510 (35.1%)
horse: 0.1168 (11.7%)
Finding the Median with torch.kthvalue()
You can use kthvalue() to find the median of a tensor:
import torch
tensor = torch.tensor([15.0, 3.0, 8.0, 22.0, 11.0])
n = tensor.size(0)
# For odd-length tensors, median is the middle element
median_val, median_idx = torch.kthvalue(tensor, (n + 1) // 2)
print(f"Tensor: {tensor}")
print(f"Median: {median_val.item()} (at index {median_idx.item()})")
Output:
Tensor: tensor([15., 3., 8., 22., 11.])
Median: 11.0 (at index 4)
Common Mistake: k Out of Range
Passing a k value larger than the dimension size raises a runtime error:
import torch
tensor = torch.tensor([1, 2, 3])
# ❌ Wrong: k=5 but tensor only has 3 elements
try:
values, indices = torch.topk(tensor, 5)
except RuntimeError as e:
print(f"Error: {e}")
Output:
Error: k (5) is too big for dimension size (3)
The fix: Ensure k does not exceed the tensor size along the target dimension:
import torch
tensor = torch.tensor([1, 2, 3])
# ✅ Correct: clamp k to the tensor size
k = min(5, tensor.size(0))
values, indices = torch.topk(tensor, k)
print(f"Top {k}: {values}")
Output:
Top 3: tensor([3, 2, 1])
Quick Comparison: kthvalue() vs. topk()
| Feature | torch.kthvalue() | torch.topk() |
|---|---|---|
| Returns | Single k-th smallest element | k largest (or smallest) elements |
| Sorting order | Ascending (finds k-th smallest) | Descending by default (configurable) |
| Use case | Finding a specific rank (median, percentile) | Getting the top/bottom k items |
largest parameter | ❌ Not available (always smallest) | ✅ Toggle between largest/smallest |
| Output shape | Scalar (or reduced dimension) | k elements per dimension |
Conclusion
PyTorch provides efficient, GPU-accelerated functions for finding ranked elements in tensors:
torch.kthvalue(input, k)returns the k-th smallest element and its index - ideal for finding specific ranks like medians or percentiles.torch.topk(input, k)returns the top k largest (or smallest withlargest=False) elements and their indices - perfect for tasks like selecting top predictions, beam search, or feature selection.
Both functions support the dim parameter for operating along specific dimensions of multi-dimensional tensors, making them versatile tools for tensor analysis in deep learning workflows.