pytorch中的 nn.ModuleList 和 nn.Sequential

nn.ModuleList() 和 nn.Sequential() 都可以用来搭建神经网络。nn.ModuleList()函数是用来存储各个模块,前后模块是没有关联的。

class net(nn.Module):
    def __init__(self):
        super(net6, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
 
    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x
 
net = net()
print(net)
# net(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

*可以将列表解包

class net(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear_list = [nn.Linear(10, 10) for i in range(3)]
        self.linears = nn.Sequential(*self.linear_list)  ###  *可以将列表解包
 
    def forward(se

版权声明:本文为qq_40107571原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。