pytorch语法问题 loss.backword() 缺少入参?

我是一个母语java的后端开发,最近在学python和pytorh,对脚本语言不太熟悉。
我从教程中看到一个简单模型的训练过程的部分代码如下

    for epoch in range(opt.epochs):

        loss_mean = 0.
        correct = 0.
        total = 0.

        net.train()
        for i, data in enumerate(train_loader):

            # forward
            inputs, labels = data
            outputs = net(inputs)

            # backward
            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()

            # update weights
            optimizer.step()

我的疑惑如下:
loss是net(模型)的输出经过损失函数计算得到的损失对象,optimizer是优化器。loss.backword()是反向传播,更新模型中参数的梯度。有了梯度之后,优化器根据梯度对模型参数进行更新。这些过程我大概是没有理解错的。
但是这个过程有些函数我感觉缺少入参。比如,loss.backward()应该是对net.parameter更新梯度属性,那应该是loss.backward(net.parameter); 以及optimizer.step()也是对net.parmeter更新,应该是optimizer.step(net.parameter)。
如果不传入参,那么一定是之前的某些过程中获得了net.parameter的引用,不然不可能更新到它。optimizer在声明的时候是这样写的

optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9)  # 选择优化器

这里可以看到在声明时传递了parameters的引用,我还能理解。但是loss是什么时候获得parameters的引用的呢?以及loss是不是真的持有net.parameters的引用呢?

阅读 759
1 个回答

在 forward 的每一步里,比如加法、乘法、Relu、等等,如何做 backward 就已经记录到了结果的 tensor 里。所以可以在最终结果 tensor 里直接调用 backward 。

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题