头图

One-hot encoding of category tensors in PyTorch

This article has been authorized by the Jishi platform and first published on the public account of the Jishi platform. It may not be reprinted without permission.

foreword

One-hot form of encoding is very common in deep learning tasks, but it is not a very natural way of storing data. So in most cases, we need to manually convert it ourselves. Although the idea is very straightforward, that is, to split the categories into 0-1 vectors corresponding to one-to-one, but the specific implementation still needs to be thought about. In fact pytorch itself nn.functional has been provided one_hot method to quickly apply. But this does not affect our thinking and practice :>! Therefore, this article tries to based on the commonly used method to achieve one-hot encoding. I hope it will be useful.

The main ways are as follows:

  • for Loop
  • scatter
  • index_select

for cycle

This method is very intuitive. To put it bluntly, it is to perform an assignment (assign 1) operation to a specified position in a blank (all zero) tensor.
The key is how to set the index.
Two schemes that are essentially the same but slightly different due to the different specified dimensions are designed below.

def bhw_to_onehot_by_for(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = bhw_tensor.new_zeros(size=(num_classes, *bhw_tensor.shape))
    for i in range(num_classes):
        one_hot[i, bhw_tensor == i] = 1
    one_hot = one_hot.permute(1, 2, 3, 0)
    return one_hot


def bhw_to_onehot_by_for_V1(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = bhw_tensor.new_zeros(size=(*bhw_tensor.shape, num_classes))
    for i in range(num_classes):
        one_hot[..., i][bhw_tensor == i] = 1
    return one_hot

scatter

This method should be the common form of one_hot In fact, its main function is to assign a value to the specified position in tensor.

It is more flexible because it can use a specially constructed index matrix as an index. Of course, flexibility also brings difficulties in understanding. The explanation provided in the official documentation is quite intuitive:

'''
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html

* (int dim, Tensor index, Tensor src)
 * (int dim, Tensor index, Tensor src, *, str reduce)
 * (int dim, Tensor index, Number value)
 * (int dim, Tensor index, Number value, *, str reduce)
'''

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

The in-place replacement ( in-place src based on the replacement value of 061e3a6600358c , the tensor. In fact, in our application, it is mainly based on the in-place replacement version and the replacement value is in the form value

In the above form, we can see that by specifying the parameter tensor index , we can place src in (i,j,k) to the specified position of the method caller (here is self The specified position by the index the (i,j,k) replace the value at the coordinates (i,j,k) in dim value of the position is constituted (also reflected in the index a requirement tensor, that is the number of dimensions to and self , src (if src as tensor words. Hereinafter A concrete scalar value of 1 is used, i.e. src replaced by value (consistent). This is very consistent with the concept of one-hot Because the one-hot itself is that for the i data, the i position is 1, and the rest of the positions are 0. scatter_ for all-zero tensor can easily construct one-hot tensor, that is, place 1 in the position corresponding to the category number.

For our problem, index is well suited to be represented using an input tensor (shape B,H,W ) containing the class number. Based on this thinking, two different strategies can be conceived:

def bhw_to_onehot_by_scatter(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)
    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
    return one_hot


def bhw_to_onehot_by_scatter_V1(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)
    return one_hot

The root of the difference between the two forms lies in the treatment of shape. This brings about different application forms scatter

For the first form, the B,H,W are combined, and the advantage of this is that the understanding of the index of the channel (category) becomes intuitive.

    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))
    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)

Here, the category dimension is directly separated from other dimensions and moved to the bottom. The dimension is specified by dim , so there is such a correspondence:

zero_tensor[abc, index[abc][d]] = value  # d=0

In the second case, the first three dimensions are still retained, and the category dimension is still moved to the last position.

    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))
    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)

The corresponding relationship at this time is as follows:

zero_tensor[a,b,c, index[a][b][c][d]] = value # d=0

In addition, in the pytorch classification model library timm , a similar method is also used:

# https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L19
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
    x = x.long().view(-1, 1)
    return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)

index_select

torch.index_select(input, dim, index, *, out=None) → Tensor

- input (Tensor) – the input tensor.
- dim (int) – the dimension in which we index
- index (IntTensor or LongTensor) – the 1-D tensor containing the indices to index

This function, as the name suggests, uses the index to select the sub-tensor of the specified dimension of the tensor.

To understand the motivation for this approach, we actually need to turn it around and look at the one-hot encoding in terms of class labels.

one-hot corresponding to the original category numbers arranged from small to large is an identity matrix. So each category corresponds to a specific column (or row) of the identity matrix. This requirement is exactly in line with the function of index_select So we can use it to implement one_hot encoding, just use the category number to index a specific column or row. Here is an example:

def bhw_to_onehot_by_index_select(bhw_tensor: torch.Tensor, num_classes: int):
    """
    Args:
        bhw_tensor: b,h,w
        num_classes:
    Returns: b,h,w,num_classes
    """
    assert bhw_tensor.ndim == 3, bhw_tensor.shape
    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)
    one_hot = torch.eye(num_classes).index_select(dim=0, index=bhw_tensor.reshape(-1))
    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)
    return one_hot

Performance comparison

The overall code can be found in GitHub .

The approximate relative performance of different methods is shown below (because the program is running in the background, it may not be very accurate, it is recommended that you test it yourself). As you can see, the functions that come with pytorch are not very efficient on the CPU, but perform well on the GPU. The interesting thing is that index_select is very bright.

1.10.0 GeForce RTX 2080 Ti

cpu
('bhw_to_onehot_by_for', 0.5411529541015625)
('bhw_to_onehot_by_for_V1', 0.4515676498413086)
('bhw_to_onehot_by_scatter', 0.0686192512512207)
('bhw_to_onehot_by_scatter_V1', 0.08529376983642578)
('bhw_to_onehot_by_index_select', 0.05156970024108887)
('F.one_hot', 0.07366824150085449)

gpu
('bhw_to_onehot_by_for', 0.005235433578491211)
('bhw_to_onehot_by_for_V1', 0.045584678649902344)
('bhw_to_onehot_by_scatter', 0.0025513172149658203)
('bhw_to_onehot_by_scatter_V1', 0.0024869441986083984)
('bhw_to_onehot_by_index_select', 0.002012014389038086)
('F.one_hot', 0.0024051666259765625)

lart
126 声望6 粉丝

生活就是肩膀痛和折腾