ONNX 概述
Open Neural Network Exchange (ONNX) 是一個開放標準,專門設計用於促進機器學習模型在不同框架之間的互操作性。
為什麼使用 ONNX?
在深度學習生態系統中,不同的框架(例如 PyTorch、TensorFlow、MXNet)使用不同的模型表示方法。這種碎片化導致了以下問題:
- 框架鎖定:難以在不同的訓練工具之間移動模型。
- 部署複雜性:生產環境通常需要特定的推理引擎。
- 工具鏈分散:不同的框架需要單獨的優化工具。
ONNX 通過提供開放的模型表示標準解決了這些問題。
核心原則
- 計算圖表示:模型表示為節點(算子)和邊(數據流)的有向無環圖 (DAG)。
- 算子標準化:卷積、激活等算子的標準化定義確保了在不同框架下行為一致。
- 單一文件格式:架構、權重和元數據均存儲在一個
.onnx文件中。
關鍵特性
| 特性 | 描述 |
|---|---|
| 互操作性 | 模型可以從一個框架導出並導入到另一個框架。 |
| 優化友好 | 硬件加速器(如 TensorRT、OpenVINO)原生支持 ONNX。 |
| 版本控制 | 顯式的版本控制確保向後兼容性。 |
實際用法
從 PyTorch 導出
import torch
# 創建模型實例
model = SimpleModel()
# 創建虛擬輸入
dummy_input = torch.randn(1, 784)
# 導出為 ONNX 格式
torch.onnx.export(
model, # 要導出的模型
dummy_input, # 示例模型輸入
"simple_model.onnx", # 輸出文件名
opset_version=13, # ONNX 算子集版本
input_names=['input'], # 輸入節點名稱
output_names=['output'], # 輸出節點名稱
dynamic_axes={'input': {0: 'batch_size'}, # 動態維度
'output': {0: 'batch_size'}}
)
使用 ONNX Runtime 進行推理
import onnxruntime as ort
import numpy as np
# 加載 ONNX 模型
ort_session = ort.InferenceSession("simple_model.onnx")
# 準備輸入數據
input_data = np.random.randn(1, 784).astype(np.float32)
# 運行推理
outputs = ort_session.run(None, {'input': input_data})
建議
對於大多數深度學習工作流,ONNX 是連接訓練與高性能推理的主要橋樑。