批量修改pth文件里的参数名

#Batch modify pth files#
import torch

#导入pth文件
path_2='./model.pth'
model_2=dict(torch.load(path_2))

#原pth文件有的params 将其换成state_dict
model_2['state_dict'] = model_2.pop('params')

dict=[]
for k in model_2['state_dict'].keys():
    k_="{}".format(k) #在输出变量时加上引号
    dict.append(k)

#修改成新名字
for k in dict:
    #k是旧名
    k_="{}".format(k) #在输出变量时加上引号
    older_val=model_2['state_dict'][k_]
    #print(k_)

    #新名
    k_new="generator.{}".format(k_)
    #print(k_new)

    # 修改参数名,pop该方法返回从列表中移除的元素对象。
    model_2['state_dict'][k_new] = model_2['state_dict'].pop(k_)

torch.save(model_2,'./model_changed.pth')
print(model_2)

#merge path_1\path_2


import torch

path_1='/model_1.pth'
path_2='/model_2.pth'

model_1=torch.load(path_1)
model_2=torch.load(path_2)

for k,v in model_2.items():
   for i,j in v.items():
      model_1['state_dict'][i]=j
      #print(j)
      #print(i)
      #print(model_2['state_dict'][i])

torch.save(model_1,'./model_3.pth')


版权声明:本文为qq_44132116原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。