pytroch网络训练中断后,根据断点再次训练!!!

背景:训练集较大,训练过程需要时间过长,但是我们集群环境超级不稳定。常常,训练到部分,又得重新开始训练。

一、模型的保存

torch.save主要参数就是:需要保存的权重对象 + 保存路径

torch.save(utils_x.makeDict(Model.state_dict()), 'XX+present.pkl'))

二、模型的加载

torch.load主要参数就是:文件路径 + 指定存放位置:cpu or gpu

self.Model.load_state_dict(torch.load('XX+present.pkl', map_location='cpu'))

三、模型训练过程中断点的设置

断点过程中需要保存:网络权重、优化器权重、以及epoch,便于继续训练恢复。【起初,我是每一个epoch就保存一个权重文件。后来发现还是没有必要,可以设置例如5个epoch保存一次,防止过大内存空间的消耗】

checkpoint = {
     "net":self.Model.state_dict(),
      'optimizer':self.optimizer.state_dict(),
      "epoch":idx
  }
      if not os.path.isdir('./checkpoint'):
           os.makedirs("./checkpoint")
       torch.save(checkpoint, './checkpoint/ckpt_best_%s.pth' % (str(idx+1)))

四、模型继续训练过程中断点内容的恢复

注意start_epoch的设定,以保证再次训练时,epoch的次数匹配。

start_epoch = -1
Resume = False        # 控制是否是恢复训练。False:初次训练。True:继续训练!!!!!
#Resume = True
#! 模型断点的设置  
if Resume:
    path_checkpoint = "./checkpoint/ckpt_best_5.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    self.Model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    self.optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch'] # 设置开始的epoch     

PS:终于下定决心解决完这个问题啦,开心!!【虽然,代码不是很难,但是自己解决还是挺有成就感的。】


版权声明:本文为qq_37844044原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。