PyTorch实现梯度累加变相扩大batch

本篇文章,主要针对显卡过小或者不容易实现大batch来求梯度更新网络参数

1)常规情况下,pytorch求梯度和进行网络参数更新

    outputs = model(inputs)
    loss = criterion(outputs,inputs)

    optimizer.zero_grad() #清空梯度
    loss.backward()       #反向传播,求梯度
    optimizer.step()      #根据优化器更新网络参数

2)显卡过小或者不容易实现大batch来求梯度更新网络参数,但又想试一下呢,可以按照以下代码进行模拟

    outputs = model(inputs)
    loss = criterion(outputs,inputs)

    loss = loss/batch_size   #相当于平均了loss
    loss.backward()          #求梯度,后面没有马上清0

    if cnt%batch_size==0:
        optimizer.step()     #根据累计到batch_size个梯度,进行网络参数更新
        optimizer.zero_grad()#梯度清0   


版权声明:本文为u013289254原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。