跳至主要内容

ONNX 概述

Open Neural Network Exchange (ONNX) 是一個開放標準,專門設計用於促進機器學習模型在不同框架之間的互操作性。

為什麼使用 ONNX?

在深度學習生態系統中,不同的框架(例如 PyTorch、TensorFlow、MXNet)使用不同的模型表示方法。這種碎片化導致了以下問題:

  • 框架鎖定:難以在不同的訓練工具之間移動模型。
  • 部署複雜性:生產環境通常需要特定的推理引擎。
  • 工具鏈分散:不同的框架需要單獨的優化工具。

ONNX 通過提供開放的模型表示標準解決了這些問題。

核心原則

  • 計算圖表示:模型表示為節點(算子)和邊(數據流)的有向無環圖 (DAG)。
  • 算子標準化:卷積、激活等算子的標準化定義確保了在不同框架下行為一致。
  • 單一文件格式:架構、權重和元數據均存儲在一個 .onnx 文件中。

關鍵特性

特性描述
互操作性模型可以從一個框架導出並導入到另一個框架。
優化友好硬件加速器(如 TensorRT、OpenVINO)原生支持 ONNX。
版本控制顯式的版本控制確保向後兼容性。

實際用法

從 PyTorch 導出

import torch

# 創建模型實例
model = SimpleModel()

# 創建虛擬輸入
dummy_input = torch.randn(1, 784)

# 導出為 ONNX 格式
torch.onnx.export(
model, # 要導出的模型
dummy_input, # 示例模型輸入
"simple_model.onnx", # 輸出文件名
opset_version=13, # ONNX 算子集版本
input_names=['input'], # 輸入節點名稱
output_names=['output'], # 輸出節點名稱
dynamic_axes={'input': {0: 'batch_size'}, # 動態維度
'output': {0: 'batch_size'}}
)

使用 ONNX Runtime 進行推理

import onnxruntime as ort
import numpy as np

# 加載 ONNX 模型
ort_session = ort.InferenceSession("simple_model.onnx")

# 準備輸入數據
input_data = np.random.randn(1, 784).astype(np.float32)

# 運行推理
outputs = ort_session.run(None, {'input': input_data})
建議

對於大多數深度學習工作流,ONNX 是連接訓練與高性能推理的主要橋樑。