读写Tensor
import torch
import torch.nn as nn
存储一个Tensor变量
x = torch.ones(3)
torch.save(x, 'x.pt')
x2 = torch.load('x.pt')
print(x2)
tensor([1., 1., 1.])
存储一个Tensor列表
x = torch.ones(3)
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
print(xy_list)
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
存储一个从字符串映射到Tensor的字典
x = torch.ones(3)
y = torch.zeros(4)
torch.save({'x':x, 'y':y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
print(xy)
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
读写模型——两种方法
- 方法一:仅保存和加载模型参数(state_dict)
- 方法二:保存和加载整个模型(model)
state_dict:从参数名称映射到参数Tensor的字典对象
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
OrderedDict([('hidden.weight', tensor([[-0.0612, -0.2414, 0.3345],
[ 0.3538, -0.1846, 0.1380]])),
('hidden.bias', tensor([-0.5306, 0.2777])),
('output.weight', tensor([[-0.5396, -0.1019]])),
('output.bias', tensor([-0.3461]))])
1.保存和加载 state_dict
PATH = "./net.pt" # 推荐的文件名后缀是 .pt or .pth
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
2.保存和加载整个模型
PATH = "model.pth"
torch.save(net, PATH)
net2 = torch.load(PATH)
版权声明:本文为qq_41995258原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。