#Batch modify pth files#
import torch
#导入pth文件
path_2='./model.pth'
model_2=dict(torch.load(path_2))
#原pth文件有的params 将其换成state_dict
model_2['state_dict'] = model_2.pop('params')
dict=[]
for k in model_2['state_dict'].keys():
k_="{}".format(k) #在输出变量时加上引号
dict.append(k)
#修改成新名字
for k in dict:
#k是旧名
k_="{}".format(k) #在输出变量时加上引号
older_val=model_2['state_dict'][k_]
#print(k_)
#新名
k_new="generator.{}".format(k_)
#print(k_new)
# 修改参数名,pop该方法返回从列表中移除的元素对象。
model_2['state_dict'][k_new] = model_2['state_dict'].pop(k_)
torch.save(model_2,'./model_changed.pth')
print(model_2)
#merge path_1\path_2
import torch
path_1='/model_1.pth'
path_2='/model_2.pth'
model_1=torch.load(path_1)
model_2=torch.load(path_2)
for k,v in model_2.items():
for i,j in v.items():
model_1['state_dict'][i]=j
#print(j)
#print(i)
#print(model_2['state_dict'][i])
torch.save(model_1,'./model_3.pth')
版权声明:本文为qq_44132116原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。