一、快速搭建
这是上一篇博客中所用的搭建方法,看起来确实复杂。
主要思路:我们用 class 继承了一个 torch 中的神经网络结构, 然后对其进行了修改。
#-----建立神经网络-----
class Net(torch.nn.Module): #继承torch的组件
#class(类)是面向对象编程的基本概念,是一种自定义数据结构类型 命名为Net
#init搭建这个信息层所需要的信息
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__() # 继承 __init__ 功能
#首先找到Net的父类(比如是类nn.Module),然后把类Net的对象self转换为类nn.Module的对象,然后“被转换”的类nn.Module对象调用自己的init函数
self.hidden = torch.nn.Linear(n_feature, n_hidden) # 隐藏层线性输出
self.predict = torch.nn.Linear(n_hidden, n_output) # 输出层线性输出
#nn.Linear():用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
#【此处注释一】#
#把前面的内容一个一个传递到这里进行组合,搭流程图就是forward所做的事情。
def forward(self, x): #x为输入信息
# 正向传播输入值, 神经网络分析出输出值
x = F.relu(self.hidden(x)) # activation function for hidden layer
#relu是小于0的改为0,大于0的不管的激励函数
x = self.predict(x) # linear output
#此步骤用来输出
return x
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
#输入值1个,神经元10个,输出1个。
print(net) # net architecture
"""
Net (
(hidden): Linear (1 -> 10)
#此处意思为hidden linear 从一个输入到10个神经元
(predict): Linear (10 -> 1)
#此处意思为hidden linear 从10个神经元输出一个
)
"""
现在给出一种快速搭建的简单方法:
net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
#输入值1个,神经元10个,输出1个。
print(net)
我们会发现 方法2中net
多显示了一些内容, 这是为什么呢? 原来他把激励函数也一同纳入进去了, 但是 net1
中, 激励函数实际上是在 forward()
功能中才被调用的. 这也就说明了, 相比 net2
, net1
的好处就是, 你可以根据你的个人需要更加个性化你自己的前向传播过程, 比如(RNN). 不过如果你不需要七七八八的过程, 相信 net2
这种形式更适合你.
二、saveAndLoad
想说啥都在代码注释里
import torch
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
def save():
# 神经网络搭建
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
#训练
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 绘图
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# 2 个方法
torch.save(net1, 'net.pkl') # save entire net
torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters
#提取网络
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x)
# 画图
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
#只提取网络参数
def restore_params():
# restore only the parameters in net1 to net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# 画图
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
# save net1
save()
# restore entire net (may slow)
restore_net()
# restore only the net parameters
restore_params()
三、批训练
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
#每批训练5个
BATCH_SIZE = 5
# BATCH_SIZE = 8
x = torch.linspace(1, 10, 10) # 从1到10的十个点
y = torch.linspace(10, 1, 10) # 从10到1的十个点
#把x,y放到数据库里
torch_dataset = Data.TensorDataset(x, y)
#loader使训练分成小批
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多线程来读数据
)
def show_batch():
for epoch in range(3): # 数据总体训练3次
for step, (batch_x, batch_y) in enumerate(loader): # 每一步 loader 释放一小批数据用来学习
# 训练数据
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
show_batch()
最后两行的解释:
一个python文件通常有两种使用方法,第一是作为脚本直接执行,第二是 import 到其他的 python 脚本中被调用(模块重用)执行。因此 if __name__ == 'main': 的作用就是控制这两种情况执行代码的过程,在 if __name__ == 'main': 下的代码只有在第一种情况下(即文件作为脚本直接执行)才会被执行,而 import 到其他脚本中是不会被执行的。
版权声明:本文为qq_41754350原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。