跳到主要内容

PyTorch .pt 和 .pth 文件

.pt.pth 文件格式是 PyTorch 特有的标准,用于保存和加载各种类型的数据。

功能

这两个扩展名在功能上是完全相同的,通常用于存储:

  • 模型权重 (state_dict):最常见的用例。
  • 完整的模型架构:包括结构和参数。
  • 优化器状态:对于恢复训练过程非常有用。
  • 元数据:任何通过 PyTorch 序列化机制存储的 Python 对象。

最佳实践:仅保存 state_dict

强烈建议仅保存模型权重 (state_dict),而不是整个模型对象。这提供了更好的灵活性和跨版本兼容性。

import torch

# 定义模型实例
model = SimpleModel()

# 保存模型权重 (推荐)
torch.save(model.state_dict(), "model_weights.pt")

# 将权重加载到新模型中
new_model = SimpleModel()
new_model.load_state_dict(torch.load("model_weights.pt"))
安全警告

切勿加载来自不可信来源的 .pt.pth 文件。这些文件包含 Python pickle 数据,加载时可能会执行任意代码。

跨设备兼容性

在 CPU 和 GPU 之间移动模型时,请在加载期间指定 map_location 参数:

# 加载在 GPU 上保存的权重到 CPU
model.load_state_dict(torch.load("gpu_model.pt", map_location='cpu'))

# 加载在 CPU 上保存的权重到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("cpu_model.pt", map_location=device))
小贴士

始终记录用于保存模型的 PyTorch 版本,以避免日后的兼容性问题。