Pytorch机器学习的一般训练方法整理


一、前言

  1. 该训练方法是根据李沐老师的 d2l 包整理出来
  2. 将所有涉及到训练部分的 d2l 包方法都抽取出来进行了逐行注释说明
  3. 新增训练权重的自动保存功能
  4. 新增断点续训功能
  5. 重写了可视化部分的代码
  6. 数据读取部分可以参考此文章中针对训练数据的批量增强与加载 https://blog.csdn.net/weixin_43721000/article/details/126286690

二、代码

1.调用方法

训练时需要调用此方法,指定如下参数

train_ch6(
    net=net,                    # 指定模型
    num_epochs=num_epochs,      # 指定迭代次数
    lr=lr,                      # 指定学习率
    train_iter=train_iter,      # 指定训练集
    test_iter=validation_iter,  # 指定验证集
    resume=True                 # 是否断点续训
)

2.具体实现

训练方法如下图

from torch import nn
import torch
import time
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib_inline import backend_inline


class Timer(object):
    '''
    计时工具类
    '''
    def __init__(self):
        """Defined in :numref:`subsec_linear_model`"""
        self.times = []
        self.start()

    def start(self):
        """Start the timer."""
        self.tik = time.time()

    def stop(self):
        """Stop the timer and record the time in a list."""
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        """Return the average time."""
        return sum(self.times) / len(self.times)

    def sum(self):
        """Return the sum of time."""
        return sum(self.times)

    def cumsum(self):
        """Return the accumulated time."""
        return np.array(self.times).cumsum().tolist()


class TrainVision(object):
    '''
    训练可视化工具类
    '''
    def __init__(self):

        # svg模式
        backend_inline.set_matplotlib_formats('svg')

        # 用于显示正常中文标签
        plt.rcParams['font.sans-serif'] = ['SimHei']

        # 在 1*1 的画布 fig 上创建图纸 ax
        self.fig, self.ax = plt.subplots(1, 1, figsize=(3.5, 2.5))
        # 展示网格线
        self.ax.grid()
        # x轴标签
        self.ax.set_xlabel('epochs')
        # y轴标签
        self.ax.set_ylabel('acc & loss')
        # epoch、训练精度、训练损失、测试精度 的累加数组
        self.train_acc = {'x': [], 'y': []}
        self.train_loss = {'x': [], 'y': []}
        self.test_acc = {'x': [], 'y': []}
        # 是否加载过图例的标记
        self.is_init_legend = False

    def add(self, epoch_x, train_acc_y1, train_loss_y2, test_acc_y3):

        # 加入位置信息数组
        if epoch_x:
            if train_acc_y1:
                self.train_acc['x'].append(epoch_x)
                self.train_acc['y'].append(train_acc_y1)
            if train_loss_y2:
                self.train_loss['x'].append(epoch_x)
                self.train_loss['y'].append(train_loss_y2)
            if test_acc_y3:
                self.test_acc['x'].append(epoch_x)
                self.test_acc['y'].append(test_acc_y3)

        # 绘制
        self.ax.plot(self.train_acc['x'], self.train_acc['y'], color='blue', label='train loss')
        self.ax.plot(self.train_loss['x'], self.train_loss['y'], color='red', label='train acc')
        self.ax.plot(self.test_acc['x'], self.test_acc['y'], color='green', label='test acc')

        # 图例仅在首次加载时创建
        if not self.is_init_legend:
            plt.legend()
            self.is_init_legend = True

        plt.draw()
        plt.pause(0.001)


def accuracy(y_hat, y):
    '''
    根据预测值和真实值计算模型准确率
    :param y_hat:
    :param y:
    :return:
    '''
    # 如果 y_hat 大于一维,那么认为输出结果是概率 [[1.37, -2.13], [0.76, 3.92]...] 不是标签值[0, 1...],需要转换成标签值
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = torch.argmax(y_hat, dim=1)
    # 获得准确率矩阵
    # tensor([False, True, False, False...])
    cmp = y_hat == y
    # sum求 True 的个数【True的值是1,False的值是0,因此sum之后就是True的个数】
    return float(torch.sum(cmp))


def evaluate_accuracy_gpu(net, data_iter, device=None):
    '''
    计算验证集准确率
    先用验证集数据跑一遍模型得到 y_hat,再用 y_hat 和 y 跑 accuracy() 方法,计算模型在验证集上的准确率
    :param net:
    :param data_iter:
    :param device:
    :return:
    '''
    # 自动适配模型使用的训练设备 -------------------------------
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # -----------------------------------------------------

    # 验证集的 总准确率 与 验证样本数 的累加
    metric = {
        'sum_val_acc': 0,           # 验证集总准确率
        'sum_val_sample': 0,        # 验证样本数
    }

    with torch.no_grad():
        for X, y in data_iter:

            # 样本和标签移动到训练设备 -----------------------------------
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)

            # 前向传播预测结果
            y_hat = net(X)
            # 计算准确率,累加 准确率 和 样本数
            metric['sum_val_acc'] += accuracy(y_hat, y)
            metric['sum_val_sample'] += y.shape[0]

    # 返回平均准确率
    return metric['sum_val_acc'] / metric['sum_val_sample']


def train_ch6(net, train_iter, test_iter, num_epochs, lr, resume):
    '''
    训练方法
    :param net: 
    :param train_iter: 
    :param test_iter: 
    :param num_epochs: 
    :param lr: 
    :param resume: 
    :return: 
    '''
    
    # # 随机初始化模型权重 ----------------------------------------------------------
    # # 视情况而定,如果模型传入时加载了预训练权重,那么这里不用解开注释,否则权重会被覆盖为随机值
    # # 如果想从零开始训练,那么解开注释随机生成标准化权重
    # def init_weights(m):
    #     if type(m) == nn.Linear or type(m) == nn.Conv2d:
    #         nn.init.xavier_uniform_(m.weight)
    # net.apply(init_weights)
    # # -------------------------------------------------------------------------


    # 指定训练设备,有gpu用gpu,没用就用cpu
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    print('----- 训练设备:{} -----', device)

    # 模型移动到训练设备
    net.to(device)

    # 优化器【根据具体问题而定,一般用Adam就行】
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    # 损失函数【根据具体问题而定】
    loss = nn.CrossEntropyLoss()

    # 断点续训  ---------------------------------------------------------------------
    # 起始训练epoch数
    start_epochs = 0
    # 是否需要从上次的状态继续训练
    if resume:
        # 加载项目目录下的weight文件夹中最新的权重 -----------------------------
        model_dir = os.path.join(os.getcwd(), 'weights')
        weight_filename = os.listdir(model_dir)[-1]
        weights_path = os.path.join(model_dir, weight_filename)
        print('----- 断点续训,加载训练进度: {} -----'.format(weights_path))
        # 加载训练断点
        checkpoint = torch.load(weights_path)
        # 加载模型权重
        net.load_state_dict(checkpoint['model_state_dict'])
        # 加载优化器进度
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # 加载 epoch
        start_epochs = checkpoint['epoch']
        print('----- 加载完毕 -----')
    # -------------------------------------------------------------------

    # 绘制训练进度
    tv = TrainVision()

    timer, num_batches = Timer(), len(train_iter)

    # 训练,遍历 epoch
    print('----- 开始训练 -----')
    for epoch in range(start_epochs, num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples

        # 初始化长度为3的list累加当前 epoch 的 训练总损失、总准确率 和 样本数量
        metric = {
            'sum_train_loss': 0,
            'sum_train_acc': 0,
            'sum_sample': 0,
        }

        # 开启训练模式,此模式下会计算更新梯度
        net.train()

        # 训练,遍历 batch
        for i, (X, y) in enumerate(train_iter):

            # 计时开始【每批次训练耗时统计】
            timer.start()

            # 优化器梯度清零
            optimizer.zero_grad()

            # 样本和标签移动到训练设备
            X, y = X.to(device), y.to(device)

            # 前向传播
            y_hat = net(X)

            # 真实值与标签计算损失
            l = loss(y_hat, y)

            # 反向传播计算梯度
            l.backward()

            # 更新模型权重
            optimizer.step()

            # 累加当前 epoch 的 训练总损失、训练总准确率 和 样本总数 ------------------------------------------------------------------
            with torch.no_grad():   # 不计入梯度
                # metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
                metric['sum_train_loss'] += l * X.shape[0]
                metric['sum_train_acc'] += accuracy(y_hat, y)
                metric['sum_sample'] += X.shape[0]

            # 计算当前 epoch 的 训练平均损失
            train_l = float(metric['sum_train_loss'] / metric['sum_sample'])
            # 计算当前epoch 的 训练平均准确率
            train_acc = float(metric['sum_train_acc'] / metric['sum_sample'])

            # 每个 epoch 执行5次,将当前 epoch 的 训练平均损失 与 训练平均准确率 打印到控制台并绘制曲线
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                # 控制台输出,打印当前 epoch 的 训练平均损失 与 训练平均准确率
                print("epochs: {}, batches: {}, loss: {}, train_acc: {}".format(epoch+1, i+1, train_l, train_acc))
                # 画图,绘制当前 epoch 的 训练平均损失 与 训练平均准确率
                # print(type(train_l), type(train_acc))
                tv.add(epoch + (i + 1) / num_batches, train_l, train_acc, None)
            # --------------------------------------------------------------------------------------------------------------

            # 计时结束【每批次训练耗时统计】
            timer.stop()
            
        # 每个 epoch 计算验证集平均准确率
        test_acc = evaluate_accuracy_gpu(net, test_iter)

        # 画图,绘制验证集平均准确率
        tv.add(epoch + 1, None, None, test_acc)

        # 每个 epoch 保存模型权重 ---------------------------------------------------------------
        checkpoint = {
            "model_state_dict": net.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch
        }
        net_name = 'resnet34'
        weight_dir = os.path.join(os.getcwd(), 'weights')
        if not os.path.exists(weight_dir):      # weight文件夹不存在就创建
            os.mkdir(weight_dir)
        path_checkpoint = os.path.join(weight_dir, '{}__time_{}__epoch_{}__acc_{}__val_{}__loss_{}.t7'.format(net_name, time.strftime('%Y.%m.%d-%H.%M.%S', time.localtime()), epoch, train_acc, test_acc, train_l))
        print('----- 保存训练断点:{} -----'.format(path_checkpoint))
        torch.save(checkpoint, path_checkpoint)
        print('----- 保存完毕 -----')
        # ------------------------------------------------------------------------------------

    # 打印最后一个 epoch 的 训练平均损失、训练平均精度、验证平均准确率
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')
    # 打印设备运行速度,样本总数*epoch/总用时
    print(f'{metric["sum_sample"] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')


版权声明:本文为weixin_43721000原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。