PyTorch .pt and .pth Files
The .pt and .pth file formats are PyTorch-specific standards for saving and loading various types of data.
Functionality
Both extensions are functionally identical and are commonly used to store:
- Model Weights (
state_dict): The most common use case. - Complete Model Architecture: Including structure and parameters.
- Optimizer State: Useful for resuming training processes.
- Metadata: Any Python object serialized via PyTorch's mechanism.
Best Practice: Save Only state_dict
It is highly recommended to save only the model weights (state_dict) rather than the entire model object. This provides better flexibility and cross-version compatibility.
import torch
# Define model instance
model = SimpleModel()
# Save model weights (Recommended)
torch.save(model.state_dict(), "model_weights.pt")
# Load weights into a new model
new_model = SimpleModel()
new_model.load_state_dict(torch.load("model_weights.pt"))
Security Warning
Never load .pt or .pth files from untrusted sources. These files contain Python pickle data, which could execute arbitrary code upon loading.
Cross-Device Compatibility
When moving models between CPU and GPU, specify the map_location parameter during loading:
# Load GPU-saved weights onto CPU
model.load_state_dict(torch.load("gpu_model.pt", map_location='cpu'))
# Load CPU-saved weights onto GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("cpu_model.pt", map_location=device))
Pro Tip
Always record the PyTorch version used to save the model to avoid compatibility issues later.