# 请问下torch loss function输入的维度，如何去扩展高维？

``b_xent=torch.nn.CrossEntropyLoss()``

## 不报错案例

``````a=torch.tensor([[0.02,0.3],[0.3,0.3],[0.3,0.3]])
b=torch.tensor([0,1,1])
b_xent(a,b)

Out[3]: tensor(0.7680)``````

## 报错案例

``````a=a.unsqueeze(0)
b=b.unsqueeze(0)
a.shape  # Out[20]: torch.Size([1, 3, 2])
b.shape  # Out[21]: torch.Size([1, 3])
b_xent(a,b)``````

``````a=torch.randn((1,273,512))
b=torch.randn((1,273))
b_xent(a,b)``````

``````Traceback (most recent call last):
File "D:\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-1-b7d4d1dd28a0>", line 3, in <module>
b_xent(a,b)
File "D:\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\lib\site-packages\torch\nn\modules\loss.py", line 961, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "D:\lib\site-packages\torch\nn\functional.py", line 2468, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "D:\lib\site-packages\torch\nn\functional.py", line 2273, in nll_loss
raise ValueError('Expected target size {}, got {}'.format(
ValueError: Expected target size (1, 512), got torch.Size([1, 273])
``````

1 个回答

###### 你尚未登录，登录后可以
• 和开发者交流问题的细节
• 关注并接收问题和回答的更新提醒
• 参与内容的编辑和改进，让解决方法与时俱进