nn.ModuleList() 和 nn.Sequential() 都可以用来搭建神经网络。nn.ModuleList()函数是用来存储各个模块,前后模块是没有关联的。
class net(nn.Module):
def __init__(self):
super(net6, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
def forward(self, x):
for layer in self.linears:
x = layer(x)
return x
net = net()
print(net)
# net(
# (linears): ModuleList(
# (0): Linear(in_features=10, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# (2): Linear(in_features=10, out_features=10, bias=True)
# )
# )
*可以将列表解包
class net(nn.Module):
def __init__(self):
super(net7, self).__init__()
self.linear_list = [nn.Linear(10, 10) for i in range(3)]
self.linears = nn.Sequential(*self.linear_list) ### *可以将列表解包
def forward(se
版权声明:本文为qq_40107571原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。