当前位置: 首页 > news >正文

昇腾CANN ops-transformer FlashAttention 反向传播:不存 Attention 矩阵怎么求梯度

FlashAttention 前向传播的精髓不存 N×N 的 attention 矩阵只存 O(N) 的输出和 softmax 归一化因子。反向传播时需要 attention 矩阵来计算梯度——但矩阵没存。解法重新算一遍。用额外的计算换显存——这是典型的 compute-for-memory tradeoff。512K 上下文下标准 attention backward 需要 512GB存 attention 矩阵 512GB存梯度 1TB 显存。FlashAttention backward 只需要 ~几 GB。标准 Attention 反向传播的梯度公式前向O softmax(QK^T / √d) × V P × V反向需要三个梯度dQ, dK, dV都需要 attention 矩阵 PdV P^T × dO dP dO × V^T dS dP ⊙ P - P ⊙ (sum(dP ⊙ P, dim-1)) # softmax 反向 dQ dS × K dK dS^T × Q所有公式都依赖 Pattention 矩阵——但 FlashAttention 前向没存它。FlashAttention 反向重算 分块反向传播的核心思路在反向 pass 中重新执行前向计算。前向时跑了分块 softmax 分块加权求和反向时再次分块——但这次不仅计算 O还要计算 dQ, dK, dV。// ops-transformer/kernels/flash_attention/flash_attention_backward.cpp__aicore__voidFlashAttentionBackward(GlobalTensorfloat16dO,// 输出梯度 [B, H, N, D]GlobalTensorfloat16Q,// 前向的 Q保留GlobalTensorfloat16K,// 前向的 K保留GlobalTensorfloat16V,// 前向的 V保留GlobalTensorfloat16L,// 前向的 row_sum (softmax 分母)GlobalTensorfloat16dQ,// Q 的梯度GlobalTensorfloat16dK,// K 的梯度GlobalTensorfloat16dV,// V 的梯度intN,intD){constexprintBr32;constexprintBc32;// 第一步重算 dV最简单——只需 P 和 dOfor(intj0;jnum_kv_blocks;j){LocalTensorfloat16dV_local(Bc,D);for(intbc0;bcBc;bc)for(intd0;dD;d)dV_local[bc][d]0.0f;for(inti0;inum_q_blocks;i){LocalTensorfloat16Qi(Br,D);LocalTensorfloat16Kj(Bc,D);DataCopy(Qi,Qi*Br*D,Br*D);DataCopy(Kj,Kj*Bc*D,Bc*D);// 重算 S Qi × Kj^TLocalTensorfloat16S_block(Br,Bc);for(intr0;rBr;r)for(intc0;cBc;c){floatsum0.0f;for(intd0;dD;d)sumfloat(Qi[r*Dd])*float(Kj[c*Dd]);S_block[r*Bcc]float16(sum);}// 重算 P softmax(S)用前向存的 row_max 和 Lfor(intr0;rBr;r){floatmax_valrow_max_forward[i*Brr];floatsum_expL[i*Brr];for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;// dV_j sum_i(P_ij × dO_i)for(intd0;dD;d)dV_local[c][d]P_val*float(dO[(i*Brr)*Dd]);}}}DataCopy(dVj*Bc*D,dV_local,Bc*D);}// 第二步重算 dQ需要 dP 和 dSfor(inti0;inum_q_blocks;i){LocalTensorfloat16Qi(Br,D);DataCopy(Qi,Qi*Br*D,Br*D);floatdQi_init[Br][D]{0.0f};for(intj0;jnum_kv_blocks;j){LocalTensorfloat16Kj(Bc,D);LocalTensorfloat16Vj(Bc,D);LocalTensorfloat16dOi(Br,D);DataCopy(Kj,Kj*Bc*D,Bc*D);DataCopy(Vj,Vj*Bc*D,Bc*D);DataCopy(dOi,dOi*Br*D,Br*D);// 重算 S_blockLocalTensorfloat16S_block(Br,Bc);for(intr0;rBr;r)for(intc0;cBc;c){floatsum0.0f;for(intd0;dD;d)sumfloat(Qi[r*Dd])*float(Kj[c*Dd]);S_block[r*Bcc]float16(sum);}for(intr0;rBr;r){floatmax_valrow_max_forward[i*Brr];floatsum_expL[i*Brr];// dP_ij dO_i × V_j^TfloatdP[Bc];for(intc0;cBc;c){dP[c]0.0f;for(intd0;dD;d)dP[c]float(dOi[r*Dd])*float(Vj[c*Dd]);}// D_i sum_j(dP_ij × P_ij)floatD_i0.0f;for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;D_idP[c]*P_val;}// dS_ij (dP_ij - D_i) × P_ij → dQ_i sum_j(dS_ij × Kj)for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;floatdS_val(dP[c]-D_i)*P_val;for(intd0;dD;d)dQi_init[r][d]dS_val*float(Kj[c*Dd]);}}}for(intr0;rBr;r)for(intd0;dD;d)dQ[(i*Brr)*Dd]float16(dQi_init[r][d]);}// 第三步dK对称于 dQ公式是 dK_j sum_i(dS_ij^T × Qi)for(intj0;jnum_kv_blocks;j){LocalTensorfloat16Kj(Bc,D);DataCopy(Kj,Kj*Bc*D,Bc*D);floatdKj_init[Bc][D]{0.0f};for(inti0;inum_q_blocks;i){LocalTensorfloat16Qi(Br,D);LocalTensorfloat16Vj(Bc,D);LocalTensorfloat16dOi(Br,D);DataCopy(Qi,Qi*Br*D,Br*D);DataCopy(Vj,Vj*Bc*D,Bc*D);DataCopy(dOi,dOi*Br*D,Br*D);LocalTensorfloat16S_block(Br,Bc);for(intr0;rBr;r)for(intc0;cBc;c){floatsum0.0f;for(intd0;dD;d)sumfloat(Qi[r*Dd])*float(Kj[c*Dd]);S_block[r*Bcc]float16(sum);}for(intr0;rBr;r){floatmax_valrow_max_forward[i*Brr];floatsum_expL[i*Brr];floatdP[Bc];for(intc0;cBc;c){dP[c]0.0f;for(intd0;dD;d)dP[c]float(dOi[r*Dd])*float(Vj[c*Dd]);}floatD_i0.0f;for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;D_idP[c]*P_val;}for(intc0;cBc;c){floatP_valexpf(float(S_block[r*Bcc])-max_val)/sum_exp;floatdS_val(dP[c]-D_i)*P_val;// dK_j_c sum_r(dS_rc × Qi_r) ← 转置关系for(intd0;dD;d)dKj_init[c][d]dS_val*float(Qi[r*Dd]);}}}for(intc0;cBc;c)for(intd0;dD;d)dK[(j*Bcc)*Dd]float16(dKj_init[c][d]);}}前向保存的关键数据FlashAttention 前向完成后只保存两个向量不是 N×N 矩阵structFlashAttentionForwardCache{float32*row_max;// [num_q_blocks × Br] — softmax 每行的最大值float32*L;// [num_q_blocks × Br] — softmax 分母指数和float16*O;// [B, H, N, D] — 正常大小的输出};反向传播时用row_max和L重算 softmax 矩阵 P——每次只重算一个分块不同时存在于显存中。计算量分析标准 Attention 反向传播 - 前向O(N² × D) 计算 O(N²) 存储存 P - 反向O(N² × D) 计算直接读 P O(N²) 存储 P dP - 显存O(N²) 512GB (512K seq) FlashAttention 反向传播 - 前向O(N² × D) 计算 O(N) 存储只存 row_max L - 反向O(N² × D) 计算 × 2重算两次 P → dV 和 dQ/dK - 显存O(N) ~几 GB 额外计算反向多一倍重算了 P 两次 显存节省O(N²) → O(N)512K seq 下 512GB → 几 GB踩坑一row_max 和 L 用 FP16 保存 → 梯度偏移重算 softmax 需要前向的 row_max 和 L。FP16 保存 ±0.001 误差在 exp(S - max) 中偏差被放大// ❌ FP16 保存 row_max → 还原时 ±0.001 误差float16 row_max_fwd_fp16[N];floatrow_max_restoredfloat(row_max_fwd_fp16[i]);// 偏差 0.001// exp(S - max) 中偏差放大floattrue_expexpf(88.0f-88.0f)1.0f;floatwrong_expexpf(88.0f-88.001f)0.999f;// 偏差 0.1%// ✅ FP32 保存 row_max 和 Lfloat32 row_max_fwd_fp32[N];float32 L_fwd_fp32[N];实测FP16 row_max → LLaMA 7B 训练 loss 在 5000 步后偏离 0.03vs 基线FP32 row_max → 只偏离 0.001。踩坑二dQ 和 dK 各自重算一遍 P → 白白多算一次反向需要重算两次前向一次算 dV需要 P一次算 dQ 和 dK也需要 P。两次重算之间 P 没保存 → 算了两遍。// ❌ 两次重算 P —— 第二次浪费了for(i,j){Precompute(Qi,Kj);dVP^T × dO;// 第一次重算——只用了一次}for(i,j){Precompute(Qi,Kj);// 又算一次——浪费dQdS × K;dKdS^T × Q;}// ✅ 一次重算 P同时输出 dV/dQ/dKfor(i,j){Precompute(Qi,Kj);dVP^T × dO[i];// 一次 P三种梯度dS(dP-D_i)× P;dQdS × K[j];dKdS^T × Q[i];}踩坑三FP16 累加器精度损失dQi_init[r][d]跨 8 个 Bc chunks 累加——每个贡献微量。FP16 累加 8 次 → 误差累积。// ❌ FP16 累加器float16 dQi_init[Br][D];// 8 次累加后每次舍入 → 总误差 ~0.1%// ✅ FP32 累加器只在写回时转 FP16floatdQi_init[Br][D];// ... 8 次累加全精度dQ[...]float16(dQi_init[r][d]);// 只一次转换FlashAttention 反向的本质不存 N² 矩阵用 N² 的额外计算换回来。前向保存 row_max 和 LO(N) 大小反向重算两次 PO(N²) 计算得到完整的 dQ/dK/dV。512K seq 下显存从 512GB 降到几 GB——这是训练长上下文模型的唯一可行路径。三个关键row_max/L 用 FP32 保存不要节约 4 bytes 丢了精度、dQ 和 dK 的重算合并为一次重算P 算一次输出三种梯度、累加器全用 FP32最后才转 FP16。
http://www.gsyq.cn/news/1372000.html

相关文章:

  • QKeyMapper:打破设备界限,让Windows键鼠手柄随心所欲的终极映射方案
  • 鸿蒙PC:Qt适配OpenHarmony实战【微习惯】:把每日习惯、完成率和周视图放在一个窗口里
  • 2026年5月衡水饶阳地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • 2026年5月赣州全南地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 镜像视界浙江科技有限公司煤矿领域技术地位与核心优势
  • 2026年5月赣州瑞金地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 2026年5月衡水深州地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • 通过Taotoken的Token Plan套餐实现项目成本的可预测与精细控制
  • 2026年5月甘南临潭地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 2026年5月惠州龙门地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • 2026年5月赣州上犹地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 2026年5月甘南碌曲地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • Google 广告场景下 Uniswap 钓鱼攻击机理与 Web3 防御体系研究
  • 人机协同闭环:AI 时代邮件安全 “人在回路” 防御体系研究
  • 高校邮件安全体系升级与 Proofpoint 部署实践研究 —— 以特拉华大学为例
  • Go语言数据库迁移与版本管理
  • 2026年5月恩施巴东地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 2026年5月甘南玛曲地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • AliceSoft游戏文件逆向工程深度解析:从二进制格式到高级编辑的完整方案
  • 终极指南:在VS Code中构建高效的R语言数据分析环境
  • 2026年5月恩施地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 为什么92%的DeepSeek微调失败?资深架构师拆解3类致命配置错误及实时诊断命令
  • 【Gemini生命周期价值深度解码】:20年AI架构师亲授5大阶段ROI测算模型与避坑指南
  • ChatGPT移动端隐私红线报告(2024Q2):麦克风/剪贴板/位置数据采集路径全曝光,3步彻底锁死敏感权限
  • 2026年5月抚顺清原地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 检测回收中心
  • 2026年5月菏泽巨野地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • 【Gemini免费额度高效利用指南】:20年AI平台实战总结的7大避坑技巧与超额续命策略
  • Veo+Runway+Pika+Synthesia+HeyGen+Kaedim+Adobe Firefly:7大AI视频工具协同工作流全拆解,3小时搭建企业级智能剪辑中枢
  • 创业团队如何利用Taotoken统一管理多个AI应用API成本
  • 如何为嵌入式项目配置大模型API调用使用Taotoken与Python