【深度学习】自定义神经网络层(pyTorch)

自定义神经网络层

1 不含模型参数的自定义层

import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self, **kyargs):
        super(CenteredLayer, self).__init__(**kyargs)
    def forward(self, x):
        return x - x.mean()
    
layer = CenteredLayer()
layer(torch.tensor([1,2,3,4,5], dtype = torch.float))
tensor([-2., -1.,  0.,  1.,  2.])
net = nn.Sequential(nn.Linear(8, 128),CenteredLayer())
y = net(torch.rand(4, 8))
y.mean().item()
2.7939677238464355e-09

2 含模型参数的自定义层

主要在__init__定义每一层的参数和在forward()中定义前向传播操作

# 通过ParameterList定义
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])     
        self.params.append(nn.Parameter(torch.rand(4,1)))
    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
net = MyDense()
print(net)
MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)
# 通过ParameterDict定义
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
            'linear1':nn.Parameter(torch.rand(4,4)),
            'linear2':nn.Parameter(torch.rand(4,1)),
            
        })
        self.params.update({'linear3':nn.Parameter(torch.rand(4,2))})
        
    def forward(self, x, choice = 'linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)
MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)
x = torch.ones(1,4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
tensor([[2.0827, 3.0864, 2.2150, 2.3705]], grad_fn=<MmBackward>)
tensor([[2.6715]], grad_fn=<MmBackward>)
tensor([[1.6972, 2.2113]], grad_fn=<MmBackward>)
# 可以使用自定义层构造模型
net = nn.Sequential(MyDictDense(),MyDense())

print(net)
print(net(x))
Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[5.9880]], grad_fn=<MmBackward>)

参考原文
欢迎关注【OAOA


版权声明:本文为qq_43360533原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。