ONNX Runtime:跨平台模型推理引擎
--- title: "ONNX Runtime:跨平台模型推理引擎" description: "掌握ONNX格式转换和ONNX Runtime推理,实现模型的跨平台无缝部署" tags: ["ONNX", "跨平台推理", "模型转换"] category: "llm" icon: "🧠"
ONNX Runtime:跨平台模型推理引擎
ONNX简介
ONNX(Open Neural Network Exchange)是一种开放的模型交换格式,旨在实现不同深度学习框架之间的模型互操作性。ONNX Runtime是微软开发的高性能推理引擎,支持在CPU、GPU、NPU等多种硬件上运行ONNX模型。
ONNX的核心价值:
- 框架无关:支持PyTorch、TensorFlow、JAX等主流框架的模型转换
- 跨平台:Windows、Linux、macOS、移动端统一部署
- 硬件加速:通过Execution Provider适配不同硬件
- 生产级优化:图优化、量化、算子融合等高级特性
PyTorch模型导出为ONNX
基础导出方法
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型
model = AutoModelForCausalLM.from_pretrained("bert-base-chinese")
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model.eval()
# 创建虚拟输入
dummy_input = tokenizer("测试文本", return_tensors="pt", padding="max_length", max_length=128)
# 导出ONNX
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
"bert_model.onnx",
opset_version=14,
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"logits": {0: "batch_size", 1: "sequence_length"}
}
)
使用Optimum库导出LLM
Optimum库简化了Hugging Face模型的ONNX导出流程:
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer
# 一键导出LLM为ONNX格式
model = ORTModelForCausalLM.from_pretrained(
"Qwen/Qwen-1.8B-Chat",
export=True,
provider="CUDAExecutionProvider"
)
# 保存导出的模型
model.save_pretrained("./qwen_onnx")
# 加载并推理
tokenizer = AutoTokenizer.from_pretrained("./qwen_onnx")
inputs = tokenizer("你好,请介绍一下自己", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
ONNX Runtime推理
CPU推理
import onnxruntime as ort
import numpy as np
# 创建推理会话
session = ort.InferenceSession(
"bert_model.onnx",
providers=["CPUExecutionProvider"]
)
# 准备输入
input_ids = np.array([101, 872, 3221, 102](/notes/101-872-3221-102), dtype=np.int64)
attention_mask = np.array([1, 1, 1, 1](/notes/1-1-1-1), dtype=np.int64)
# 执行推理
outputs = session.run(
["logits"],
{"input_ids": input_ids, "attention_mask": attention_mask}
)
print("输出形状:", outputs[0].shape)
GPU推理(CUDA)
# CUDA加速推理
session = ort.InferenceSession(
"model.onnx",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
# GPU上的推理与CPU相同,框架自动处理设备转移
outputs = session.run(["logits"], input_dict)
模型优化与量化
图优化
ONNX Runtime提供自动图优化,移除冗余节点并融合算子:
import onnxruntime as ort
# 启用所有优化级别
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(
"model.onnx",
sess_options=session_options,
providers=["CUDAExecutionProvider"]
)
INT8动态量化
from onnxruntime.quantization import quantize_dynamic, QuantType, quantize_static
# 动态量化(简单快速)
quantize_dynamic(
model_input="model.onnx",
model_output="model_int8.onnx",
weight_type=QuantType.QInt8
)
# 静态量化(精度更高,需要校准数据)
from onnxruntime.quantization import CalibrationDataReader
class CalibrationReader(CalibrationDataReader):
def __init__(self, tokenizer, dataset):
self.tokenizer = tokenizer
self.dataset = dataset
def get_next(self):
if self.current_idx >= len(self.dataset):
return None
text = self.dataset[self.current_idx]
encoding = self.tokenizer(text, return_tensors="np", padding="max_length", max_length=128)
self.current_idx += 1
return {
"input_ids": encoding["input_ids"],
"attention_mask": encoding["attention_mask"]
}
# 执行静态量化
quantize_static(
model_input="model.onnx",
model_output="model_static_int8.onnx",
calibration_data_reader=reader
)
跨平台部署实战
Windows桌面应用
import onnxruntime as ort
import tkinter as tk
class LLMApp:
def __init__(self):
self.session = ort.InferenceSession(
"model.onnx",
providers=["DmlExecutionProvider", "CPUExecutionProvider"] # DirectML加速
)
self.tokenizer = AutoTokenizer.from_pretrained("model_path")
def predict(self, prompt):
inputs = self.tokenizer(prompt, return_tensors="np")
outputs = self.session.run(None, dict(inputs))
return self.tokenizer.decode(outputs[0][0])
Android移动端
// Android使用ONNX Runtime Mobile
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtSession
val env = OrtEnvironment.getEnvironment()
val session = env.createSession("model_quantized.onnx")
val inputTensor = OnnxTensor.createTensor(env, inputIds)
val results = session.run(mapOf("input_ids" to inputTensor))
性能对比与选择建议
| 场景 | 推荐方案 | 原因 |
|---|---|---|
| Windows桌面 | DirectML | 硬件兼容性好 |
| Linux服务器 | CUDA | GPU性能最优 |
| 移动端 | NNAPI/CoreML | 原生硬件加速 |
| 嵌入式 | CPU + INT8 | 资源受限环境 |
ONNX Runtime凭借其跨平台能力和丰富的优化选项,是模型从研究到生产部署的重要桥梁。选择合适的Execution Provider和量化策略,可在各种硬件上实现高效的LLM推理。