Pytorch获取中间层输出

钩子截流:

        这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
        函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。实际过程中建议先打印所有层的名字以做到精确提取。给出代码如下:

网络:


class net1(nn.Module):
    def __init__(self):
        super(net1, self).__init__()
 
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Conv2d(3, 6, kernel_size=3, stride=1, padding=0, bias=False),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 12

版权声明:本文为qq_45100200原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。