One-hot encoding of category tensors in PyTorch
This article has been authorized by the Jishi platform and first published on the public account of the Jishi platform. It may not be reprinted without permission.
- Original document: https://www.yuque.com/lart/ugkv9f/src5w8
- Code repository: https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch
foreword
One-hot form of encoding is very common in deep learning tasks, but it is not a very natural way of storing data. So in most cases, we need to manually convert it ourselves. Although the idea is very straightforward, that is, to split the categories into 0-1 vectors corresponding to one-to-one, but the specific implementation still needs to be thought about. In fact pytorch itself nn.functional
has been provided one_hot
method to quickly apply. But this does not affect our thinking and practice :>! Therefore, this article tries to based on the commonly used method to achieve one-hot
encoding. I hope it will be useful.
The main ways are as follows:
for
Loopscatter
index_select
for
cycle
This method is very intuitive. To put it bluntly, it is to perform an assignment (assign 1) operation to a specified position in a blank (all zero) tensor.
The key is how to set the index.
Two schemes that are essentially the same but slightly different due to the different specified dimensions are designed below.
def bhw_to_onehot_by_for(bhw_tensor: torch.Tensor, num_classes: int):
"""
Args:
bhw_tensor: b,h,w
num_classes:
Returns: b,h,w,num_classes
"""
assert bhw_tensor.ndim == 3, bhw_tensor.shape
assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
one_hot = bhw_tensor.new_zeros(size=(num_classes, *bhw_tensor.shape))
for i in range(num_classes):
one_hot[i, bhw_tensor == i] = 1
one_hot = one_hot.permute(1, 2, 3, 0)
return one_hot
def bhw_to_onehot_by_for_V1(bhw_tensor: torch.Tensor, num_classes: int):
"""
Args:
bhw_tensor: b,h,w
num_classes:
Returns: b,h,w,num_classes
"""
assert bhw_tensor.ndim == 3, bhw_tensor.shape
assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
one_hot = bhw_tensor.new_zeros(size=(*bhw_tensor.shape, num_classes))
for i in range(num_classes):
one_hot[..., i][bhw_tensor == i] = 1
return one_hot
scatter
This method should be the common form of one_hot
In fact, its main function is to assign a value to the specified position in tensor.
It is more flexible because it can use a specially constructed index matrix as an index. Of course, flexibility also brings difficulties in understanding. The explanation provided in the official documentation is quite intuitive:
'''
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html
* (int dim, Tensor index, Tensor src)
* (int dim, Tensor index, Tensor src, *, str reduce)
* (int dim, Tensor index, Number value)
* (int dim, Tensor index, Number value, *, str reduce)
'''
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
The in-place replacement ( in-place
src
based on the replacement value of 061e3a6600358c , the tensor. In fact, in our application, it is mainly based on the in-place replacement version and the replacement value is in the form value
In the above form, we can see that by specifying the parameter tensor index
, we can place src
in (i,j,k)
to the specified position of the method caller (here is self
The specified position by the index
the (i,j,k)
replace the value at the coordinates (i,j,k)
in dim
value of the position is constituted (also reflected in the index
a requirement tensor, that is the number of dimensions to and self
, src
(if src
as tensor words. Hereinafter A concrete scalar value of 1 is used, i.e. src
replaced by value
(consistent). This is very consistent with the concept of one-hot
Because the one-hot
itself is that for the i
data, the i
position is 1, and the rest of the positions are 0. scatter_
for all-zero tensor can easily construct one-hot
tensor, that is, place 1 in the position corresponding to the category number.
For our problem, index
is well suited to be represented using an input tensor (shape B,H,W
) containing the class number. Based on this thinking, two different strategies can be conceived:
def bhw_to_onehot_by_scatter(bhw_tensor: torch.Tensor, num_classes: int):
"""
Args:
bhw_tensor: b,h,w
num_classes:
Returns: b,h,w,num_classes
"""
assert bhw_tensor.ndim == 3, bhw_tensor.shape
assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)
one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
return one_hot
def bhw_to_onehot_by_scatter_V1(bhw_tensor: torch.Tensor, num_classes: int):
"""
Args:
bhw_tensor: b,h,w
num_classes:
Returns: b,h,w,num_classes
"""
assert bhw_tensor.ndim == 3, bhw_tensor.shape
assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
return one_hot
The root of the difference between the two forms lies in the treatment of shape. This brings about different application forms scatter
For the first form, the B,H,W
are combined, and the advantage of this is that the understanding of the index of the channel (category) becomes intuitive.
one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)
Here, the category dimension is directly separated from other dimensions and moved to the bottom. The dimension is specified by dim
, so there is such a correspondence:
zero_tensor[abc, index[abc][d]] = value # d=0
In the second case, the first three dimensions are still retained, and the category dimension is still moved to the last position.
one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
The corresponding relationship at this time is as follows:
zero_tensor[a,b,c, index[a][b][c][d]] = value # d=0
In addition, in the pytorch classification model library timm
, a similar method is also used:
# https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L19
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
x = x.long().view(-1, 1)
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
index_select
torch.index_select(input, dim, index, *, out=None) → Tensor
- input (Tensor) – the input tensor.
- dim (int) – the dimension in which we index
- index (IntTensor or LongTensor) – the 1-D tensor containing the indices to index
This function, as the name suggests, uses the index to select the sub-tensor of the specified dimension of the tensor.
To understand the motivation for this approach, we actually need to turn it around and look at the one-hot
encoding in terms of class labels.
one-hot
corresponding to the original category numbers arranged from small to large is an identity matrix. So each category corresponds to a specific column (or row) of the identity matrix. This requirement is exactly in line with the function of index_select
So we can use it to implement one_hot
encoding, just use the category number to index a specific column or row. Here is an example:
def bhw_to_onehot_by_index_select(bhw_tensor: torch.Tensor, num_classes: int):
"""
Args:
bhw_tensor: b,h,w
num_classes:
Returns: b,h,w,num_classes
"""
assert bhw_tensor.ndim == 3, bhw_tensor.shape
assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
one_hot = torch.eye(num_classes).index_select(dim=0, index=bhw_tensor.reshape(-1))
one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
return one_hot
Performance comparison
The overall code can be found in GitHub .
The approximate relative performance of different methods is shown below (because the program is running in the background, it may not be very accurate, it is recommended that you test it yourself). As you can see, the functions that come with pytorch are not very efficient on the CPU, but perform well on the GPU. The interesting thing is that index_select
is very bright.
1.10.0 GeForce RTX 2080 Ti
cpu
('bhw_to_onehot_by_for', 0.5411529541015625)
('bhw_to_onehot_by_for_V1', 0.4515676498413086)
('bhw_to_onehot_by_scatter', 0.0686192512512207)
('bhw_to_onehot_by_scatter_V1', 0.08529376983642578)
('bhw_to_onehot_by_index_select', 0.05156970024108887)
('F.one_hot', 0.07366824150085449)
gpu
('bhw_to_onehot_by_for', 0.005235433578491211)
('bhw_to_onehot_by_for_V1', 0.045584678649902344)
('bhw_to_onehot_by_scatter', 0.0025513172149658203)
('bhw_to_onehot_by_scatter_V1', 0.0024869441986083984)
('bhw_to_onehot_by_index_select', 0.002012014389038086)
('F.one_hot', 0.0024051666259765625)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。