torch多GPU模型的训练与保存

使用多gpu训练时

model = torch.nn.DataParallel(model, device_ids=[1, 2, 3, 4])

若模型采用多GPU训练,则在模型保存时:

torch.save(model.module.state_dict(), model_out_path)

若单GPU则:

torch.save(mode.state_dict(), model_out_path)


版权声明:本文为weixin_42105432原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。