【pytorch转onnx,两个input】

import torch
model_path = './hitnet_sf_finalpass/version_40/checkpoints/epoch=6-step=44890.ckpt'
left_in = (torch.randn(1, 3, 480, 640, device='cuda'), torch.randn(1, 3, 480, 640, device='cuda'))
right_in = torch.randn(1, 3, 480, 640, device='cuda')
ckpt = torch.load(model_path)
model = PredictModel(**vars(args)).eval()
model.load_state_dict(ckpt['state_dict'])

# 给输入输出取个名字
input_names = ('input_1', 'input_2')
output_names = ["output_1"]

torch.onnx.export(model,
                  left_in,
                  "tinyhitnet.onnx",
                  opset_version=13, #注意版本选择
                  verbose=True,
                  input_names=input_names,
                  output_names=output_names)
print('export onnx model successful!')

版权声明:本文为xingtianyao原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。