pytorch加载分布式方式训练的模型

出现的问题

  1. 使用load_state_dict函数,如果strict=False,模型可以加载,但是模型的输出不对
  2. 如果strict=True,但是没有新建字典,模型无法加载

问题解决

  • 保存的模型使用分布式方式训练,如果加载后模型不用分布式,则需要修改模型的key.
  • 分布式模型的权值名字前面有module.,非分布式不包含。
    比如分布式模型的权值名称名称为module.blocks.0.norm1.weight,非分布式模型权值名称名称blocks.0.norm1.weight
  • 需要重新建一个字典,用去掉module.的名字座位key,值不变

代码示例

chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
print('Loading checkpoint', chk_filename)
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
checkpoint = checkpoint['model_pos']
new_checkpoint = {} ## 新建一个字典来访模型的权值
# print(checkpoint)
for k,value in checkpoint.items():
    key = k.split('module.')[-1]
    new_checkpoint[key] = value
    # print(k,key)

# model_pos.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos.load_state_dict(new_checkpoint, strict=True)

版权声明:本文为Blankit1原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。