保存和提取神经网络

之前我们学习了如何去搭建训练神经网络,接下来就要学习如何将训练好的神经网络进行保存,在需要的时候进行提取呢?

保存神经网络有两种方法:

  • 保存整个神经网络
  • 保存神经网络的参数,不保存他的结构

提取也有两种方法:

  • 直接进行提取
  • 先创建一个跟提取的神经网络一模一样的神经网络,然后再将各个参数传入神经网络即可。

 小demo:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import  Variable

#fake data
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x.pow(2)+0.2*torch.rand(x.size())

x,y=Variable(x),Variable(y)
def save():
    net=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )

    optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
    loss_func=torch.nn.MSELoss()

    for i in range(100):
        prediction=net(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #保存整个网络
    torch.save(net,'net.pkl')
    #保存网络的参数
    torch.save(net.state_dict(),'net_params.pkl')
    plt.subplot(131)
    plt.title('net')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

def restore_net():
    #提取整个网络
    net1=torch.load('net.pkl')
    prediction=net1(x)
    plt.subplot(132)
    plt.title('net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

def restore_params():
    #先创建一个跟 原来网络一样的网络结构
    net2=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    #将参数传入网络
    net2.load_state_dict(torch.load('net_params.pkl'))
    prediction=net2(x)
    plt.subplot(133)
    plt.title('net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()
save()
restore_net()
restore_params()


版权声明:本文为xs_211314原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。