pytorch模型结构可视化

  • hiddenlayer
import hiddenlayer as h
import torch

class ConNet(torch.nn.Module):
    def __init__(self):
        super(ConNet, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2, 2)
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(32*7*7, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU()
        )
        self.out = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        output = self.out(x)
        return output

convnet = ConNet()

vis_graph = h.build_graph(convnet, torch.zeros([1, 1, 28, 28]))   # 获取绘制图像的对象
vis_graph.theme = h.graph.THEMES["blue"].copy()     # 指定主题颜色
vis_graph.save("./demo1")   # 保存图像的路径

在这里插入图片描述

  • torchviz
import torch
from torchviz import make_dot

class ConNet(torch.nn.Module):
    def __init__(self):
        super(ConNet, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2, 2)
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, 3, 1, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2)
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(32*7*7, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU()
        )
        self.out = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        output = self.out(x)
        return output

convnet = ConNet()
x = torch.randn(size=(1, 1, 28, 28))
y = convnet(x)

conv_viz = make_dot(y, params=dict(list(convnet.named_parameters()) + [("x", x)]))

conv_viz.format = "png"

conv_viz_directory = "data"
conv_viz.view()

在这里插入图片描述


版权声明:本文为cyj5201314原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。