Model.modules和Model.children
首先我们先定义一个网络结构:
class Linear(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(),
nn.BatchNorm2d(),
nn.ReLU())
self.decoder = nn.Sequential(
nn.Linear(),
nn.Interpolate())
def forward(self,x)
output = self.encoder(x)
output = self.decoder(output)
return output
然后我们实例化类:
model = Linear()
对于Model.modules,我们遍历整个模型:
会得到:
首先从linear层开始遍历:
nn.Sequential(
nn.Conv2d(),
nn.BatchNorm2d(),
nn.ReLU)
nn.Sequential(
nn.Linear(),
nn.Interpolate())
接着从self.encoder,self.decoder层开始向下遍历:
nn.Sequential(
nn.Conv2d(),
nn.BatchNorm2d(),
nn.ReLU)
nn.Conv2d(),
nn.BatchNorm2d(),
nn.ReLU
nn.Sequential(
nn.Linear(),
nn.Interpolate())
nn.Linear(),
nn.Interpolate()
而Model.children只会遍历self.encoder和self.encoder:
nn.Sequential(
nn.Conv2d(),
nn.BatchNorm2d(),
nn.ReLU())
nn.Sequential(
nn.Linear(),
nn.Interpolate())
版权声明:本文为qq_43733107原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。