WTS 桥接
.wts (Weights) 文件格式是一种纯文本格式的权重存储文件,主要作为 TensorRT 的中间格式使用。
功能
与 .pt 或 .onnx 等二进制格式不同,.wts 文件以人类可读的文本格式存储神经网络权重参数。这特别适用于:
- 自定义层:当 ONNX 不支持特定层时,你可以手动提取权重并构建自定义 TensorRT 构建器。
- 调试:以人类可读的格式检查权重。
- 细粒度优化控制:控制权重转换的每个方面。
文件结构
.wts 文件遵循简单的格式:
<层数>
<层名称_1> <权重数量_1> <权重值_1...>
<层名称_2> <权重数量_2> <权重值_2...>
...
示例:
3
conv1.weight 324 0.123 -0.456 0.789 ...
conv1.bias 16 0.01 -0.02 0.03 ...
fc1.weight 10240 0.234 -0.567 0.890 ...
从 PyTorch 创建 WTS 文件
import torch
def generate_wts(pytorch_model, output_file):
named_parameters = pytorch_model.named_parameters()
num_layers = len(list(named_parameters))
with open(output_file, 'w') as f:
f.write(f"{num_layers}\n")
for name, param in pytorch_model.named_parameters():
param = param.cpu().detach().numpy().astype('float32')
flat_weights = param.flatten()
num_weights = len(flat_weights)
f.write(f"{name} {num_weights} ")
weights_str = " ".join([str(w) for w in flat_weights])
f.write(weights_str + "\n")
在 C++ 中解析 WTS 文件
#include <fstream>
#include <map>
#include <vector>
std::map<std::string, std::vector<float>> loadWeights(const std::string& file) {
std::map<std::string, std::vector<float>> weight_map;
std::ifstream input(file);
int num_layers;
input >> num_layers;
for (int i = 0; i < num_layers; ++i) {
std::string layer_name;
int num_weights;
input >> layer_name >> num_weights;
std::vector<float> weights(num_weights);
for (int j = 0; j < num_weights; ++j) {
input >> weights[j];
}
weight_map[layer_name] = weights;
}
return weight_map;
}
建议
虽然 .wts 提供了最大的透明度,但它应仅在标准的 .pt → .onnx → .engine 路径失败时作为备选方案使用。
优缺点
| 优点 | 缺点 |
|---|---|
| 透明度:易于调试和验证权重值。 | 文件体积大:比二进制格式大 3-4 倍。 |
| 细粒度控制:非常适合自定义层转换。 | 无架构信息:仅存储权重,不存储图。 |
| 框架独立:不受特定版本绑定。 | 精度损失:文本转换期间会有微小误差。 |