import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot
batch_size = 3
learning_rate =0.0002
epoch = 50
resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)
我想从 pytorch 模型中可视化 resnet
。我该怎么做?我尝试使用 torchviz
但出现错误:
'ResNet' object has no attribute 'grad_fn'
原文由 raaj 发布,翻译遵循 CC BY-SA 4.0 许可协议
make_dot
需要一个变量(即带有grad_fn
的张量),而不是模型本身。尝试: