ResNet50 转 TRT

《torch2trt》

模型的 accuracy 和 speed 是鱼和熊掌不可兼得,所以如何进行有效的模型加速就很重要。之前尝试过模型剪枝、量化和蒸馏的方法,发现除了蒸馏可以提升小模型的效果外,其余的效果大都停留在学术层面上,实际用起来的模型效果影响还是比较明显的。直接结合硬件层面对模型的加速是比较合理的,所以模型直接转 TensorRT 加速一直都是业界用的比较多方案,本文介绍一个 NVIDIA 开源的 torch2trt,目前用起来还是有很多的 bug,但是可以自定义 converter 也不至于遇到问题束手无策。之前转 TRT 都是 caffe/tf/torch -> onnx -> trt 这个链路的,中间还需要编译 TensorRT。而直接用 NVIDIA 的这个代码就比较方便一些。

Key Words:torch2trt


Beijing, 2022

https://github.com/NVIDIA-AI-IOT/torch2trt

Agile Pioneer

以官方的 resnet50 为例的转换程序如下:

import time
import torch
import numpy as np
import tensorrt as trt
from torch2trt import torch2trt
from torch2trt import tensorrt_converter
from torchvision.models.resnet import resnet50


# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# create some regular pytorch model...
model = resnet50(pretrained=True).eval().cuda()


# convert to TensorRT feeding sample data as input
start = time.time()
model_trt = torch2trt(model, [x], fp16_mode=True, strict_type_constraints=True, use_onnx=True)
print("Convert cost time: %s"%((time.time() - start)*1000))

安装 torch2trt

基本上没什么坑,当时遇到了点小问题,我在 issue 里面也放了解决的评论

先对比 trt fp32/16 和 pytorch 的diff


np.testing.assert_almost_equal(model_trt(x).data.cpu().numpy(), model(x).data.cpu().numpy(), decimal=2)

fp16 的 diff 还是挺大的,fp32还好

对比 trt fp32/fp16 和 pytorch 的性能


def test_latency(model, model_trt, x):
    ### warm_up
    model(x)
    model(x)
    model(x)
    start = time.time()
    for i in range(50):
        model(x)
    torch.cuda.synchronize()  # 同步 gpu 操作结束
    print("Torch Cost time: %s ms"%((time.time() - start) * 1000))
    ### warm_up
    model_trt(x)
    model_trt(x)
    model_trt(x)
    start = time.time()
    for i in range(50):
        model_trt(x)
    torch.cuda.synchronize()
    print("TRT Cost time: %s ms"%((time.time() - start) * 1000))

test_latency(model, model_trt, x)

不同型号的 GPU 不一样,就不放我的结果了

存储和加载 trt 模型

torch2trt模型保存并加载
torch.save(net_trt.state_dict(), 'resnet50_trt.pth')
model_trt = torch2trt.TRTModule()
model_trt.load_state_dict(torch.load('resnet50_trt.pth'))

版权声明:本文为racesu原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。