caffe模型 转 pytorch 模型

最近基于 caff2onnx 做了部分修改,完成了caffe 转 pytorch的模型代码, 

 

主代码 , 需要自己构建 pytorch 的Net 架构, 同时 net各层的名字要与 caffe的各层对应。

    graph, params = LoadCaffeModel(caffe_graph_path,caffe_params_path)
    #print(graph)

    net_pytorch = NET()

    net_pytorch.eval()


    print('start convert')
    caffe2pytorch = Caffe2Pytorch(graph, params, net_pytorch)
    caffe2pytorch.convert()
    print("convert finish.")
    torch.save(net_pytorch.state_dict(), save_path)
    print("save finish.")

子函数代码

class Caffe2Pytorch():
    def __init__(self, net, model, pytorch_net):
        # 初始化一个c2oGraph对象
        # 网络和参数
        self.netLayerCaffe = self.GetNetLayerCaffe(net)
        self.netModelCaffe = self.GetNetModelCaffe(model)

        # 模型的输入名和输入维度
        self.model_input_name = []
        self.model_input_shape = []

        self.pytorch_net = pytorch_net
        self.state_dict = {}

    # 获取网络层
    def GetNetLayerCaffe(self, net):
        if len(net.layer) == 0 and len(net.layers) != 0:
            return net.layers
        elif len(net.layer) != 0 and len(net.layers) == 0:
            return net.layer
        else:
            print("prototxt layer error")
            return -1

    # 获取参数层
    def GetNetModelCaffe(self, model):
        if len(model.layer) == 0 and len(model.layers) != 0:
            return model.layers
        elif len(model.layer) != 0 and len(model.layers) == 0:
            return model.layer
        else:
            print("caffemodel layer error")
            return -1


    def match(self, caffe_layer_name, pS, pD):
        index = 0
        for name in self.pytorch_net.state_dict():
            if name.find(caffe_layer_name+".") == 0:
                print("match success:  caffe name:", caffe_layer_name,  " py name:", name)
                newD = [p for p in pD[index]]
                newD = np.array(newD)
                shape = tuple([s for s in pS[index]])
                newD = newD.reshape(shape)
                #print("newD:", newD.shape)
                self.state_dict[name] = torch.from_numpy(newD)
                index += 1

            if index == len(pS):
                break


    def convert(self):
        ParamShape = []
        ParamData = []
        # 根据这个layer名找出对应的caffemodel中的参数
        for i, model_layer in enumerate(self.netModelCaffe):
            Params = copy.deepcopy(model_layer.blobs)
            ParamShape = [p.shape.dim for p in Params]
            ParamData = [p.data for p in Params]

            if model_layer.type == "BatchNorm" or model_layer.type == "BN":
                if len(ParamShape) == 3:
                    # 如果是bn层,params为[mean, var, s],则需要把mean和var除以滑动系数s
                    ParamShape = ParamShape[:-1]
                    ParamData = [
                        [q / (Params[-1].data[0])
                         for q in p.data] if i == 0 else
                        [q / (Params[-1].data[0] + 1e-5) for q in p.data]
                        for i, p in enumerate(Params[:-1])
                    ]  # with s
                elif len(ParamShape) == 2 and len(ParamShape[0]) == 4:
                    ParamShape = [[ParamShape[0][1]], [ParamShape[1][1]]]
                    ParamData = [[q / 1. for q in p.data] if i == 0 else
                                 [q / (1. + 1e-5) for q in p.data]
                                 for i, p in enumerate(Params)]
                if self.netModelCaffe[i+1].type == "Scale":
                    Params = copy.deepcopy(self.netModelCaffe[i+1].blobs)
                    ParamShape1 = [p.shape.dim for p in Params]
                    ParamData1 = [p.data for p in Params]

                    ParamShape1.extend(ParamShape)
                    ParamData1.extend(ParamData)
                    ParamShape = ParamShape1
                    ParamData = ParamData1

            print("caffe param name:", model_layer.name, " param shape :", ParamShape)

            layer_name = model_layer.name
            layer_name = layer_name.replace("/", "_")
            self.match(layer_name, ParamShape, ParamData)

        self.pytorch_net.load_state_dict(self.state_dict, strict=False)

        return ParamShape, ParamData

 


版权声明:本文为u011808673原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。