Skip to main content

Python PyTorch: How to Compare Two Tensors

Comparing tensors is a fundamental operation in PyTorch - whether you're validating model outputs, debugging computations, writing unit tests, or implementing conditional logic in neural networks. PyTorch provides several methods for tensor comparison, from element-wise checks to overall equality verification.

In this guide, you'll learn how to compare two tensors in PyTorch using torch.eq(), other comparison functions, and practical techniques for real-world scenarios including floating-point tolerance.

Element-Wise Equality with torch.eq()

The torch.eq() function compares two tensors element by element and returns a boolean tensor of the same shape. Each position contains True if the corresponding elements are equal and False otherwise.

Syntax:

torch.eq(input, other, out=None)

Comparing 1-D Tensors

import torch

first = torch.tensor([4.4, 2.4, -9.1, -5.31, 5.3])
second = torch.tensor([4.4, 5.5, -9.1, -5.31, 43.0])

print(f"First: {first}")
print(f"Second: {second}")
print(f"Equal: {torch.eq(first, second)}")

Output:

First:  tensor([ 4.4000,  2.4000, -9.1000, -5.3100,  5.3000])
Second: tensor([ 4.4000, 5.5000, -9.1000, -5.3100, 43.0000])
Equal: tensor([ True, False, True, True, False])

Elements at positions 0, 2, and 3 are equal, while positions 1 and 4 differ.

Comparing 2-D Tensors

import torch

first = torch.tensor([
[7, -2, 3],
[29, 9, -5],
[2, -8, 34]
])

second = torch.tensor([
[7, -5, 3],
[26, 9, -4],
[3, -8, 43]
])

print("First tensor:")
print(first)

print("\nSecond tensor:")
print(second)

print("\nElement-wise comparison:")
print(torch.eq(first, second))

Output:

First tensor:
tensor([[ 7, -2, 3],
[29, 9, -5],
[ 2, -8, 34]])

Second tensor:
tensor([[ 7, -5, 3],
[26, 9, -4],
[ 3, -8, 43]])

Element-wise comparison:
tensor([[ True, False, True],
[False, True, False],
[False, True, False]])

Checking Complete Equality with torch.equal()

While torch.eq() returns an element-wise boolean tensor, torch.equal() returns a single boolean indicating whether the two tensors are entirely identical (same shape, same values).

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2, 4])

print(f"a equals b: {torch.equal(a, b)}")
print(f"a equals c: {torch.equal(a, c)}")

Output:

a equals b: True
a equals c: False
Key Difference: torch.eq() vs torch.equal()
FunctionReturnsUse Case
torch.eq(a, b)Boolean tensor (element-wise)Finding where tensors differ
torch.equal(a, b)Single booleanChecking if tensors are entirely identical
import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 9, 3])

print(f"torch.eq(): {torch.eq(a, b)}") # Element-wise details
print(f"torch.equal(): {torch.equal(a, b)}") # Overall verdict

Output:

torch.eq():    tensor([ True, False,  True])
torch.equal(): False

Other Comparison Operators

PyTorch provides a full set of element-wise comparison functions beyond equality:

import torch

a = torch.tensor([10, 20, 30, 40, 50])
b = torch.tensor([15, 20, 25, 40, 55])

print(f"a: {a}")
print(f"b: {b}")
print(f"a > b: {torch.gt(a, b)}")
print(f"a < b: {torch.lt(a, b)}")
print(f"a >= b: {torch.ge(a, b)}")
print(f"a <= b: {torch.le(a, b)}")
print(f"a != b: {torch.ne(a, b)}")

Output:

a:      tensor([10, 20, 30, 40, 50])
b: tensor([15, 20, 25, 40, 55])
a > b: tensor([False, False, True, False, False])
a < b: tensor([ True, False, False, False, True])
a >= b: tensor([False, True, True, True, False])
a <= b: tensor([ True, True, False, True, True])
a != b: tensor([ True, False, True, False, True])

You can also use Python operators directly on tensors:

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 5, 3])

print(a == b) # Same as torch.eq(a, b)
print(a > b) # Same as torch.gt(a, b)
print(a != b) # Same as torch.ne(a, b)

Output:

tensor([ True, False,  True])
tensor([False, False, False])
tensor([False, True, False])

Approximate Comparison with torch.allclose()

Floating-point arithmetic introduces tiny rounding errors that make exact comparison unreliable. torch.allclose() checks if all elements are equal within a tolerance.

import torch

a = torch.tensor([0.1 + 0.2])
b = torch.tensor([0.3])

print(f"a value: {a.item():.20f}")
print(f"b value: {b.item():.20f}")
print(f"Exact equal: {torch.equal(a, b)}")
print(f"Approximately: {torch.allclose(a, b)}")

Output:

a value:       0.30000001192092895508
b value: 0.30000001192092895508
Exact equal: True
Approximately: True

A more revealing example with actual differences:

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00001, 2.00001, 3.00001])

print(f"Exact equal: {torch.equal(a, b)}")
print(f"Close (default tol): {torch.allclose(a, b)}")
print(f"Close (looser tol): {torch.allclose(a, b, atol=1e-4)}")

Output:

Exact equal:              False
Close (default tol): False
Close (looser tol): True

Syntax:

torch.allclose(input, other, rtol=1e-05, atol=1e-08)
  • atol - absolute tolerance
  • rtol - relative tolerance
tip

Always use torch.allclose() when comparing tensors that result from neural network computations, matrix operations, or any floating-point arithmetic. Exact comparison with torch.equal() will often fail due to rounding.

Counting Differences

To find out how many elements differ and where they are:

import torch

predictions = torch.tensor([1, 0, 1, 1, 0, 1, 0, 1])
labels = torch.tensor([1, 0, 0, 1, 0, 1, 1, 1])

# Element-wise comparison
matches = torch.eq(predictions, labels)
num_correct = matches.sum().item()
total = matches.numel()
accuracy = num_correct / total * 100

print(f"Matches: {matches}")
print(f"Correct: {num_correct}/{total}")
print(f"Accuracy: {accuracy:.1f}%")

# Find mismatch positions
mismatches = torch.where(~matches)[0]
print(f"Mismatches at indices: {mismatches.tolist()}")

Output:

Matches:    tensor([ True,  True, False,  True,  True,  True, False,  True])
Correct: 6/8
Accuracy: 75.0%
Mismatches at indices: [2, 6]

Comparing Tensors on Different Devices

Device Mismatch Error

Attempting to compare tensors on different devices (CPU vs GPU) raises a RuntimeError:

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3]).cuda() # On GPU

torch.eq(a, b) # RuntimeError

Output:

RuntimeError: Expected all tensors to be on the same device

Move both tensors to the same device before comparing:

# Move GPU tensor to CPU
print(torch.eq(a, b.cpu()))

# Or move CPU tensor to GPU
# print(torch.eq(a.cuda(), b))

Comparing Tensors with Different Shapes

When tensors have different shapes, PyTorch applies broadcasting rules - which may produce unexpected results:

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([[1, 2, 3],
[4, 5, 6]])

print("Shapes:", a.shape, b.shape)
print("Comparison (broadcast):")
print(torch.eq(a, b))

Output:

Shapes: torch.Size([3]) torch.Size([2, 3])
Comparison (broadcast):
tensor([[ True, True, True],
[False, False, False]])

The 1-D tensor a is broadcast across the rows of b. Use torch.equal() if you want strict shape matching - it returns False for differently shaped tensors.

Quick Comparison of Methods

MethodReturnsHandles Float ToleranceBest For
torch.eq()Boolean tensorElement-wise equality details
torch.equal()Single booleanOverall exact equality check
torch.allclose()Single booleanFloating-point comparison
torch.gt(), torch.lt(), etc.Boolean tensorElement-wise magnitude comparison
==, >, < operatorsBoolean tensorInline, readable comparisons

Conclusion

PyTorch provides a comprehensive set of tools for comparing tensors:

  • Use torch.eq() to see exactly which elements match between two tensors - it returns a boolean tensor showing the result at each position.
  • Use torch.equal() for a simple yes/no check on whether two tensors are entirely identical.
  • Use torch.allclose() when comparing floating-point tensors that may have small rounding differences - essential for validating neural network outputs and gradient computations.
  • Use torch.gt(), torch.lt(), torch.ge(), torch.le(), torch.ne() for element-wise magnitude comparisons.

For most deep learning workflows, torch.equal() for exact checks and torch.allclose() for approximate checks will cover your comparison needs.