"); //-->
一、前言
在端侧部署时(如在移动设备、嵌入式芯片上),为了加速模型推理、减少功耗和资源开销,往往会将某些计算复杂的函数(如 exp、log、tanh、sigmoid、softmax 等)**转为查表操作**。查表算子在转成定点计算时不可避免地会出现误差,此时就需要定位引起精度下降的具体算子以及对其进行针对性的优化。
本文将讲述在使用地平线 QAT 链路基于 J6 系列平台进行模型部署时,对查表算子进行精度调优的相关手段,主要包括以下内容:
- 如何确定是定点查表导致的误差?
- 如何确定具体的定点查表算子导致的误差?

二、如何确定是定点查表导致的误差
当 QAT 模型的精度处于正常状态,然而 QAT 模型 export 出来的`qat.bc` 文件精度却出现异常情况时,并且在这种情况下,我们已经对导出 `qat.bc` 文件的整个流程进行了全面且细致的检验,确认该流程不存在任何错误。那么在这样的前提条件之下,我们此时就可以开始着手验证是否是由于查表算子的因素导致了精度下降。
`horizon_plugin_pytorch`提供了 api 来辅助进行查表算子精度的验证,具体思路是将 qat model 中所有的查表算子转成定点,然后在验证集上进行精度的评测,如果和 qat.bc 的现象一样都出现了比较严重的精度下降问题,那么就说明是因为查表算子导致的误差。
以下是`horizon_plugin_pytorch`提供的 api 的使用示例,如下所示:
import torch
from horizon_plugin_pytorch.quantization import prepare,set_fake_quantize,FakeQuantState
import copy
qat_model = prepare(
copy.deepcopy(float_model),
example_inputs=example_input,
qconfig_setter=default_calibration_qconfig_setter,)
print("--"*20+"Prepare qat model success"+"--"*20)
state_dict = torch.load(ckpt_path)
qat_model.load_state_dict(new_state_dict,strict=True)
print("--"*20+"Load qat ckpt success"+"--"*20)
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
#将qat model中所有的查表算子转成定点
qat_lut_model=copy.deepcopy(qat_model)
from horizon_plugin_pytorch.nn.qat.segment_lut import QuantizedQATSegmentLUT
QuantizedQATSegmentLUT.convert_segment_lut(qat_lut_model)
#评测查表转定点的qat_lut model的精度
evaluate(qat_lut_mode,val_dataloader, .....)
#如果相对于qat model精度下降比较严重,那么就说明是查表算子导致的问题三、如何确定引起误差的查表算子?
在确定是查表算子导致的精度误差后,我们需要进一步确定是哪些查表算子导致的误差。一般来说,模型中会包含多个多种查表算子。
具体方法是结合 QAT 精度 debug 工具`horizon_plugin_profile`来做 qat_model 和 qat_lut_model(查表转定点的 qat model)的精度 debug,然后根据敏感度来确定具体的查表算子。
以下是 QAT 精度 debug 工具的使用示例,如下所示:
from horizon_plugin_profiler import QuantAnalysis, ModelProfiler qa = QuantAnalysis( baseline_model=qat_model, analysis_model=qat_lut_model, analysis_model_type="fake_quant", device_ids=0, # GPU index,若不指定则在 CPU 上 out_dir=output_dir) qa.auto_find_bad_case(data_generator=val_dataloader,metric="L1") qa.run() qa.compare_per_layer() qa.sensitivity(metric="L1")
debug 工具运行完成后,在`out_dir`会生成系列产物,这里我们主要关注逐层相似度和输出敏感度,`compare_per_layer_out.csv`和`output_xxxx_L1_sensitive_ops.txt`文件。
在获得敏感度 txt 文件后,我们就根据敏感度顺序逐步来确认引起误差的算子,具体思路如下:
1. 将 qat mode 中所有的查表算子转成定点;
2. 然后将敏感度靠前的差表算子回退到浮点进行精度评测;
如果在将某个/类别查表算子回退到浮点以后,精度指标与 qat model 区别不大,那么就说明是这个/类查表导致的精度下降。
下面为具体的操作代码:
import torch
from horizon_plugin_pytorch.quantization import prepare,set_fake_quantize,FakeQuantState
import copy
qat_model = prepare(
copy.deepcopy(float_model),
example_inputs=example_input,
qconfig_setter=default_calibration_qconfig_setter,)
print("--"*20+"Prepare qat model success"+"--"*20)
state_dict = torch.load(ckpt_path)
qat_model.load_state_dict(new_state_dict,strict=True)
print("--"*20+"Load qat ckpt success"+"--"*20)
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
#将qat model中所有的查表算子转成定点
qat_lut_model=copy.deepcopy(qat_model)
from horizon_plugin_pytorch.nn.qat.segment_lut import QuantizedQATSegmentLUT
QuantizedQATSegmentLUT.convert_segment_lut(qat_lut_model)
#将敏感度靠前的算子回退到浮点
#"decoder.decoder._generated_log_1.log"为算子名称
qat_lut_model.get_submodule("decoder.decoder._generated_log_1.log").quantized_forward = False
....
#评测查表转定点的qat_lut model的精度
evaluate(qat_lut_mode,val_dataloader, .....)
#如果相对于qat model精度下降比较严重,那么就说明是查表算子导致的问题这边补充说明两点:
- 侧重关注`output_xxxx_L1_sensitive_ops.txt`中 L1 误差较大的算子,可以先回退 L1 数值较大的算子到浮点;
- 确定具体的查表算子可能要进行多次精度评测,建议采用二分法,即在第一次进行精度评测的时候尽可能选择较多的算子进行回退(理想情况是第一次精度就恢复),然后使用二分法进一步确定具体的算子;
### 在下一篇文章中,我们将演示在定位到具体算子后,如何进行精度调优,敬请期待!!!
专栏文章内容及配图由作者撰写发布,仅供工程师学习之用,如有侵权或者其他违规问题,请联系本站处理。 联系我们
相关推荐
加快实现自动驾驶(完整小组讨论)
自动驾驶正推动汽车行业加速布局人形机器人
Ouster推出 Rev8 OS 激光雷达系列 原生彩色激光雷达正式落地
曲面显示屏取代传统汽车挡风玻璃
实时训练驾驶人工智能
自动驾驶的现状与未来(节选)
面向算法硬件加速的FPGA实现方法
简单实用的单片机CRC 快速算法
加密算法之MD5算法
目标跟踪算法在红外热成像跟踪技术上的应用
计算机科学与技术反思录(2)
2035年自动驾驶出租车市场规模将达1680亿美元
[转帖]us/os就绪表的维护算法分析
采用Mean-Shift和Camshift算法相结合的火焰视频图像跟踪设计
CRC算法原理及C语言实现
携手ADI赢得未来
地平线征程 6 系列集成 Cadence Tensilica Vision DSP,实现规模化量产,合作加速智能驾驶解决方案部署
高阶智驾要落地,线控底盘为什么必须执行得准
PID算法
有关指纹算法
求FSK信号的解调算法,主要是铁路上的移频信号!
数字PID控制算法之一
掘金自动驾驶,不要把大坑当机会
76-81GHz自动驾驶CMOS RADAR
ADI:传感技术助力未来自动驾驶的发展
无线传感器网络低功耗分簇路由算法设计
基于LPC2138的血压测量算法开发平台电路图
特斯拉监督版FSD加入中国市场
数字PID控制及其改进算法的应用
vxwokrs下静态图像压缩算法(上)