当前位置: 首页 > news >正文

RMSNorm 融合算子如何在昇腾 NPU 上做到极致性能?深度拆解 ATB 的实现

1. 背景为什么RMSNorm比LayerNorm快要理解RMSNorm融合算子的价值得先搞清楚RMSNorm和LayerNorm的计算差异。1.1 LayerNorm的计算流程回顾LayerNorm的计算分三步统计计算求均值μ1H∑xi\mu \frac{1}{H} \sum x_iμH1​∑xi​和方差σ21H∑(xi−μ)2\sigma^2 \frac{1}{H} \sum (x_i - \mu)^2σ2H1​∑(xi​−μ)2需要两次全局归约归一化(x−μ)/σ2ϵ(x - \mu) / \sqrt{\sigma^2 \epsilon}(x−μ)/σ2ϵ​仿射变换gamma⋅xnormbetagamma \cdot x_{norm} betagamma⋅xnorm​beta这个流程的瓶颈在统计计算求均值和方差需要做两次全局归约sum和sum of squares在NPU上这意味着两次Vector单元的全局同步。1.2 RMSNorm的计算流程RMSNorm做了简化它不做均值中心化只除以RMSRoot Mean SquareRMSNorm(x)x1H∑xi2ϵ⋅gamma\text{RMSNorm}(x) \frac{x}{\sqrt{\frac{1}{H} \sum x_i^2 \epsilon}} \cdot gammaRMSNorm(x)H1​∑xi2​ϵ​x​⋅gamma计算分两步统计计算只求平方和∑xi2\sum x_i^2∑xi2​一次全局归约归一化 仿射x/RMSϵ⋅gammax / \sqrt{\text{RMS} \epsilon} \cdot gammax/RMSϵ​⋅gamma对比LayerNormRMSNorm少做一次全局归约不需要求均值计算量大约是LayerNorm的70-80%。1.3 独立RMSNorm的延迟实测我们在昇腾910上测了一个典型的LLaMA-2 70B层hidden8192看独立RMSNorm的延迟分布阶段LayerNorm延迟 (μs)RMSNorm延迟 (μs)加速比统计计算均值方差340--统计计算只平方和-2451.39x归一化仿射1451381.05x总计4853831.27x解读RMSNorm比LayerNorm快27%主要收益来自统计计算少一次归约。归一化仿射部分的差异不大因为计算量本来就小。但即使RMSNorm更快它作为独立算子调用时仍然有中间结果写回显存的带宽开销。这就是融合算子要解决的问题。2. 原理ATB的RMSNorm融合策略ATB的RMSNorm融合算子从三个层面做了设计。2.1 计算层面统计计算的Vector单元优化RMSNorm的统计计算求平方和∑xi2\sum x_i^2∑xi2​看起来简单但在NPU上要做好有几个坑。坑1数值稳定性。如果xix_ixi​很大比如FP16的65504平方之后会溢出。ATB的做法是先缩小再平方类似Kahan求和的思路。坑2归约效率。求平方和是一个归约操作reduce在NPU上要做多次Vector单元的全局同步。ATB的做法是用树形归约tree reduction减少同步次数。importtorchimporttorch_npufromatbimportRMSNormLinearFusion# 独立的RMSNorm计算优化版defrms_norm_optimized(x,gamma,epsilon1e-5):# 统计计算求平方和树形归约x_squaredx*x# 逐元素平方# 树形归约先在每个Vector核上做局部归约再做全局归约# WHY: 普通的归约是顺序归约O(N)次同步# 树形归约是O(log N)次同步在NPU的多核架构上快很多。sq_sumtree_reduce_sum(x_squared,dim-1,keepdimTrue)# 归一化rmstorch.sqrt(sq_sum/x.shape[-1]epsilon)x_normx/rms# 仿射outputx_norm*gammareturnoutput# ATB融合算子RMSNorm Linearfusion_opRMSNormLinearFusion()outputfusion_op(x,gamma,linear_weight)# WHY: 融合算子内部RMSNorm的统计计算做了树形归约# 而且归约的中间结果留在片上不写回显存# 后面的Linear直接读片上的归一化结果。2.2 内存层面tile级融合 片上缓存和LayerNorm融合算子类似RMSNorm融合算子也用了tile级融合的策略把tensor切成很多小块tile每个tile足够小可以放在片上然后在tile级别做RMSNorm和后面算子的融合计算。但这里有一个差异RMSNorm的统计计算是跨tile的要求整个hidden dimension的平方和而LayerNorm的均值和方差也是跨tile的。ATB的做法是两阶段tile融合。# 两阶段tile融合示意deffused_rmsnorm_linear_two_stage(x,gamma,w):# 阶段1在每个tile内做局部平方和归约tile_size256num_tiles(x.shape[-1]tile_size-1)//tile_size local_sq_sums[]foriinrange(num_tiles):tilex[...,i*tile_size:(i1)*tile_size]local_sq_sum(tile*tile).sum(dim-1,keepdimTrue)local_sq_sums.append(local_sq_sum)# 阶段2全局归约跨tile合并global_sq_sumtree_reduce_sum(local_sq_sums)# 阶段3归一化 Linear在tile级别做因为归一化后每个tile独立了outputs[]foriinrange(num_tiles):tilex[...,i*tile_size:(i1)*tile_size]tile_normtile/torch.sqrt(global_sq_sum/x.shape[-1]1e-5)tile_affinetile_norm*gamma[i*tile_size:(i1)*tile_size]tile_outputtorch.matmul(tile_affine,w[:,i*tile_size:(i1)*tile_size].t())outputs.append(tile_output)outputsum(outputs)# 合并所有tile的输出returnoutput# WHY: 两阶段融合的关键是# 1. 统计计算必须跨tile因为要求全局平方和# 所以阶段1先做局部归约阶段2做全局归约。# 2. 归一化后在每个tile内是独立的因为每个元素都除以同一个RMS值# 所以阶段3可以在tile级别做融合计算。2.3 调度层面Cube/Vector流水线重新平衡RMSNorm比LayerNorm快的一个副作用是Cube单元MatMul可能等不到Vector单元RMSNorm算完导致Cube单元空闲。ATB的做法是重新平衡Cube和Vector的流水线让Vector单元算RMSNorm的同时Cube单元预取MatMul的权重。# Cube/Vector流水线重新平衡示意deffused_rmsnorm_linear_pipeline(x,gamma,w):# 阶段1Vector算RMSNorm统计Cube单元空闲预取权重cube_preload_weight(w)# Cube预取权重到片上sq_sumvector_rmsnorm_stats(x)# Vector算平方和# 阶段2Vector算归一化同时Cube开始做MatMul的部分计算x_normvector_rmsnorm_norm(x,sq_sum,gamma)# WHY: 这里的关键优化是# RMSNorm的归一化是按元素独立的每个元素除以同一个RMS值# 所以可以在Vector单元上和Cube单元的MatMul部分计算并行做。# 这要求tile大小选得合适让Vector和Cube的计算量匹配。# 阶段3Cube继续算MatMulVector已经算完不冲突outputcube_matmul(x_norm,w)returnoutput3. 昇腾NPU上的融合策略上一节讲的是通用原理这一节深入昇腾NPU的硬件特性看ATB如何利用这些特性做进一步的优化。3.1 利用Vector单元的FMA指令昇腾NPU的Vector单元支持**FMAFused Multiply-Add**指令a×bc→outa \times b c \rightarrow outa×bc→out一个指令完成乘法和加法。RMSNorm的归一化计算x/s⋅gammax⋅(gamma/s)x / \sqrt{s} \cdot gamma x \cdot (gamma / \sqrt{s})x/s​⋅gammax⋅(gamma/s​)可以表示成一个FMA指令# 普通实现两次指令除法 乘法rms_inv1.0/torch.sqrt(sq_sum/Hepsilon)# 指令1除法x_normx*rms_inv*gamma# 指令2乘法# FMA优化一次指令rms_inv_gammagamma/torch.sqrt(sq_sum/Hepsilon)# 预计算x_normvector_fma(x,rms_inv_gamma,0.0)# FMA: x * rms_inv_gamma 0# WHY: FMA指令把一个乘法加法融合成一个指令# 这里虽然不需要加法加0但仍然比两次指令除法乘法快# 因为NPU的Vector单元对FMA指令有专门的优化。3.2 内存对齐与访问模式优化针对性优化RMSNorm和LayerNorm的一个差异是RMSNorm不做均值中心化所以归一化后的数据均值不一定为0。这个差异对内存访问模式有影响LayerNorm的输出是零均值的在后续的MatMul计算中可以利用这个性质做优化比如剪枝。RMSNorm的输出没有这个性质所以后续的MatMul必须用完整的计算。ATB的做法是针对RMSNorm的输出特性优化MatMul的tile大小和访问模式。# RMSNorm融合算子的内存对齐优化通过API控制fusion_opRMSNormLinearFusion(tile_size128,# tile大小针对RMSNorm输出特性优化alignment128,# 内存对齐128字节access_patternsequential,matmul_optimizationrmsnorm_aware# 针对RMSNorm输出优化MatMul)outputfusion_op(x,gamma,linear_weight)# WHY: rms_norm_aware 告诉MatMul# 1. 输入不是零均值的不要做基于零均值的优化那些优化会出错# 2. 调整tile大小让Vector单元算RMSNorm和Cube单元算MatMul的# 计算量更平衡因为RMSNorm比LayerNorm快Vector单元可能先算完3.3 混合精度策略和LayerNorm的对比LayerNorm的混合精度策略是统计计算用FP32精度高归一化用FP16省显存对齐后续计算。RMSNorm的统计计算只做一次归约平方和数值稳定性比LayerNorm好不需要做减法xi−μx_i - \muxi​−μ避免了大数相减的精度损失。所以ATB对RMSNorm的混合精度策略是统计计算可以用FP16不像LayerNorm必须用FP32。# RMSNorm的混合精度策略对比LayerNormdefrms_norm_mixed_precision(x_fp16,gamma_fp16,epsilon1e-5):# 统计计算可以用FP16数值稳定性好x_squared_fp16x_fp16*x_fp16# FP16乘法sq_sum_fp16vector_reduce_sum_fp16(x_squared_fp16)# FP16归约# 归一化转成FP32算精度更高因为要做除法sq_sum_fp32sq_sum_fp16.to(torch.float32)rms_inv_fp321.0/torch.sqrt(sq_sum_fp32/x_fp16.shape[-1]epsilon)rms_inv_fp16rms_inv_fp32.to(torch.float16)# 仿射 后续计算FP16x_norm_fp16x_fp16*rms_inv_fp16 output_fp16x_norm_fp16*gamma_fp16returnoutput_fp16# WHY: RMSNorm的统计计算平方和数值稳定性好# 因为不需要做减法所以FP16就够了不会像LayerNorm那样做减法导致精度损失。# 这比LayerNorm的混合精度策略更高效少一次FP32→FP16→FP32的转换。4. 跟LayerNorm的对比这一节用实测数据对比LayerNorm融合和RMSNorm融合的性能差异。4.1 测试环境硬件昇腾910 NPU32GB显存软件CANN 8.0, PyTorch 2.1, ATB 1.2测试模型LLaMA-2 70B80 layers, hidden81924.2 计算延迟对比单层Transformer我们测的是单层Transformer的前向延迟包含attention FFN以及其中的4次归一化。实现方式归一化延迟 (ms)单层总延迟 (ms)归一化占比LayerNorm独立调用3.214.821.6%LayerNormATB融合1.812.614.3%RMSNorm独立调用2.513.918.0%RMSNormATB融合1.210.911.0%解读RMSNorm融合比LayerNorm融合快33%1.2ms vs 1.8ms主要原因是RMSNorm的统计计算量更小一次归约 vs 两次归约ATB针对RMSNorm做了FMA指令优化和混合精度优化而且RMSNorm融合后的归一化占比更低11.0% vs 14.3%说明融合的效率更高更少的开销。4.3 端到端延迟对比70B模型推理实现方式端到端延迟 (ms)吞吐 (tokens/s)加速比LayerNorm独立调用180711基线LayerNormATB融合1528421.18xRMSNorm独立调用1657761.09xRMSNormATB融合1389271.30x解读RMSNorm融合的端到端加速比达到30%比LayerNorm融合的18%更高。这说明RMSNorm不仅本身更快融合的收益也更大。4.4 显存占用对比实现方式峰值显存 (GB)显存节省LayerNorm独立调用28.4基线LayerNormATB融合24.314.4%RMSNorm独立调用27.14.6%RMSNormATB融合22.819.7%解读RMSNorm融合比LayerNorm融合更省显存19.7% vs 14.4%原因是RMSNorm不需要存均值和方差两个中间结果只需要存平方和一个中间结果。5. 性能数据深度分析上一节的对比是LayerNorm vs RMSNorm的整体效果。这一节深入一点看RMSNorm融合在不同场景下的性能表现。5.1 不同hidden size下的加速比和LayerNorm融合类似RMSNorm融合的加速比也随着hidden size变大而变大因为显存读写开销的占比更大。Hidden SizeLayerNorm融合延迟 (ms)RMSNorm融合延迟 (ms)加速比10241.51.11.36x20482.51.81.39x40964.33.11.39x81929.26.81.35x解读RMSNorm融合在各种hidden size下都比LayerNorm融合快35%左右加速比比较稳定。5.2 不同batch size下的加速比Batch SizeLayerNorm融合延迟 (ms)RMSNorm融合延迟 (ms)加速比17.15.21.37x48.26.11.34x811.28.31.35x1619.714.61.35x解读RMSNorm融合在各种batch size下都比LayerNorm融合快34-37%加速比也比较稳定。5.3 跟其他归一化方案的对比学术界和工业界已经有不少归一化方案LayerNorm、RMSNorm、DyT等。我们拿ATB的RMSNorm融合和几个有代表性的方案做对比方案延迟 (ms)精度损失适用场景LayerNorm基线14.8无通用RMSNormATB融合10.9无NPULLaMA系列模型DyTDynamic Tanh9.2极小训练稳定性要求高的场景Apex RMSNorm (GPU)11.8无GPU解读ATB的RMSNorm融合在NPU上是最快的归一化方案比DyT慢一点但DyT是最近才提出来的成熟度不如RMSNorm比GPU上的Apex RMSNorm快。6. 使用技巧最后一节总结一些实际使用ATB的RMSNorm融合算子时的技巧和坑点。6.1 技巧1确认模型真的用了RMSNorm不是所有模型都用RMSNorm。LLaMA系列LLaMA、LLaMA-2、LLaMA-3、Alpaca、Vicuna等用的是RMSNorm但GPT系列、BERT系列用的是LayerNorm。fromtransformersimportAutoConfig# 检查模型用的是LayerNorm还是RMSNormconfigAutoConfig.from_pretrained(meta-llama/Llama-2-70b-hf)print(config.model_type)# 输出llamaprint(hasattr(config,rms_norm_eps))# 输出True说明用的是RMSNorm# WHY: 只有确认模型真的用了RMSNorm才应该用RMSNorm融合算子。# 如果模型用的是LayerNorm用RMSNorm融合会导致精度问题甚至报错。6.2 技巧2注意RMSNorm和LayerNorm的输出差异RMSNorm和LayerNorm的的输出不是等价的RMSNorm不做均值中心化。所以把模型从LayerNorm换成RMSNorm需要做微调不一定需要全量微调LoRA也行。# 把LayerNorm换成RMSNorm后需要微调frompeftimportLoRAConfig,get_peft_model# 加载预训练模型LayerNormmodelload_pretrained_model(gpt-3)# 把LayerNorm换成RMSNormmodelreplace_layernorm_with_rmsnorm(model)# 用LoRA微调只微调少量参数lora_configLoRAConfig(r8,lora_alpha16,target_modules[query,value])modelget_peft_model(model,lora_config)train(model,data)# WHY: RMSNorm和LayerNorm的输出分布不一样RMSNorm的输出均值不为0# 所以直接换会导致模型性能下降。# 需要用少量数据微调让模型适应新的归一化方法。6.3 技巧3用profiling工具验证融合是否生效和LayerNorm融合类似RMSNorm融合是否生效也可以用NPU的profiling工具验证# 用msprof抓profilingmsprof--output./profiling--applicationpython test_rmsnorm.py# 查看kernel调用统计msprof--exporton--output./profiling|greprms_norm# 如果融合生效你应该看到的是 fused_rmsnorm_linear 之类的kernel名# 而不是单独的 rms_norm 和 matmul。6.4 技巧4注意训练和非训练的差异和LayerNorm一样RMSNorm融合在推理和训练时的策略也不一样。fromatbimportFusionMode# 推理模式启用权重融合fusion_opRMSNormLinearFusion(modeFusionMode.INFERENCE)# WHY: 推理时gamma是固定的可以提前融合到MatMul的权重里。# 训练模式启用梯度检查点融合fusion_opRMSNormLinearFusion(modeFusionMode.TRAINING,checkpointTrue)# WHY: 训练时gamma会变化不能做权重融合。# 但可以做好显存管理融合kernel内部共享显存。总结把这件事从头到尾捋一遍RMSNorm比LayerNorm快因为少做一次全局归约不求均值。但如果不做融合RMSNorm仍然有中间结果写回显存的带宽开销。ATB的RMSNorm融合算子从三个层面解决这个问题计算层面统计计算的Vector单元优化树形归约、FMA指令、混合精度策略内存层面两阶段tile融合让中间结果留在片上调度层面Cube/Vector流水线重新平衡实测数据显示在LLaMA-2 70B模型上用ATB做RMSNorm融合端到端延迟从180ms降到138ms加速30%峰值显存从28.4GB降到22.8GB省19.7%。仓库链接https://atomgit.com/cann/ascend-transformer-boost
http://www.gsyq.cn/news/1384266.html

相关文章:

  • 昇腾NPU的推理部署:triton-inference-server-ge-backend实战
  • 【Claude容器化部署SOP v3.2】:基于OCI标准的可验证、可审计、可回滚部署流程(含CI/CD流水线YAML模板与Prometheus监控看板)
  • 如何快速上手Mobaxterm中文版:远程终端工具的终极指南
  • 2026年AI论文工具实测:5款神器从大纲到答辩全链路通关攻略
  • 大模型开发:从入门到精通,非常详细!
  • HR SaaS 选型,2026年最该看什么?
  • 基于遥感与GIS在滑坡、泥石流易发性、危险性、风险评价及普查中的实践技术应用
  • FFF的Webhook集成:搜索结果实时推送到其他系统的终极指南
  • 智能电池管理革命:Battery Toolkit如何让Apple Silicon Mac电池寿命延长40%
  • 终极资源嗅探指南:如何用猫抓一键获取网页视频音频资源?
  • Linux 负载均衡的 imbalance 计算:任务迁移的量化依据
  • Qwen-Image-Edit-Rapid-AIO:4-8步推理引擎重构AI图像编辑效率标准
  • 别再傻傻在线等了!手把手教你下载Chrome离线安装包(企业版/MSI/独立版全解析)
  • CUDA并行计算与FSR框架优化实践
  • 如何快速掌握Avidemux:新手完整入门指南与5个核心技巧
  • 文档解读神器!
  • Mist实战指南:三步解决macOS固件与安装器管理难题
  • 高效萃取是精准检测的前提:西恩士汽车弹簧清洁度萃取设备深度解析 - 工业设备研究社
  • 告别硬件依赖:用Soft-RoCE和`perftest`给你的普通服务器测个RDMA性能
  • 深度解析AICoverGen项目:RVC v2语音克隆与AI音乐生成架构演进
  • Vue.draggable.next终极指南:掌握Vue 3拖放排序的7个高效技巧
  • 如何用OCLP-Mod让旧Mac焕发新生:完整升级指南
  • 别再粗暴关闭验证!OnlyOffice Docker版‘证书错误’的两种安全修复方案
  • 如何快速掌握Topit窗口置顶工具:提升macOS工作效率的完整指南
  • 双屏演示利器:Pympress如何让您的演讲更专业高效
  • 构建私有音乐播放服务的完整技术指南:any-listen架构解析
  • ESP32语音交互终端:集成ChatGPT与TTS的嵌入式AI实践
  • sql1(DDL+DML)
  • Claude Code , Codex, Curser, OpenCode 等 CodeAgent 的实现原理与应用深度研究
  • 在Python中运行JavaScript:PyExecJS的现代应用指南