当前位置: 首页 > news >正文

FlashDecode:Decode 阶段的 Attention 并行化改造

本文基于昇腾CANN和昇腾NPU围绕 ops-transformer 仓库的相关技术展开。FlashDecode 解决了 Decode 阶段的一个结构性浪费每个 Decode Step 只产生 1 个新 Token但 Attention 计算仍然要走完整的 QK^T 路径。FlashDecode 在 CANN 上做了一个关键优化——把多个 Decode Step 的 Attention 计算合并到一起让 NPU 的 Cube Unit 跑满。Decode 阶段 Attention 的痛点# 标准 Decode Attention——每步只算 1 个 Queryimporttorchimporttorch.nn.functionalasFdefdecode_attention(q,k_cache,v_cache,step_idx): q: [1, num_heads, 1, head_dim] —— 当前步的 Query k_cache: [1, num_heads, L, head_dim] —— L 是已缓存长度 v_cache: [1, num_heads, L, head_dim] step_idx: 当前是第几步 # Q: [1, h, 1, d] × K^T: [1, h, d, L] → score: [1, h, 1, L]scoretorch.matmul(q,k_cache.transpose(-2,-1))scorescore/(head_dim**0.5)attnF.softmax(score,dim-1)# attn: [1, h, 1, L] × V: [1, h, L, d] → [1, h, 1, d]outtorch.matmul(attn,v_cache)returnout# 问题MQ 序列长度1 → Cube 利用率只有 15-25%# 瓶颈在 ScoreV 这一步——Matrix-Vector 而不是 Matrix-Matrix每步 M1NPU 的 Cube Unit 大部分时间在等数据搬运。FlashDecode 的思路很简单把 K 缓存切块让多个 Query 并行查。FlashDecode 的块式 Attention# FlashDecode按块读取 KV Cache多个 Query Step 并行计算defflash_decode_attention(q_block,k_cache,v_cache,block_size64): q_block: [num_steps, num_heads, 1, head_dim] —— 合并多个 Decode Step 的 Q k_cache: [num_heads, total_len, head_dim] v_cache: [num_heads, total_len, head_dim] block_size: 每次从 Cache 读几组 KV num_stepsq_block.shape[0]num_headsq_block.shape[1]dq_block.shape[-1]total_lenk_cache.shape[1]# 输出累积器outputtorch.zeros(num_steps,num_heads,1,d)# 分块读取 KV Cache——NPU 的 L1 Buffer 只能装 block_size 个 KVforblock_startinrange(0,total_len,block_size):block_endmin(block_startblock_size,total_len)k_blockk_cache[:,block_start:block_end,:]# [h, bs, d]v_blockv_cache[:,block_start:block_end,:]# [h, bs, d]# Q 块 × K 块^T——现在 Mnum_steps, Kbs# Cube 实际算的是 [num_steps, d] × [d, bs] [num_steps, bs]# Mnum_steps 可以到 32-64Cube 利用率 70%forhinrange(num_heads):q_hq_block[:,h,0,:]# [num_steps, d]k_hk_block[h]# [bs, d]# 批量的 Score 计算——从 Vector 变 Matrixscore_htorch.matmul(q_h,k_h.transpose(-1,-2))# [num_steps, bs]score_hscore_h/(d**0.5)# Online-Softmax避免整段 Softmax 的显存开销local_maxscore_h.max(dim-1,keepdimTrue).values local_exptorch.exp(score_h-local_max)local_sumlocal_exp.sum(dim-1,keepdimTrue)local_outtorch.matmul(local_exp,v_block[h])# [num_steps, d]# 合并到输出——实际生产用 rescale 累加而不是简单加法output[:,h,0,:]local_out.squeeze(1)returnoutputFlashDecode 把 M1 的 Matrix-Vector 变成了 Mnum_steps 的 Matrix-Matrix。步子越大利用率越高但不能超过 64——超过了注意力分布就开始分散精度会掉。CANN 上的 FlashDecode 融合// FlashDecode 在 Ascend C 上的实现——融合了 Score Softmax 累加classFlashDecodeKernel:publicAscendC::Kernel{public:__aicore__inlineFlashDecodeKernel(){}__aicore__inlinevoidProcess()override{// 从 Global Memory 搬 Q 到 L1 BufferAscendC::LocalTensorfloatq_localAscendC::LocalAllocfloat(num_steps*head_dim);AscendC::DataCopy(q_local,gm_q,num_steps*head_dim);// 逐块处理 KV Cachefor(intblock0;blocknum_blocks;block){// 搬 K 块到 L1AscendC::LocalTensorfloatk_localAscendC::LocalAllocfloat(block_size*head_dim);AscendC::DataCopy(k_local,gm_kblock_offset,block_size*head_dim);// Cube 做 QK^T——走 MMA 指令AscendC::LocalTensorfloatscore_localAscendC::LocalAllocfloat(num_steps*block_size);// 这里触发 Cube Unit 的矩阵乘法AscendC::MatMul(score_local,q_local,k_local,AscendC::CUBE_MATRIX_TYPE::TRAN_A);// 直接在 L1 上做 Scale Softmax——不用回显存AscendC::Mul(score_local,score_local,inv_scale);AscendC::Exp(score_local,score_local);// 逐元素 ExpAscendC::ReduceSum(row_sum,score_local,1);// 逐行求和// 读 V 块算加权和AscendC::LocalTensorfloatv_localAscendC::LocalAllocfloat(block_size*head_dim);AscendC::DataCopy(v_local,gm_vblock_offset,block_size*head_dim);// Score (归一化后) V——仍在 L1 完成AscendC::MatMul(partial_out,score_local,v_local);// 累加输出——走了两轮再写回 Global MemoryAscendC::Add(output_local,output_local,partial_out);}// 最终写回AscendC::DataCopy(gm_out,output_local,num_steps*head_dim);}};实测下来 FlashDecode 在 Decode 阶段能把 GPU 的利用率从 15% 拉到 52%。每步处理 32 个合并 Query 时收益最高——再多缓存就装不下 K 块了。参考仓库FlashDecode 算子实现Runtime 多流调度
http://www.gsyq.cn/news/1360723.html

相关文章:

  • 政府科技管理部门如何推动区域创新?
  • STM32F4电池电量监测实战:用HAL库和ADC DMA,从硬件分压到软件滤波全流程解析
  • 用STM32F103C8T6+L298N+蓝牙,手把手教你做个带PID调速的智能小车电机驱动(附完整代码)
  • 2026湖州GEO优化公司全面评测:五大头部服务商排名与避坑指南 - 品牌报告
  • AI 从 “模仿智能” 到 “重构世界” 的范式跃迁
  • Java 零基础全套教程,数据结构与集合源码,笔记 168-174
  • HashMap 底层原理 面试官问 如何回答
  • 从Hub到Router:家庭网络升级踩坑实录,手把手教你选对设备
  • 从“软启动”到防误触:三极管驱动MOS管的4个经典电路场景拆解(含避坑指南)
  • 2026年南京军事夏令营大揭秘,哪家才是你的最佳之选? - GrowthUME
  • MATLAB机器人工具箱终极实战指南:从建模到控制完整解决方案
  • UHF-RFID运动检测技术原理与优化实践
  • Boss-Key:职场隐私保护终极指南,一键隐藏窗口的智能解决方案
  • 保姆级教程:手把手复现XCTF攻防世界MOBILE入门9题(附Python/Java解密脚本及避坑指南)
  • 【混合可再生能源模拟】使用遗传算法优化光伏板和电池的容量附matlab代码
  • 【模型辨识】基于最小二乘法 LS 递推最小二乘法 RLS实现Hammerstein 模型辨识非线性静态环节 + 线性ARX动态环节附Matlab代码
  • 终极配置指南:如何在macOS上快速完成res-downloader HTTPS嗅探工具完整设置
  • 【MATLAB源码-第445期】基于MATLAB的高速V2X车联网OFDM系统多普勒频偏估计补偿与误码率性能仿真
  • 泉州AI培训:泉州元数科技助力晋江市退役军人AI职业技能提升 - 新闻快传
  • 别再为虚拟机卡顿烦恼!实测VMware 16 + Ubuntu 20.04下Gazebo 11流畅运行无人船仿真的完整配置清单
  • 验证旋转中心流程
  • 飞书秒变 Claude Code 控制台:一个 Bridge 项目,正在改写 AI 编程入口
  • 九点标定验证流程
  • 从原理到实战:为什么安全工程师和红队偏爱TCP Traceroute?手把手教你用它进行网络侦察
  • MacBook到手后,除了装Homebrew,这5个zsh插件能让你的终端效率翻倍
  • 为开源AI项目配置HermesAgent使用Taotoken作为模型供应商指南
  • ShiroAttack2实战指南:从漏洞检测到内存马注入的完整揭秘
  • Taotoken多模型聚合平台助力Matlab开发者构建智能分析工具
  • 在Taotoken模型广场根据任务需求挑选合适模型的实践
  • 深圳高空广告工程:物料制作要点梳理与专业安装流程详解 - GrowthUME