pytorch网络可视化-查看每层网络的特征图
当今发论文离不开图像的可视化
于是乎我就业余时间打算学习一下如何可视化每一层的图像
一般来说会选取tensorboard来进行可视化
这里采用另外一种方式
def forward(self, x):
outputs = []
x = self.conv1(x)
outputs.append(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
outputs.append(x)
x = self.layer2(x)
outputs.append(x)
x = self.layer3(x)
outputs.append(x)
x = self.layer4(x)
outputs.append(x)
return outputs
这是resnet的网络我们将每一层的输出添加进入outputs里面方便后面可视化
可视化关键代码
# forward
out_put = model(img)
#获取输出列表这是一个列表,里面每个代表了每层的输出
for feature_map in out_put:
#通过一个迭代器来遍历每个特征图
# [N, C, H, W] -> [C, H, W]
im = np.squeeze(feature_map.detach().numpy())#把tensor变成numpy
# [C, H, W] -> [H, W, C]
im = np.transpose(im, [1, 2, 0])
#对图像的通道进行处理
# show top 12 feature maps
plt.figure()
for i in range(12):
ax = plt.subplot(3, 4, i+1)
# [H, W, C]
plt.imshow(im[:, :, i], cmap='gray')
plt.show()
原图
**
第一层
**
第二层

第三层

第四层

第五层

版权声明:本文为qq_26024067原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。