Python PyTorch: How to Compare Two Tensors
Comparing tensors is a fundamental operation in PyTorch - whether you're validating model outputs, debugging computations, writing unit tests, or implementing conditional logic in neural networks. PyTorch provides several methods for tensor comparison, from element-wise checks to overall equality verification.
In this guide, you'll learn how to compare two tensors in PyTorch using torch.eq(), other comparison functions, and practical techniques for real-world scenarios including floating-point tolerance.
Element-Wise Equality with torch.eq()
The torch.eq() function compares two tensors element by element and returns a boolean tensor of the same shape. Each position contains True if the corresponding elements are equal and False otherwise.
Syntax:
torch.eq(input, other, out=None)
Comparing 1-D Tensors
import torch
first = torch.tensor([4.4, 2.4, -9.1, -5.31, 5.3])
second = torch.tensor([4.4, 5.5, -9.1, -5.31, 43.0])
print(f"First: {first}")
print(f"Second: {second}")
print(f"Equal: {torch.eq(first, second)}")
Output:
First: tensor([ 4.4000, 2.4000, -9.1000, -5.3100, 5.3000])
Second: tensor([ 4.4000, 5.5000, -9.1000, -5.3100, 43.0000])
Equal: tensor([ True, False, True, True, False])
Elements at positions 0, 2, and 3 are equal, while positions 1 and 4 differ.
Comparing 2-D Tensors
import torch
first = torch.tensor([
[7, -2, 3],
[29, 9, -5],
[2, -8, 34]
])
second = torch.tensor([
[7, -5, 3],
[26, 9, -4],
[3, -8, 43]
])
print("First tensor:")
print(first)
print("\nSecond tensor:")
print(second)
print("\nElement-wise comparison:")
print(torch.eq(first, second))
Output:
First tensor:
tensor([[ 7, -2, 3],
[29, 9, -5],
[ 2, -8, 34]])
Second tensor:
tensor([[ 7, -5, 3],
[26, 9, -4],
[ 3, -8, 43]])
Element-wise comparison:
tensor([[ True, False, True],
[False, True, False],
[False, True, False]])
Checking Complete Equality with torch.equal()
While torch.eq() returns an element-wise boolean tensor, torch.equal() returns a single boolean indicating whether the two tensors are entirely identical (same shape, same values).
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2, 4])
print(f"a equals b: {torch.equal(a, b)}")
print(f"a equals c: {torch.equal(a, c)}")
Output:
a equals b: True
a equals c: False
torch.eq() vs torch.equal()| Function | Returns | Use Case |
|---|---|---|
torch.eq(a, b) | Boolean tensor (element-wise) | Finding where tensors differ |
torch.equal(a, b) | Single boolean | Checking if tensors are entirely identical |
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 9, 3])
print(f"torch.eq(): {torch.eq(a, b)}") # Element-wise details
print(f"torch.equal(): {torch.equal(a, b)}") # Overall verdict
Output:
torch.eq(): tensor([ True, False, True])
torch.equal(): False
Other Comparison Operators
PyTorch provides a full set of element-wise comparison functions beyond equality:
import torch
a = torch.tensor([10, 20, 30, 40, 50])
b = torch.tensor([15, 20, 25, 40, 55])
print(f"a: {a}")
print(f"b: {b}")
print(f"a > b: {torch.gt(a, b)}")
print(f"a < b: {torch.lt(a, b)}")
print(f"a >= b: {torch.ge(a, b)}")
print(f"a <= b: {torch.le(a, b)}")
print(f"a != b: {torch.ne(a, b)}")
Output:
a: tensor([10, 20, 30, 40, 50])
b: tensor([15, 20, 25, 40, 55])
a > b: tensor([False, False, True, False, False])
a < b: tensor([ True, False, False, False, True])
a >= b: tensor([False, True, True, True, False])
a <= b: tensor([ True, True, False, True, True])
a != b: tensor([ True, False, True, False, True])
You can also use Python operators directly on tensors:
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 5, 3])
print(a == b) # Same as torch.eq(a, b)
print(a > b) # Same as torch.gt(a, b)
print(a != b) # Same as torch.ne(a, b)
Output:
tensor([ True, False, True])
tensor([False, False, False])
tensor([False, True, False])
Approximate Comparison with torch.allclose()
Floating-point arithmetic introduces tiny rounding errors that make exact comparison unreliable. torch.allclose() checks if all elements are equal within a tolerance.
import torch
a = torch.tensor([0.1 + 0.2])
b = torch.tensor([0.3])
print(f"a value: {a.item():.20f}")
print(f"b value: {b.item():.20f}")
print(f"Exact equal: {torch.equal(a, b)}")
print(f"Approximately: {torch.allclose(a, b)}")
Output:
a value: 0.30000001192092895508
b value: 0.30000001192092895508
Exact equal: True
Approximately: True
A more revealing example with actual differences:
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00001, 2.00001, 3.00001])
print(f"Exact equal: {torch.equal(a, b)}")
print(f"Close (default tol): {torch.allclose(a, b)}")
print(f"Close (looser tol): {torch.allclose(a, b, atol=1e-4)}")
Output:
Exact equal: False
Close (default tol): False
Close (looser tol): True
Syntax:
torch.allclose(input, other, rtol=1e-05, atol=1e-08)
atol- absolute tolerancertol- relative tolerance
Always use torch.allclose() when comparing tensors that result from neural network computations, matrix operations, or any floating-point arithmetic. Exact comparison with torch.equal() will often fail due to rounding.
Counting Differences
To find out how many elements differ and where they are:
import torch
predictions = torch.tensor([1, 0, 1, 1, 0, 1, 0, 1])
labels = torch.tensor([1, 0, 0, 1, 0, 1, 1, 1])
# Element-wise comparison
matches = torch.eq(predictions, labels)
num_correct = matches.sum().item()
total = matches.numel()
accuracy = num_correct / total * 100
print(f"Matches: {matches}")
print(f"Correct: {num_correct}/{total}")
print(f"Accuracy: {accuracy:.1f}%")
# Find mismatch positions
mismatches = torch.where(~matches)[0]
print(f"Mismatches at indices: {mismatches.tolist()}")
Output:
Matches: tensor([ True, True, False, True, True, True, False, True])
Correct: 6/8
Accuracy: 75.0%
Mismatches at indices: [2, 6]
Comparing Tensors on Different Devices
Attempting to compare tensors on different devices (CPU vs GPU) raises a RuntimeError:
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3]).cuda() # On GPU
torch.eq(a, b) # RuntimeError
Output:
RuntimeError: Expected all tensors to be on the same device
Move both tensors to the same device before comparing:
# Move GPU tensor to CPU
print(torch.eq(a, b.cpu()))
# Or move CPU tensor to GPU
# print(torch.eq(a.cuda(), b))
Comparing Tensors with Different Shapes
When tensors have different shapes, PyTorch applies broadcasting rules - which may produce unexpected results:
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print("Shapes:", a.shape, b.shape)
print("Comparison (broadcast):")
print(torch.eq(a, b))
Output:
Shapes: torch.Size([3]) torch.Size([2, 3])
Comparison (broadcast):
tensor([[ True, True, True],
[False, False, False]])
The 1-D tensor a is broadcast across the rows of b. Use torch.equal() if you want strict shape matching - it returns False for differently shaped tensors.
Quick Comparison of Methods
| Method | Returns | Handles Float Tolerance | Best For |
|---|---|---|---|
torch.eq() | Boolean tensor | ❌ | Element-wise equality details |
torch.equal() | Single boolean | ❌ | Overall exact equality check |
torch.allclose() | Single boolean | ✅ | Floating-point comparison |
torch.gt(), torch.lt(), etc. | Boolean tensor | ❌ | Element-wise magnitude comparison |
==, >, < operators | Boolean tensor | ❌ | Inline, readable comparisons |
Conclusion
PyTorch provides a comprehensive set of tools for comparing tensors:
- Use
torch.eq()to see exactly which elements match between two tensors - it returns a boolean tensor showing the result at each position. - Use
torch.equal()for a simple yes/no check on whether two tensors are entirely identical. - Use
torch.allclose()when comparing floating-point tensors that may have small rounding differences - essential for validating neural network outputs and gradient computations. - Use
torch.gt(),torch.lt(),torch.ge(),torch.le(),torch.ne()for element-wise magnitude comparisons.
For most deep learning workflows, torch.equal() for exact checks and torch.allclose() for approximate checks will cover your comparison needs.