PyTorch .pt 和 .pth 文件
.pt 和 .pth 文件格式是 PyTorch 特有的標準,用於保存和加載各種類型的數據。
功能
這兩個擴展名在功能上是完全相同的,通常用於存儲:
- 模型權重 (
state_dict):最常見的用例。 - 完整的模型架構:包括結構和參數。
- 優化器狀態:對於恢復訓練過程非常有用。
- 元數據:任何通過 PyTorch 序列化機制存儲的 Python 對象。
最佳實踐:僅保存 state_dict
強烈建議僅保存模型權重 (state_dict),而不是整個模型對象。這提供了更好的靈活性和跨版本兼容性。
import torch
# 定義模型實例
model = SimpleModel()
# 保存模型權重 (推薦)
torch.save(model.state_dict(), "model_weights.pt")
# 將權重加載到新模型中
new_model = SimpleModel()
new_model.load_state_dict(torch.load("model_weights.pt"))
安全警告
切勿加載來自不可信來源的 .pt 或 .pth 文件。這些文件包含 Python pickle 數據,加載時可能會執行任意代碼。
跨設備兼容性
在 CPU 和 GPU 之間移動模型時,請在加載期間指定 map_location 參數:
# 加載在 GPU 上保存的權重到 CPU
model.load_state_dict(torch.load("gpu_model.pt", map_location='cpu'))
# 加載在 CPU 上保存的權重到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("cpu_model.pt", map_location=device))
小貼士
始終記錄用於保存模型的 PyTorch 版本,以避免日後的兼容性問題。