跳到主要内容

WTS 桥接

.wts (Weights) 文件格式是一种纯文本格式的权重存储文件,主要作为 TensorRT 的中间格式使用。

功能

.pt.onnx 等二进制格式不同,.wts 文件以人类可读的文本格式存储神经网络权重参数。这特别适用于:

  • 自定义层:当 ONNX 不支持特定层时,你可以手动提取权重并构建自定义 TensorRT 构建器。
  • 调试:以人类可读的格式检查权重。
  • 细粒度优化控制:控制权重转换的每个方面。

文件结构

.wts 文件遵循简单的格式:

<层数>
<层名称_1> <权重数量_1> <权重值_1...>
<层名称_2> <权重数量_2> <权重值_2...>
...

示例:

3
conv1.weight 324 0.123 -0.456 0.789 ...
conv1.bias 16 0.01 -0.02 0.03 ...
fc1.weight 10240 0.234 -0.567 0.890 ...

从 PyTorch 创建 WTS 文件

import torch

def generate_wts(pytorch_model, output_file):
named_parameters = pytorch_model.named_parameters()
num_layers = len(list(named_parameters))

with open(output_file, 'w') as f:
f.write(f"{num_layers}\n")

for name, param in pytorch_model.named_parameters():
param = param.cpu().detach().numpy().astype('float32')
flat_weights = param.flatten()
num_weights = len(flat_weights)

f.write(f"{name} {num_weights} ")
weights_str = " ".join([str(w) for w in flat_weights])
f.write(weights_str + "\n")

在 C++ 中解析 WTS 文件

#include <fstream>
#include <map>
#include <vector>

std::map<std::string, std::vector<float>> loadWeights(const std::string& file) {
std::map<std::string, std::vector<float>> weight_map;
std::ifstream input(file);

int num_layers;
input >> num_layers;

for (int i = 0; i < num_layers; ++i) {
std::string layer_name;
int num_weights;
input >> layer_name >> num_weights;

std::vector<float> weights(num_weights);
for (int j = 0; j < num_weights; ++j) {
input >> weights[j];
}

weight_map[layer_name] = weights;
}

return weight_map;
}
建议

虽然 .wts 提供了最大的透明度,但它应仅在标准的 .pt.onnx.engine 路径失败时作为备选方案使用。

优缺点

优点缺点
透明度:易于调试和验证权重值。文件体积大:比二进制格式大 3-4 倍。
细粒度控制:非常适合自定义层转换。无架构信息:仅存储权重,不存储图。
框架独立:不受特定版本绑定。精度损失:文本转换期间会有微小误差。