import torch
model_path = './hitnet_sf_finalpass/version_40/checkpoints/epoch=6-step=44890.ckpt'
left_in = (torch.randn(1, 3, 480, 640, device='cuda'), torch.randn(1, 3, 480, 640, device='cuda'))
right_in = torch.randn(1, 3, 480, 640, device='cuda')
ckpt = torch.load(model_path)
model = PredictModel(**vars(args)).eval()
model.load_state_dict(ckpt['state_dict'])
# 给输入输出取个名字
input_names = ('input_1', 'input_2')
output_names = ["output_1"]
torch.onnx.export(model,
left_in,
"tinyhitnet.onnx",
opset_version=13, #注意版本选择
verbose=True,
input_names=input_names,
output_names=output_names)
print('export onnx model successful!')版权声明:本文为xingtianyao原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。