Pytorch 自动生成格式化网络结构图
-
Pytorch 自动生成网络结构图
Pytorch没有TensorFlow类似的原生支持,尽管也可以用TensorBoardX进行可视化,但是其生成的网络图不适合直接展示。所以,有没有合适的自动化方案?
首先,可以利用一个开源库,functional-zoo
所需准备:
brew/apt-get/yum install graphviz
conda install python-graphviz
pip install torchviz
以我自己的网络为例:
import torch from torch.autograd import Variable from graphviz import Digraph from torchviz import make_dot from models import miracle_lineconv_net from config import Config opt = Config() x = Variable(torch.randn(128,2,41,9))#change 12 to the channel number of network input model = miracle_lineconv_net.MiracleLineConvNet(opt) y = model(x) g = make_dot(y) g.view()
- 如果依然有安装问题,继续下载
pip install git+https://github.com/szagoruyko/pytorchviz
官方效果如下图: