TorchExplorer是一个交互式探索神经网络的可视化工具,他的主要功能如下:

TorchExplorer是一款创新的人工智能工具,专为使用非常规神经网络架构的研究人员设计。可以在本地或者wandb中生成交互式Vega自定义图表,提供网络结构的模块级可视化。在左边的面板可以模块级方式展现神经网络架构,帮助研究人员导航网络结构。在右边的图中节点表示输入/输出占位符或在转发过程中调用的特定子模块,可以深入检查模块,直方图可视化数据分布。

节点之间的边缘表示数据处理流向,并且提供对输入/输出张量,梯度规范和参数梯度的信息。最主要的是它擅长处理非标准网络架构,这样我们看代码就方便多了,以下是官网的一个演示gif

TorchExplorer需要graphviz,所以先安装graphviz

 sudo apt-get install libgraphviz-dev graphviz
 pip install torchexplorer

然后就可以使用了:

 import torch
 import torchvision
 import torchexplorer
 
 model = torchvision.models.resnet18(pretrained=False)
 dummy_X = torch.randn(5, 3, 32, 32)
 
 # Only log input/output and parameter histograms, if you don't want even these set log=[].
 torchexplorer.watch(model, log_freq=1, log=['io', 'params'], backend='standalone')
 # Do one forwards and backwards pass
 model(dummy_X).sum().backward()
 # Your model will be available at http://localhost:5000

这里需要注意的是,需要一个完整的前向和反向传播的过程,这样他才可以得到需要的信息,结果如下:

我实验了一下,这对我们看模型代码来说是一个非常好的工具,它可以让我们更深入的了解模型的架构和工作方式,推荐大家试一试,项目地址:

https://avoid.overfit.cn/post/e7be6a62915445e8ab8af0b40f13019a


deephub
119 声望91 粉丝