跳到主要内容

ONNX 概述

Open Neural Network Exchange (ONNX) 是一个开放标准,专门设计用于促进机器学习模型在不同框架之间的互操作性。

为什么使用 ONNX?

在深度学习生态系统中,不同的框架(例如 PyTorch、TensorFlow、MXNet)使用不同的模型表示方法。这种碎片化导致了以下问题:

  • 框架锁定:难以在不同的训练工具之间移动模型。
  • 部署复杂性:生产环境通常需要特定的推理引擎。
  • 工具链分散:不同的框架需要单独的优化工具。

ONNX 通过提供开放的模型表示标准解决了这些问题。

核心原则

  • 计算图表示:模型表示为节点(算子)和边(数据流)的有向无环图 (DAG)。
  • 算子标准化:卷积、激活等算子的标准化定义确保了在不同框架下行为一致。
  • 单一文件格式:架构、权重和元数据均存储在一个 .onnx 文件中。

关键特性

特性描述
互操作性模型可以从一个框架导出并导入到另一个框架。
优化友好硬件加速器(如 TensorRT、OpenVINO)原生支持 ONNX。
版本控制显式的版本控制确保向后兼容性。

实际用法

从 PyTorch 导出

import torch

# 创建模型实例
model = SimpleModel()

# 创建虚拟输入
dummy_input = torch.randn(1, 784)

# 导出为 ONNX 格式
torch.onnx.export(
model, # 要导出的模型
dummy_input, # 示例模型输入
"simple_model.onnx", # 输出文件名
opset_version=13, # ONNX 算子集版本
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态维度
'output': {0: 'batch_size'}}
)

使用 ONNX Runtime 进行推理

import onnxruntime as ort
import numpy as np

# 加载 ONNX 模型
ort_session = ort.InferenceSession("simple_model.onnx")

# 准备输入数据
input_data = np.random.randn(1, 784).astype(np.float32)

# 运行推理
outputs = ort_session.run(None, {'input': input_data})
建议

对于大多数深度学习工作流,ONNX 是连接训练与高性能推理的主要桥梁。