我们训练好一个目标检测模型,想要部署在生产环境中,查了很多资料,最终选择使用TorchServe来部署。TorchServe是AWS和Facebook联合开发的,所以没有多想,选择大厂的东西应该没有错,部署模型过程中踩了很多坑,最终成功在windows10上成功部署,后面会在linux服务器上部署。
1.安装
我电脑上安装的cuda版本为10.1,TorchServe对cuda10.1只支持到pytorch1.8.1或者更高版本。
首先安装依赖:下载serve(https://github.com/pytorch/serve#serve-a-model):
git clone https://github.com/pytorch/serve.git然后进入serve:
python ./ts_scripts/install_dependencies.py --cuda=cu101安装torchserve
pip install torchserve torch-model-archiver torch-workflow-archiver2.命令分析:
打包模型,打包成.mar文件
torch-model-archiver --model-name densenet161
--version 1.0
--model-file ./serve/examples/image_classifier/densenet_161/model.py
--serialized-file densenet161-8d451a50.pth
--export-path model_store
--extra-files ./serve/examples/image_classifier/index_to_name.json
--handler image_classifier
--model-name:模型名称,自定义,跟实际的模型无关
--version:版本,自定义
--model-file:指定模型文件,mode.py中只能包含一类,即你模型类,如分类模型,目标检测模型
--serialized-file:指定模型参数,保存的模型
--export-path:指定打包后的模型存储位置
--extra-files:是一个json文件,存放相关的参数,不一定用到,但是最好创建一个
--handler:处理文件,image_classifier是个python文件,里面需要写入包括数据处理的相关逻辑
其中mode.py中需要放入自己的模型,最好是一个类,多个类也可以,但是在handler文件中需要做相应的处理。handler文件非常重要,需要根据自己的任务进行重写。
启动模型:
torchserve --start --ncs --model-store model_store --models densenet161.mar3.handler文件(关键)
class BaseHandler(abc.ABC):
def __init__(self):
def initialize(self, context):
def _load_torchscript_model(self, model_pt_path):
def _load_pickled_model(self, model_dir, model_file, model_pt_path):
def preprocess(self, data):
def inference(self, data, *args, **kwargs):
def postprocess(self, data):
def handle(self, data, context):
def explain_handle(self, data_preprocess, raw_data):
def _is_explain(self):
上面代码为base_handler.py中的核心方法:
.\serve\ts\torch_handler\base_handler.py我们主要需要修改的就是preprocess和postprocess两个方法,去文件中一看就明白,preprocess是数据预处理的代码,这是非常有必要的,我们收到的数据并不能直接放入模型中,还需要进行一些处理,如图片的resize,归一化等等,而postprocess也是非常重要的,模型输出结果中包含大量的无关数据,我们可以进行处理,只返回直接的结果,如目标检测,只返回阈值高的检测框。
handler文件中,加载模型的时候修改地方:
def _load_pickled_model(self, model_dir, model_file, model_pt_path):
"""
Loads the pickle file from the given model path.
"""
model_def_path = os.path.join(model_dir, model_file)
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model.py file")
module = importlib.import_module(model_file.split(".")[0])
model_class_definitions = list_classes_from_module(module)
model_class = model_class_definitions[2]
model = model_class(self.heads,
pretrained=False,
down_ratio=self.down_ratio,
final_kernel=1,
last_level=5,
head_conv=self.head_conv
)
if model_pt_path:
model = load_model(model,model_pt_path)
return modelmodel_class这里需要修改,因为我的mode.py文件中,模型不是一个类,而是包含多个类,这里就需要选择模型主类,遍历文件中所有的类,然后选择合适的model_class。
4.index_to_name.json文件说明
这个文件根据自己的需要来写,我其实没有用到,在里边随便写了几行:
{
"threshold": "0.4",
"classnums": "1"
}其实遇到了一些坑,突然记不起来了。