前言
尽管pytorch 已经集成了tensorboard的接口,但是你还要下载安装tensorboard工具。
下载tensorboard:
pip install tensorboard.
不行的话,再安装tensorboardX,是早些时候专门给pytorch用的tensorboard。
pip install tensorboardX。
效果
tensorboard用网页的方式把很多的信息都展现出来,比较方便。上方image和graph分别代表你训练的数据和你的深度学习网络结构图。
最简单的例子讲解
定义一个学习网络,来分类FashionMNIST,在SummaryWriter的时候,就开始用tensorboard了。
我会分段讲解,但是最好是先在文末拷贝整体代码再回来对照代码看。
首先import,和定义一些工具类,没什么好说的。get_num_correct函数是得到预测结果和label相同数目的函数。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
def get_num_correct(preds,labels):
return preds.argmax(dim=1).eq(labels).sum().item()
定义网络
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1=nn.Linear(in_features=12*4*4,out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t=F.relu(self.conv1(t))
t=F.max_pool2d(t,kernel_size=2,stride=2)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t,kernel_size=2,stride=2)
t=t.flatten(start_dim=1)
t=F.relu(self.fc1(t))
t=F.relu(self.fc2(t))
t=self.out(t)
return t
main函数里面,通过pytorch的工具类torchvision导入MNIST数据集,然后用data loader加载进来,为训练做准备。
if __name__ == '__main__':
train_set=torchvision.datasets.FashionMNIST(
root='./data-source',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
train_loader=torch.utils.data.DataLoader(train_set,batch_size=100,shuffle=True)
(续上main函数)接着声明的summary writer就是用到tensorboard的类,tensorboard能够记录模型学过程中的很多量,然后用图表的方式显示出来。
#tensor board
tb=SummaryWriter()
network=Network()
#取出训练用图
images,labels=next(iter(train_loader))
grid=torchvision.utils.make_grid(images)
#想用tensorboard看什么,你就tb.add什么。image、graph、scalar等
tb.add_image('images', grid)
tb.add_graph(model=network,input_to_model=images)
tb.close()
exit(0)
写好代码之后,运行一遍,看有没有错误,有错误的地方tensorboard不会储存也不会显示。
运行之后这个目录下会出现runs目录,里面储存量tensorboard要显示的数据。
然后在这个目录下cmd,指定吧runs目录下的数据在tensorboard显示,开启tensorboard服务
tensorboard --logdir=runs
然后会出现这个
这样在浏览器访问本地服务6006端口就可以看到开头的效果了。
最后,完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
def get_num_correct(preds,labels):
return preds.argmax(dim=1).eq(labels).sum().item()
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1=nn.Linear(in_features=12*4*4,out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t=F.relu(self.conv1(t))
t=F.max_pool2d(t,kernel_size=2,stride=2)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t,kernel_size=2,stride=2)
t=t.flatten(start_dim=1)
t=F.relu(self.fc1(t))
t=F.relu(self.fc2(t))
t=self.out(t)
return t
if __name__ == '__main__':
train_set=torchvision.datasets.FashionMNIST(
root='./data-source',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
train_loader=torch.utils.data.DataLoader(train_set,batch_size=100,shuffle=True)
#tensor board
tb=SummaryWriter()
network=Network()
#取出训练用图
images,labels=next(iter(train_loader))
grid=torchvision.utils.make_grid(images)
#想用tensorboard看什么,你就tb.add什么。image、graph、scalar等
tb.add_image('images', grid)
tb.add_graph(model=network,input_to_model=images)
tb.close()
exit(0)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。