pytorch网络可视化-查看每层网络的特征图

pytorch网络可视化-查看每层网络的特征图

当今发论文离不开图像的可视化

于是乎我就业余时间打算学习一下如何可视化每一层的图像

一般来说会选取tensorboard来进行可视化

这里采用另外一种方式

  def forward(self, x):
        outputs = []
        x = self.conv1(x)
        outputs.append(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        outputs.append(x)
        x = self.layer2(x)
        outputs.append(x)
        x = self.layer3(x)
        outputs.append(x)
        x = self.layer4(x)
        outputs.append(x)

        return outputs

这是resnet的网络我们将每一层的输出添加进入outputs里面方便后面可视化

可视化关键代码


# forward
out_put = model(img)
#获取输出列表这是一个列表,里面每个代表了每层的输出
for feature_map in out_put:
    #通过一个迭代器来遍历每个特征图
    # [N, C, H, W] -> [C, H, W]
    im = np.squeeze(feature_map.detach().numpy())#把tensor变成numpy
    # [C, H, W] -> [H, W, C]
    im = np.transpose(im, [1, 2, 0])
	#对图像的通道进行处理
    # show top 12 feature maps
    plt.figure()
    for i in range(12):
        ax = plt.subplot(3, 4, i+1)
        # [H, W, C]
        plt.imshow(im[:, :, i], cmap='gray')
    plt.show()

原图在这里插入图片描述
**

第一层

**
在这里插入图片描述

第二层

在这里插入图片描述

第三层

在这里插入图片描述

第四层

在这里插入图片描述

第五层

在这里插入图片描述


版权声明:本文为qq_26024067原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。