- pth转onnx
以下转onnx代码参考https://blog.csdn.net/xywy2008/article/details/115400323
import os
import torch
import argparse
import json
import torch
from models.resa import RESANet
from utils.config import Config
from datasets import build_dataloader
def load_network_specified(net, model_dir, logger=None):
pretrained_net = torch.load(model_dir, map_location='cpu')['net']
net_state = net.state_dict()
state = {}
for k, v in pretrained_net.items():
if k not in net_state.keys() or v.size() != net_state[k].size():
if logger:
logger.info('skip weights: ' + k)
continue
state[k] = v
net.load_state_dict(state, strict=False)
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_path', default=None,
help="""the pytorch model pth file path""")
parser.add_argument('--input_shape', nargs='+', type=int,
help="""the model input shape, e.g. 1 3 224 224""")
args, unknown_args = parser.parse_known_args()
if len(unknown_args) > 0:
for bad_arg in unknown_args:
print("ERROR: Unknown command line arg: %s" % bad_arg)
raise ValueError("Invalid command line arg(s)")
return args
def load_model(model_path, input_shape):
cfg = Config.fromfile('./configs/culane.py')
resa = RESANet(cfg)
load_network_specified(resa, model_path)
resa.cpu()
return resa
def main():
args = parse_args()
print("model path ", args.model_path, ", shape ", args.input_shape)
#加载模型
model = load_model(args.model_path, args.input_shape)
if model is None:
print("Load model failed")
return
#将模型切换到推理状态
model.eval()
#创建输入张量
inputs = torch.randn(tuple(args.input_shape))
#生成的onnx文件存放在pytorch模型同级目录下,文件名相同,后缀为onnx
export_onnx_file = os.path.splitext(args.model_path)[0] + '.onnx'
print(export_onnx_file)
# Export with ONNX
torch.onnx.export(model, inputs, export_onnx_file, verbose=True)
if __name__== "__main__":
args = parse_args()
main()
运行
python3 pt2onnx_resa.py --model_path ./culane_resnet50.pth --input_shape 1 3 288 800
得到onnx模型
2.onnx模型转pb
参考https://gitee.com/ascend/tools/tree/master/pt2pb
版权声明:本文为qq_45013882原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。