pth转onnx再转pb(resa模型为例)

  1. pth转onnx
    以下转onnx代码参考https://blog.csdn.net/xywy2008/article/details/115400323
import os
import torch
import argparse
 
import json
import torch
from models.resa import RESANet
from utils.config import Config
from datasets import build_dataloader
 
def load_network_specified(net, model_dir, logger=None):
    pretrained_net = torch.load(model_dir, map_location='cpu')['net']
    net_state = net.state_dict()
    state = {}
    for k, v in pretrained_net.items():
        if k not in net_state.keys() or v.size() != net_state[k].size():
            if logger:
                logger.info('skip weights: ' + k)
            continue
        state[k] = v
    net.load_state_dict(state, strict=False)
 
def parse_args():
 
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 
    parser.add_argument('--model_path', default=None, 
                        help="""the pytorch model pth file path""")
    parser.add_argument('--input_shape', nargs='+', type=int,
                        help="""the model input shape, e.g. 1 3 224 224""")
   
    args, unknown_args = parser.parse_known_args()
    if len(unknown_args) > 0:
        for bad_arg in unknown_args:
            print("ERROR: Unknown command line arg: %s" % bad_arg)
        raise ValueError("Invalid command line arg(s)")
 
    return args
 
def load_model(model_path, input_shape):
    cfg = Config.fromfile('./configs/culane.py')
    resa = RESANet(cfg)
    load_network_specified(resa, model_path)
    resa.cpu()
    return resa
 
    
def main():
    args = parse_args()
    print("model path ", args.model_path, ", shape ", args.input_shape)
    #加载模型
    model = load_model(args.model_path, args.input_shape)
 
    if model is None:
        print("Load model failed")
        return 
 
    #将模型切换到推理状态
    model.eval()
    #创建输入张量
    inputs = torch.randn(tuple(args.input_shape)) 
    #生成的onnx文件存放在pytorch模型同级目录下,文件名相同,后缀为onnx
    export_onnx_file = os.path.splitext(args.model_path)[0] + '.onnx'
 
    print(export_onnx_file)
 
    # Export with ONNX
    torch.onnx.export(model, inputs, export_onnx_file, verbose=True)
    
if __name__== "__main__":
    args = parse_args()
    main()

运行

python3 pt2onnx_resa.py --model_path ./culane_resnet50.pth --input_shape 1 3 288 800

得到onnx模型
2.onnx模型转pb
参考https://gitee.com/ascend/tools/tree/master/pt2pb


版权声明:本文为qq_45013882原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。