PyTorch .pt 和 .pth 文件
.pt 和 .pth 文件格式是 PyTorch 特有的标准,用于保存和加载各种类型的数据。
功能
这两个扩展名在功能上是完全相同的,通常用于存储:
- 模型权重 (
state_dict):最常见的用例。 - 完整的模型架构:包括结构和参数。
- 优化器状态:对于恢复训练过程非常有用。
- 元数据:任何通过 PyTorch 序列化机制存储的 Python 对象。
最佳实践:仅保存 state_dict
强烈建议仅保存模型权重 (state_dict),而不是整个模型对象。这提供了更好的灵活性和跨版本兼容性。
import torch
# 定义模型实例
model = SimpleModel()
# 保存模型权重 (推荐)
torch.save(model.state_dict(), "model_weights.pt")
# 将权重加载到新模型中
new_model = SimpleModel()
new_model.load_state_dict(torch.load("model_weights.pt"))
安全警告
切勿加载来自不可信来源的 .pt 或 .pth 文件。这些文件包含 Python pickle 数据,加载时可能会执行任意代码。
跨设备兼容性
在 CPU 和 GPU 之间移动模型时,请在加载期间指定 map_location 参数:
# 加载在 GPU 上保存的权重到 CPU
model.load_state_dict(torch.load("gpu_model.pt", map_location='cpu'))
# 加载在 CPU 上保存的权重到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("cpu_model.pt", map_location=device))
小贴士
始终记录用于保存模型的 PyTorch 版本,以避免日后的兼容性问题。