1.背景简介
当发现使用 plugin 精度 debug 工具定位到是某个 linear 敏感时,示例如下:
op_name sensitive_type op_type L1 quant_dtype flops
------------------------------------- --------------- ----------------------------- ---------------- ------------------------- ------- ------------- --------------
model.layernorm.rsqrt activation <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'> 6.52537 qint16 0(0%)
model.linear2 weight <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> 5.02445 qint8 3072000(0.00%)
model.layernorm.var_mean.pre_mean activation <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> 3.1683 qint16 0(0%)
可以发现,model.linear2 weight 排在了前面,且是 int8 量化。
接下来看下 baseline_statistic.txt 与 analysis_statistic.txt,其中有 model.linear2 的 input、weight、output 的数值分布范围,示例如下:
| Op Name | Mod Name | Attr | Min | Max | Mean | Var | Shape |
|---------------------------------------------------------------------------------------------------------------------------------------------------------------
| torch.nn.modules.linear.Linear | model.linear2 | input | 0.0000000 | 15.4210167 | 4.0793311 | 0.2532279 | torch.Size([2, 100, 256]) |
| torch.nn.modules.linear.Linear | model.linear2 | weight | -41.6590347 | 31.2311363 | -0.0053362 | 0.4427260 | torch.Size([60, 256]) |
| torch.nn.modules.linear.Linear | model.linear2 | bias | -0.4426649 | 0.3714900 | 0.0053294 | 0.0112585 | torch.Size([60]) |
| torch.nn.modules.linear.Linear | model.linear2 | output | -32.0065079 | 5.7881856 | 0.4558742 | 3.8736136 | torch.Size([2, 100, 60]) |
解决方案:使用 int16 来量化这个敏感 linear 的 weight。
如果必须要求 linear input weight output 都是 int16 量化,怎么办呢?
2.知识基础
在 征程 6E/M 上,地平线 BPU 对 linear 支持的情况如下:
本文发布时是这样的
可以看到:input 和 weight 不能同时为 int16。
3.Linear input weight both int16
对于 linear input 和 weight 均需要 int16 量化的情况,可使用 broadcast mul sum 来替代验证,无需重训 float。
异同简介:broadcast_mul_sum_replace_linear 在 float 层面可以等价替换 linear,但在量化方式上存在区别:Linear weight 是 per channel 量化,weight 作为 mul 输入时,是 per tensor 量化。一般情况下:weight int8 perchannel 变成 per tensor int16,精度是正向优化。
替换方案:在 float 训练完成后替换,然后进行 calib+qat。
class SmallModel(nn.Module):
def __init__(self, linear2_weight, linear2_bias):
super(SmallModel, self).__init__()
# 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]
self.linear1 = nn.Linear(256, 256)
self.layernorm = nn.LayerNorm(256) # 对最后一维进行归一化
self.relu = nn.ReLU()
# 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]
# self.linear2 = nn.Linear(256, 60)
self.linear2_weight = linear2_weight
self.linear2_bias = linear2_bias
# 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]
self.linear3 = nn.Linear(60, 60)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.quant_linear2_weight = QuantStub()
self.quant_linear2_bias = QuantStub()
def forward(self, x):
x = self.quant(x)
linear2_weight = self.quant_linear2_weight(self.linear2_weight)
linear2_bias = self.quant_linear2_bias(self.linear2_bias)
# 第一个 Linear
x = self.linear1(x) # [2, 100, 256]
x = self.layernorm(x) # [2, 100, 256]
x = self.relu(x) # [2, 100, 256]
# 第二个 Linear
# x = self.linear2(x) # [2, 100, 60]
# ===================================
# 使用 broadcast mul + sum 替换linear
# ===================================
# 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播
broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256) # [2, 100, 60, 256]
# 按最后一个维度求和:sum 操作模拟线性层的加权求和
sum_output = broadcast_mul.sum(dim=-1) # [2, 100, 60]
# 加上偏置
x = sum_output + linear2_bias # [2, 100, 60]
# 第三个 Linear
x = self.linear3(x)
x = self.dequant(x)
return x
broadcast mul sum 替换方案,均支持 int16。
注意事项:如果 mul 的输出 绝大多数 数值都在 0 附近 -> MSE 校准受异常值影响较大 -> 输出 scale 非常大 -> 0 附近的大量小数值被舍入成 0 -> sum 和发生巨大偏差。
影响范围:mul 后面跟着 sigmoid 或 add+sigmoid 时影响很大。
解决方案:mul 输出设置 fixed scale 为 7/32767,因为 sigmoid 并不需要太大的输入,而 mul 的输出分布需要小 scale。
4.全流程示例
从表中可以看到,在 linear 需要 int16 量化的场景,input/output int16 对应的 latency 最短,其次是 weight output int16 input int8,最差的是三者都需要 int16,针对这三种情况,下面分别提供完整的例子供参考。
信息描述
注意:非完全等价,仅作为参考
4.1 示例代码
import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn
from horizon_plugin_pytorch.quantization import hbdk4 as hb4
from hbdk4.compiler import convert, save, hbm_perf, visualize, compile
import torch
import torch.nn as nn
# 定义网络结构
class SmallModel(nn.Module):
def __init__(self, linear2_weight, linear2_bias):
super(SmallModel, self).__init__()
# 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]
self.linear1 = nn.Linear(256, 256)
self.layernorm = nn.LayerNorm(256) # 对最后一维进行归一化
self.relu = nn.ReLU()
# 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]
# self.linear2 = nn.Linear(256, 60)
self.linear2_weight = linear2_weight
self.linear2_bias = linear2_bias
# 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]
self.linear3 = nn.Linear(60, 60)
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.quant_linear2_weight = QuantStub()
self.quant_linear2_bias = QuantStub()
def forward(self, x):
x = self.quant(x)
linear2_weight = self.quant_linear2_weight(self.linear2_weight)
linear2_bias = self.quant_linear2_bias(self.linear2_bias)
# 第一个 Linear
x = self.linear1(x) # [2, 100, 256]
x = self.layernorm(x) # [2, 100, 256]
x = self.relu(x) # [2, 100, 256]
# 第二个 Linear
# x = self.linear2(x) # [2, 100, 60]
# ===================================
# 使用 broadcast mul + sum 替换linear
# ===================================
# 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播
broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256) # [2, 100, 60, 256]
# 按最后一个维度求和:sum 操作模拟线性层的加权求和
sum_output = broadcast_mul.sum(dim=-1) # [2, 100, 60]
# 加上偏置
x = sum_output + linear2_bias # [2, 100, 60]
# 第三个 Linear
x = self.linear3(x)
x = self.dequant(x)
return x
float_ckpt_path = "model_path/float-checkpoint.ckpt"
float_state_dict = torch.load(float_ckpt_path)
# 遍历 OrderedDict,查找包含 "linear2" 的键
for key, value in float_state_dict.items():
# if "linear2" in key:
# print(f"Key: {key}, Value: {value.shape}")
if key == "linear2.weight":
linear2_weight = value
if key == "linear2.bias":
linear2_bias = value
# example_input = torch.randn(2, 100, 256)
file_path = "random_data.pt"
example_input = torch.load(file_path)
model = SmallModel(linear2_weight, linear2_bias)
missing_keys, unexpected_keys = model.load_state_dict(float_state_dict, strict=False)
print("missing_keys & unexpected_keys:", missing_keys, '\n', unexpected_keys)
# 前向传播
output = model(example_input)
print("float输出数据:", output)
torch.save(output, "model_path/6_model_float_output.pt")
print("输入形状:", example_input.shape)
print("输出形状:", output.shape)
# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)
calib_model = prepare(model.eval(), example_input,
qconfig_setter=(
calibration_8bit_weight_16bit_act_qconfig_setter,
),
)
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input)
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input)
print("calib输出数据:", calib_out)
qat_bc = export(calib_model, example_input)
hb_quantized_model = convert(qat_bc, March.NASH_M)
4.2 比较替代方案的输出一致性
- linear2 weight input output int16
float输出数据: tensor([[[-0.3016, 0.1338, -0.5251, ..., -0.0551, -0.2093, -0.0308],
[-0.1969, -0.0131, -0.3287, ..., 0.3234, -0.0869, -0.0637],
[-0.3056, 0.1478, -0.2673, ..., 0.2355, -0.3487, 0.0134],
...,
[-0.3990, -0.0389, -0.1686, ..., -0.0046, -0.4131, 0.0482],
[-0.1059, 0.2431, -0.1886, ..., 0.0787, -0.3454, 0.0231],
[-0.2134, -0.1071, -0.0575, ..., 0.3434, -0.1661, 0.2248]]],
grad_fn=<ViewBackward0>)
calib输出数据: tensor([[[-0.3038, 0.1370, -0.5269, ..., -0.0571, -0.2111, -0.0296],
[-0.1975, -0.0111, -0.3280, ..., 0.3215, -0.0884, -0.0637],
[-0.3052, 0.1488, -0.2677, ..., 0.2348, -0.3479, 0.0132],
...,
[-0.3988, -0.0393, -0.1662, ..., -0.0055, -0.4117, 0.0484],
[-0.1058, 0.2442, -0.1890, ..., 0.0780, -0.3447, 0.0240],
[-0.2142, -0.1061, -0.0587, ..., 0.3422, -0.1657, 0.2255]]],
grad_fn=<ViewBackward0>)
- broadcast mul sum int16
float输出数据: tensor([[[-0.3016, 0.1338, -0.5251, ..., -0.0551, -0.2093, -0.0308],
[-0.1969, -0.0131, -0.3287, ..., 0.3234, -0.0869, -0.0637],
[-0.3056, 0.1478, -0.2673, ..., 0.2355, -0.3487, 0.0134],
...,
[-0.3990, -0.0389, -0.1686, ..., -0.0046, -0.4131, 0.0482],
[-0.1059, 0.2431, -0.1886, ..., 0.0787, -0.3454, 0.0231],
[-0.2134, -0.1071, -0.0575, ..., 0.3434, -0.1661, 0.2248]]],
grad_fn=<ViewBackward0>)
calib输出数据: tensor([[[-0.3038, 0.1370, -0.5269, ..., -0.0571, -0.2111, -0.0296],
[-0.1975, -0.0111, -0.3280, ..., 0.3215, -0.0884, -0.0637],
[-0.3051, 0.1487, -0.2678, ..., 0.2349, -0.3478, 0.0132],
...,
[-0.3988, -0.0392, -0.1662, ..., -0.0055, -0.4117, 0.0484],
[-0.1058, 0.2442, -0.1890, ..., 0.0780, -0.3447, 0.0240],
[-0.2142, -0.1061, -0.0586, ..., 0.3423, -0.1657, 0.2255]]],
grad_fn=<ViewBackward0>)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。