Pytorch——读写Tensor、读写模型

读写Tensor

import torch 
import torch.nn as nn

存储一个Tensor变量

x = torch.ones(3)
torch.save(x, 'x.pt')
x2 = torch.load('x.pt')
print(x2)
tensor([1., 1., 1.])

存储一个Tensor列表

x = torch.ones(3)
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
print(xy_list)
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储一个从字符串映射到Tensor的字典

x = torch.ones(3)
y = torch.zeros(4)
torch.save({'x':x, 'y':y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
print(xy)
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

读写模型——两种方法

  • 方法一:仅保存和加载模型参数(state_dict
  • 方法二:保存和加载整个模型(model

state_dict:从参数名称映射到参数Tensor的字典对象

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
net.state_dict()
OrderedDict([('hidden.weight', tensor([[-0.0612, -0.2414,  0.3345],
                      [ 0.3538, -0.1846,  0.1380]])),
             ('hidden.bias', tensor([-0.5306,  0.2777])),
             ('output.weight', tensor([[-0.5396, -0.1019]])),
             ('output.bias', tensor([-0.3461]))])
1.保存和加载 state_dict
PATH = "./net.pt" # 推荐的文件名后缀是 .pt or .pth
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
2.保存和加载整个模型
PATH = "model.pth"
torch.save(net, PATH)
net2 = torch.load(PATH)

版权声明:本文为qq_41995258原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。