背景:训练集较大,训练过程需要时间过长,但是我们集群环境超级不稳定。常常,训练到部分,又得重新开始训练。
一、模型的保存
torch.save主要参数就是:需要保存的权重对象 + 保存路径
torch.save(utils_x.makeDict(Model.state_dict()), 'XX+present.pkl'))
二、模型的加载
torch.load主要参数就是:文件路径 + 指定存放位置:cpu or gpu
self.Model.load_state_dict(torch.load('XX+present.pkl', map_location='cpu'))
三、模型训练过程中断点的设置
断点过程中需要保存:网络权重、优化器权重、以及epoch,便于继续训练恢复。【起初,我是每一个epoch就保存一个权重文件。后来发现还是没有必要,可以设置例如5个epoch保存一次,防止过大内存空间的消耗】
checkpoint = {
"net":self.Model.state_dict(),
'optimizer':self.optimizer.state_dict(),
"epoch":idx
}
if not os.path.isdir('./checkpoint'):
os.makedirs("./checkpoint")
torch.save(checkpoint, './checkpoint/ckpt_best_%s.pth' % (str(idx+1)))
四、模型继续训练过程中断点内容的恢复
注意start_epoch的设定,以保证再次训练时,epoch的次数匹配。
start_epoch = -1
Resume = False # 控制是否是恢复训练。False:初次训练。True:继续训练!!!!!
#Resume = True
#! 模型断点的设置
if Resume:
path_checkpoint = "./checkpoint/ckpt_best_5.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
self.Model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
self.optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
PS:终于下定决心解决完这个问题啦,开心!!【虽然,代码不是很难,但是自己解决还是挺有成就感的。】
版权声明:本文为qq_37844044原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。