Python PyTorch: How to Get the Data Type of a PyTorch Tensor in Python
When working with PyTorch tensors, knowing the data type (dtype) of your tensor is essential for ensuring compatibility in operations, debugging type-related errors, and optimizing memory usage.
In this guide, we'll show you how to check tensor data types, explain the available data types in PyTorch, and demonstrate how to create and convert tensors with specific types.
Quick Answer: Use .dtype
The .dtype attribute returns the data type of any PyTorch tensor:
import torch
tensor = torch.tensor([1.0, 2.0, 3.0])
print(tensor.dtype)
Output:
torch.float32
Creating Tensors and Checking Their Types
Default Data Types
PyTorch infers the data type from the values you provide:
import torch
# Integer values → torch.int64 (default integer type)
int_tensor = torch.tensor([10, 20, 30])
print(f"Integer tensor: {int_tensor.dtype}")
# Float values → torch.float32 (default float type)
float_tensor = torch.tensor([1.5, 2.5, 3.5])
print(f"Float tensor: {float_tensor.dtype}")
# Boolean values → torch.bool
bool_tensor = torch.tensor([True, False, True])
print(f"Boolean tensor: {bool_tensor.dtype}")
# Mixed int and float → promotes to float
mixed_tensor = torch.tensor([1, 2.5, 3])
print(f"Mixed tensor: {mixed_tensor.dtype}")
Output:
Integer tensor: torch.int64
Float tensor: torch.float32
Boolean tensor: torch.bool
Mixed tensor: torch.float32
When a tensor contains both integers and floats, PyTorch promotes all values to the broader type (float32) to avoid data loss. This is called type promotion.
Specifying Data Types Explicitly
Use the dtype parameter when creating a tensor to specify the exact type:
import torch
# Specify 32-bit float
a = torch.tensor([1, 2, 3], dtype=torch.float32)
print(f"float32: {a} → {a.dtype}")
# Specify 16-bit integer
b = torch.tensor([1, 2, 3], dtype=torch.int16)
print(f"int16: {b} → {b.dtype}")
# Specify double precision (64-bit float)
c = torch.tensor([1, 2, 3], dtype=torch.float64)
print(f"float64: {c} → {c.dtype}")
# Specify boolean
d = torch.tensor([0, 1, 0, 5], dtype=torch.bool)
print(f"bool: {d} → {d.dtype}")
Output:
float32: tensor([1., 2., 3.]) → torch.float32
int16: tensor([1, 2, 3], dtype=torch.int16) → torch.int16
float64: tensor([1., 2., 3.], dtype=torch.float64) → torch.float64
bool: tensor([False, True, False, True]) → torch.bool
All PyTorch Data Types
| Data Type | PyTorch Name | Alias | Description |
|---|---|---|---|
| Float (16-bit) | torch.float16 | torch.half | Half-precision float |
| Float (32-bit) | torch.float32 | torch.float | Single-precision (default float) |
| Float (64-bit) | torch.float64 | torch.double | Double-precision float |
| BFloat16 | torch.bfloat16 | - | Brain float (ML-optimized) |
| Int (8-bit) | torch.int8 | - | Signed 8-bit integer |
| Unsigned Int (8-bit) | torch.uint8 | - | Unsigned 8-bit integer |
| Int (16-bit) | torch.int16 | torch.short | Signed 16-bit integer |
| Int (32-bit) | torch.int32 | torch.int | Signed 32-bit integer |
| Int (64-bit) | torch.int64 | torch.long | Signed 64-bit integer (default int) |
| Boolean | torch.bool | - | True / False |
| Complex (64-bit) | torch.complex64 | - | Complex with float32 parts |
| Complex (128-bit) | torch.complex128 | - | Complex with float64 parts |
Methods to Check Data Type
Method 1: Using .dtype Attribute
The most direct way to check the data type:
import torch
tensor = torch.tensor([1, 2, 3])
print(tensor.dtype)
Output:
torch.int64
Method 2: Using type() Method
The .type() method returns a string representation that includes both the tensor type and device:
import torch
tensor = torch.tensor([1.0, 2.0, 3.0])
print(tensor.type())
Output:
torch.FloatTensor
The mapping between .dtype and .type():
.dtype | .type() |
|---|---|
torch.float32 | torch.FloatTensor |
torch.float64 | torch.DoubleTensor |
torch.int32 | torch.IntTensor |
torch.int64 | torch.LongTensor |
torch.bool | torch.BoolTensor |
Method 3: Using isinstance() for Type Checking
Check if a tensor has a specific data type:
import torch
tensor = torch.tensor([1.0, 2.0, 3.0])
print(tensor.dtype == torch.float32) # True
print(tensor.dtype == torch.int64) # False
Output:
True
False
Method 4: Using .is_floating_point() and Similar Methods
PyTorch provides convenience methods for broad type checks:
import torch
float_tensor = torch.tensor([1.0, 2.0])
int_tensor = torch.tensor([1, 2])
complex_tensor = torch.tensor([1+2j, 3+4j])
print(f"Float? {float_tensor.is_floating_point()}")
print(f"Float? {int_tensor.is_floating_point()}")
print(f"Complex? {complex_tensor.is_complex()}")
Output:
Float? True
Float? False
Complex? True
Converting Between Data Types
Using .to() Method
import torch
original = torch.tensor([1, 2, 3])
print(f"Original: {original.dtype}")
# Convert to float32
float_tensor = original.to(torch.float32)
print(f"To float: {float_tensor.dtype}")
# Convert to float16
half_tensor = original.to(torch.float16)
print(f"To half: {half_tensor.dtype}")
Output:
Original: torch.int64
To float: torch.float32
To half: torch.float16
Using Convenience Methods
PyTorch provides shortcut methods for common conversions:
import torch
tensor = torch.tensor([1, 2, 3])
print(f"Original: {tensor.dtype}")
print(f".float(): {tensor.float().dtype}")
print(f".double(): {tensor.double().dtype}")
print(f".int(): {tensor.int().dtype}")
print(f".long(): {tensor.long().dtype}")
print(f".half(): {tensor.half().dtype}")
print(f".bool(): {tensor.bool().dtype}")
Output:
Original: torch.int64
.float(): torch.float32
.double(): torch.float64
.int(): torch.int32
.long(): torch.int64
.half(): torch.float16
.bool(): torch.bool
Using .type() Method for Conversion
The .type() method can also convert types when passed a type string:
import torch
tensor = torch.tensor([1, 2, 3])
converted = tensor.type(torch.FloatTensor)
print(f"Converted: {converted.dtype}")
Output:
Converted: torch.float32
Practical Example: Checking Types in a Neural Network
Understanding data types is critical when feeding data into models:
import torch
import torch.nn as nn
# Create a simple model
model = nn.Linear(3, 1)
print(f"Model weights dtype: {model.weight.dtype}")
# Create input data
input_int = torch.tensor([[1, 2, 3]])
print(f"Input dtype: {input_int.dtype}")
# This would fail: model expects float, input is int
try:
output = model(input_int)
except RuntimeError as e:
print(f"Error: {e}")
# Fix: convert input to float
input_float = input_int.float()
print(f"Converted dtype: {input_float.dtype}")
output = model(input_float)
print(f"Output: {output}")
print(f"Output dtype: {output.dtype}")
Output:
Model weights dtype: torch.float32
Input dtype: torch.int64
Error: mat1 and mat2 must have the same dtype, but got Long and Float
Converted dtype: torch.float32
Output: tensor([[-1.6591]], grad_fn=<AddmmBackward0>)
Output dtype: torch.float32
Comprehensive Type Inspection Function
import torch
def tensor_info(tensor, name="Tensor"):
"""Print comprehensive information about a tensor."""
print(f"--- {name} ---")
print(f" Values: {tensor}")
print(f" Shape: {tensor.shape}")
print(f" Dtype: {tensor.dtype}")
print(f" Type: {tensor.type()}")
print(f" Device: {tensor.device}")
print(f" Floating: {tensor.is_floating_point()}")
print(f" Elements: {tensor.numel()}")
print()
# Test with different tensors
tensor_info(torch.tensor([1, 2, 3]), "Integer Tensor")
tensor_info(torch.tensor([1.0, 2.0, 3.0]), "Float Tensor")
tensor_info(torch.zeros(2, 3, dtype=torch.float16), "Half-Precision Zeros")
Output:
--- Integer Tensor ---
Values: tensor([1, 2, 3])
Shape: torch.Size([3])
Dtype: torch.int64
Type: torch.LongTensor
Device: cpu
Floating: False
Elements: 3
--- Float Tensor ---
Values: tensor([1., 2., 3.])
Shape: torch.Size([3])
Dtype: torch.float32
Type: torch.FloatTensor
Device: cpu
Floating: True
Elements: 3
--- Half-Precision Zeros ---
Values: tensor([[0., 0., 0.],
[0., 0., 0.]], dtype=torch.float16)
Shape: torch.Size([2, 3])
Dtype: torch.float16
Type: torch.HalfTensor
Device: cpu
Floating: True
Elements: 6
Quick Reference
| Task | Code |
|---|---|
| Check data type | tensor.dtype |
| Check type (string) | tensor.type() |
| Is floating point? | tensor.is_floating_point() |
| Convert to float32 | tensor.float() or tensor.to(torch.float32) |
| Convert to int64 | tensor.long() or tensor.to(torch.int64) |
| Create with specific dtype | torch.tensor([...], dtype=torch.float16) |
| Compare dtype | tensor.dtype == torch.float32 |
Conclusion
Getting the data type of a PyTorch tensor is straightforward with the .dtype attribute, which returns the exact type (e.g., torch.float32, torch.int64).
PyTorch supports a wide range of data types from 8-bit integers to 64-bit complex numbers.
To convert between types, use .to(dtype) for explicit conversion or convenience methods like .float(), .long(), and .half() for common conversions.
Understanding tensor data types is essential for avoiding runtime errors in neural network operations, where type mismatches between model parameters and input data are a frequent source of bugs.