PyTorch 基础3

onnx-runtime

Pytorch导出onnx

如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

model = MyModel().eval()
dummy = torch.randn(1, 3, 224, 224)

torch.onnx.export(
model,#要转化的模型
dummy,#模型的任意一组输入
"model.onnx",#导出的ONNX文件名
input_names=['input'],#输入Tensor的名称
output_names=['output'],#输出Tensor的名称
opset_version=11,#ONNX算子集版本
do_constant_folding=True,
)

① dummy input 是干什么的?

给 ONNX 一个“样例输入”,让它知道模型的:

  • 输入 shape
  • 数据类型
  • 张量流动路径

没 dummy ONNX 就无法导出。

② opset_version = 11 为什么?

  • 11 最稳定
  • TensorRT、ONNXRuntime都兼容 11
  • YOLO/SegFormer 等模型都采用 opset 11/12/13

③ input_names / output_names 是什么?

影响 C++ 中绑定输入:

1
const char* input_name = session.GetInputName(0);

如果你在 PyTorch 导出时名字错了,C++ 就会找不到输入。