请问下torch loss function输入的维度,如何去扩展高维?

前提:

b_xent=torch.nn.CrossEntropyLoss()

不报错案例

没问题1

a=torch.tensor([[0.02,0.3],[0.3,0.3],[0.3,0.3]])
b=torch.tensor([0,1,1])
b_xent(a,b)

Out[3]: tensor(0.7680)

报错案例

若直接扩展维度,报错

a=a.unsqueeze(0)
b=b.unsqueeze(0)
a.shape  # Out[20]: torch.Size([1, 3, 2])
b.shape  # Out[21]: torch.Size([1, 3])
b_xent(a,b)

维度出错2

a=torch.randn((1,273,512))
b=torch.randn((1,273))
b_xent(a,b)

报错内容:

Traceback (most recent call last):
  File "D:\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-b7d4d1dd28a0>", line 3, in <module>
    b_xent(a,b)
  File "D:\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\lib\site-packages\torch\nn\modules\loss.py", line 961, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "D:\lib\site-packages\torch\nn\functional.py", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "D:\lib\site-packages\torch\nn\functional.py", line 2273, in nll_loss
    raise ValueError('Expected target size {}, got {}'.format(
ValueError: Expected target size (1, 512), got torch.Size([1, 273])

那么如何在input和label都升高维度,并且使用CELoss呢?

回复
阅读 1.3k
1 个回答

在计算交叉熵损失函数前,一般都是使用view()把输入的input拉成[m,c],c为分类,label拉成一维,再计算交叉熵,有点不太明白为什么要升高维度,可以看一下pytorh中关于交叉熵损失函数的介绍
交叉熵损失函数

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
宣传栏