Pytorch模型的网络结构可视化

import torch

from torchvision import models

from torchviz import make_dot

model = models.resnet50()

x = torch.randn(1, 3, 224, 224)

vis_graph = make_dot(model(x),params=dict(model.named_parameters()))

vise_graph.view()


版权声明:本文为y1556368418原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。