头图

1.背景简介

当发现使用 plugin 精度 debug 工具定位到是某个 linear 敏感时,示例如下:

op_name                                sensitive_type    op_type                                                                          L1  quant_dtype    flops
-------------------------------------  ---------------   -----------------------------  ----------------  -------------------------  -------  -------------  --------------
model.layernorm.rsqrt                  activation        <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'>              6.52537  qint16         0(0%)
model.linear2                          weight            <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'>                       5.02445  qint8          3072000(0.00%)
model.layernorm.var_mean.pre_mean      activation        <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>  3.1683   qint16         0(0%)

可以发现,model.linear2 weight 排在了前面,且是 int8 量化。
接下来看下 baseline_statistic.txt 与 analysis_statistic.txt,其中有 model.linear2 的 input、weight、output 的数值分布范围,示例如下:

| Op Name                            | Mod Name       | Attr     | Min            | Max            | Mean           | Var        | Shape                       |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------
| torch.nn.modules.linear.Linear     | model.linear2  | input    | 0.0000000      | 15.4210167     | 4.0793311      | 0.2532279  | torch.Size([2, 100, 256])   |
| torch.nn.modules.linear.Linear     | model.linear2  | weight   | -41.6590347    | 31.2311363     | -0.0053362     | 0.4427260  | torch.Size([60, 256])       |
| torch.nn.modules.linear.Linear     | model.linear2  | bias     | -0.4426649     | 0.3714900      | 0.0053294      | 0.0112585  | torch.Size([60])            |
| torch.nn.modules.linear.Linear     | model.linear2  | output   | -32.0065079    | 5.7881856      | 0.4558742      | 3.8736136  | torch.Size([2, 100, 60])    |

解决方案:使用 int16 来量化这个敏感 linear 的 weight。
如果必须要求 linear input weight output 都是 int16 量化,怎么办呢?

2.知识基础

在 征程 6E/M 上,地平线 BPU 对 linear 支持的情况如下:

本文发布时是这样的

图片

可以看到:input 和 weight 不能同时为 int16。

3.Linear input weight both int16

对于 linear input 和 weight 均需要 int16 量化的情况,可使用 broadcast mul sum 来替代验证,无需重训 float。

异同简介:broadcast_mul_sum_replace_linear 在 float 层面可以等价替换 linear,但在量化方式上存在区别:Linear weight 是 per channel 量化,weight 作为 mul 输入时,是 per tensor 量化。一般情况下:weight int8 perchannel 变成 per tensor int16,精度是正向优化。

替换方案:在 float 训练完成后替换,然后进行 calib+qat。

class SmallModel(nn.Module):
    def __init__(self, linear2_weight, linear2_bias):
        super(SmallModel, self).__init__()
        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]
        self.linear1 = nn.Linear(256, 256)
        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化
        self.relu = nn.ReLU()
        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]
        # self.linear2 = nn.Linear(256, 60)
        self.linear2_weight = linear2_weight
        self.linear2_bias = linear2_bias
        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]
        self.linear3 = nn.Linear(60, 60)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.quant_linear2_weight = QuantStub()
        self.quant_linear2_bias = QuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        linear2_weight = self.quant_linear2_weight(self.linear2_weight)
        linear2_bias = self.quant_linear2_bias(self.linear2_bias)
        # 第一个 Linear
        x = self.linear1(x)  # [2, 100, 256]
        x = self.layernorm(x)  # [2, 100, 256]
        x = self.relu(x)  # [2, 100, 256]
        
        # 第二个 Linear
        # x = self.linear2(x)  # [2, 100, 60]
        # ===================================
        # 使用 broadcast mul + sum 替换linear
        # ===================================
        # 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播
        broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]
        # 按最后一个维度求和:sum 操作模拟线性层的加权求和
        sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]
        # 加上偏置
        x = sum_output + linear2_bias  # [2, 100, 60]
        
        # 第三个 Linear
        x = self.linear3(x)
        x = self.dequant(x)
        return x

broadcast mul sum 替换方案,均支持 int16。

注意事项:如果 mul 的输出 绝大多数 数值都在 0 附近 -> MSE 校准受异常值影响较大 -> 输出 scale 非常大 -> 0 附近的大量小数值被舍入成 0 -> sum 和发生巨大偏差。

影响范围:mul 后面跟着 sigmoid 或 add+sigmoid 时影响很大。

解决方案:mul 输出设置 fixed scale 为 7/32767,因为 sigmoid 并不需要太大的输入,而 mul 的输出分布需要小 scale。

4.全流程示例

从表中可以看到,在 linear 需要 int16 量化的场景,input/output int16 对应的 latency 最短,其次是 weight output int16 input int8,最差的是三者都需要 int16,针对这三种情况,下面分别提供完整的例子供参考。

信息描述
图片

注意:非完全等价,仅作为参考

4.1 示例代码

import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn
from horizon_plugin_pytorch.quantization import hbdk4 as hb4
from hbdk4.compiler import convert, save, hbm_perf, visualize, compile

import torch
import torch.nn as nn

# 定义网络结构
class SmallModel(nn.Module):
    def __init__(self, linear2_weight, linear2_bias):
        super(SmallModel, self).__init__()
        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]
        self.linear1 = nn.Linear(256, 256)
        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化
        self.relu = nn.ReLU()
        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]
        # self.linear2 = nn.Linear(256, 60)
        self.linear2_weight = linear2_weight
        self.linear2_bias = linear2_bias
        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]
        self.linear3 = nn.Linear(60, 60)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.quant_linear2_weight = QuantStub()
        self.quant_linear2_bias = QuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        linear2_weight = self.quant_linear2_weight(self.linear2_weight)
        linear2_bias = self.quant_linear2_bias(self.linear2_bias)
        # 第一个 Linear
        x = self.linear1(x)  # [2, 100, 256]
        x = self.layernorm(x)  # [2, 100, 256]
        x = self.relu(x)  # [2, 100, 256]
        
        # 第二个 Linear
        # x = self.linear2(x)  # [2, 100, 60]
        # ===================================
        # 使用 broadcast mul + sum 替换linear
        # ===================================
        # 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播
        broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]
        # 按最后一个维度求和:sum 操作模拟线性层的加权求和
        sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]
        # 加上偏置
        x = sum_output + linear2_bias  # [2, 100, 60]
        
        # 第三个 Linear
        x = self.linear3(x)
        x = self.dequant(x)
        return x

float_ckpt_path = "model_path/float-checkpoint.ckpt" 
float_state_dict = torch.load(float_ckpt_path)
# 遍历 OrderedDict,查找包含 "linear2" 的键
for key, value in float_state_dict.items():
    # if "linear2" in key:
    #     print(f"Key: {key}, Value: {value.shape}")
    if key == "linear2.weight":
        linear2_weight = value
    if key == "linear2.bias":
        linear2_bias = value

# example_input = torch.randn(2, 100, 256)
file_path = "random_data.pt"
example_input = torch.load(file_path)
model = SmallModel(linear2_weight, linear2_bias)
missing_keys, unexpected_keys = model.load_state_dict(float_state_dict, strict=False)
print("missing_keys & unexpected_keys:", missing_keys, '\n', unexpected_keys)

# 前向传播
output = model(example_input)
print("float输出数据:", output)
torch.save(output, "model_path/6_model_float_output.pt")
print("输入形状:", example_input.shape)
print("输出形状:", output.shape)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)

calib_model = prepare(model.eval(), example_input,
                      qconfig_setter=(
                          calibration_8bit_weight_16bit_act_qconfig_setter,
                          ),
                      )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input)

calib_model.eval()        
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input)
print("calib输出数据:", calib_out)
qat_bc = export(calib_model, example_input)
hb_quantized_model = convert(qat_bc, March.NASH_M)

4.2 比较替代方案的输出一致性

  • linear2 weight input output int16
float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],
         [-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],
         [-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],
         ...,
         [-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],
         [-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],
         [-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],
       grad_fn=<ViewBackward0>)
       
calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],
         [-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],
         [-0.3052,  0.1488, -0.2677,  ...,  0.2348, -0.3479,  0.0132],
         ...,
         [-0.3988, -0.0393, -0.1662,  ..., -0.0055, -0.4117,  0.0484],
         [-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],
         [-0.2142, -0.1061, -0.0587,  ...,  0.3422, -0.1657,  0.2255]]],
       grad_fn=<ViewBackward0>)
  • broadcast mul sum int16
float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],
         [-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],
         [-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],
         ...,
         [-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],
         [-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],
         [-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],
       grad_fn=<ViewBackward0>)
calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],
         [-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],
         [-0.3051,  0.1487, -0.2678,  ...,  0.2349, -0.3478,  0.0132],
         ...,
         [-0.3988, -0.0392, -0.1662,  ..., -0.0055, -0.4117,  0.0484],
         [-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],
         [-0.2142, -0.1061, -0.0586,  ...,  0.3423, -0.1657,  0.2255]]],
       grad_fn=<ViewBackward0>)

地平线智驾开发者
7 声望6 粉丝

地平线智能驾驶开发者社区旨在连接智能驾驶领域的开发者和对相关技术感兴趣的其他行业开发者、从业者。