当前位置: 首页 > news >正文

Transformer 核心模块详解:多头注意力、前馈网络与词嵌入

【学习记录】Transformer 核心模块详解:多头注意力、前馈网络与词嵌入

Transformer 是现代大语言模型的基石,而多头注意力(MultiHeadAttention)前馈网络(FFN)词嵌入(Embedding)是其最核心的三个组件。本文从原理到代码,逐层拆解这三个模块,并提供 Python(PyTorch)和 C++(LibTorch)实现,附带完整的复杂度分析。


📌 目录

  1. MultiHeadAttention(多头注意力)
  2. FFN(前馈网络)
  3. Embedding(词嵌入)
  4. 三个模块的组合使用
  5. 复杂度总结

一、多头注意力(MultiHeadAttention)

1.1 作用

多头注意力机制允许模型同时关注输入序列中不同位置的不同表示子空间。它通过将查询(Q)、键(K)、值(V)线性映射到多个头,分别计算注意力,最后拼接并映射回原维度。

1.2 数学公式

标准缩放点积注意力:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V

多头注意力:

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)

1.3 代码实现(Python/PyTorch)

importtorchimporttorch.nnasnnimportmathclassMultiHeadAttention(nn.Module):def__init__(self,d_model,num_heads):super().__init__()assertd_model%num_heads==0self.d_model=d_model self.num_heads=num_heads self.d_k=d_model//num_heads self.Wq=nn.Linear(d_model,d_model)self.Wk=nn.Linear(d_model,d_model)self.Wv=nn.Linear(d_model,d_model)self.Wo=nn.Linear(d_model,d_model)defforward(self,Q,K,V,mask=None):batch_size=Q.size(0)# 1. 线性映射并拆分为多头Q=self.Wq(Q).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)K=self.Wk(K).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)V=self.Wv(V).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# 2. 缩放点积注意力scores=torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.d_k)# 3. 应用掩码(可选)ifmaskisnotNone:scores=scores.masked_fill(mask==0,-1e9)attn_weights=torch.softmax(scores,dim=-1)output=torch.matmul(attn_weights,V)# 4. 合并多头并输出output=output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)returnself.Wo(output)

1.4 图解(文本示意)

输入: (B, T, D) │ ├─→ 线性映射 Wq, Wk, Wv → (B, T, D) │ ├─→ view + transpose → (B, n_head, T, d_k) │ ├─→ scores = Q @ K^T / sqrt(d_k) → (B, n_head, T, T) │ │ │ └─→ mask (可选) 填充 -1e9 │ ├─→ softmax → (B, n_head, T, T) │ ├─→ output = attn @ V → (B, n_head, T, d_k) │ ├─→ transpose + view → (B, T, D) │ └─→ Wo 线性映射 → (B, T, D)

1.5 C++ 代码(LibTorch)

#include<torch/torch.h>classMultiHeadAttentionImpl:publictorch::nn::Module{public:intd_model,num_heads,d_k;torch::nn::Linear Wq,Wk,Wv,Wo;MultiHeadAttentionImpl(intd_model_,intnum_heads_):d_model(d_model_),num_heads(num_heads_),d_k(d_model_/num_heads_),Wq(torch::nn::Linear(d_model,d_model)),Wk(torch::nn::Linear(d_model,d_model)),Wv(torch::nn::Linear(d_model,d_model)),Wo(torch::nn::Linear(d_model,d_model)){register_module("Wq",Wq);register_module("Wk",Wk);register_module("Wv",Wv);register_module("Wo",Wo);}torch::Tensorforward(torch::Tensor Q,torch::Tensor K,torch::Tensor V,torch::Tensor mask={}){intbatch_size=Q.size(0);// 线性映射Q=Wq->forward(Q).view({batch_size,-1,num_heads,d_k}).transpose(1,2);K=Wk->forward(K).view({batch_size,-1,num_heads,d_k}).transpose(1,2);V=Wv->forward(V).view({batch_size,-1,num_heads,d_k}).transpose(1,2);// 注意力分数autoscores=torch::matmul(Q,K.transpose(-2,-1))/std::sqrt(d_k);if(mask.defined()){scores=scores.masked_fill(mask==0,-1e9);}autoattn=torch::softmax(scores,-1);autooutput=torch::matmul(attn,V);output=output.transpose(1,2).contiguous().view({batch_size,-1,d_model});returnWo->forward(output);}};TORCH_MODULE(MultiHeadAttention);

1.6 复杂度分析

操作时间复杂度空间复杂度
线性映射 (Q,K,V)O(B×T×D²)O(B×T×D)
拆分多头O(B×T×D)O(B×n_head×T×d_k)
分数矩阵乘法O(B×n_head×T²×d_k)O(B×n_head×T²)
SoftmaxO(B×n_head×T²)O(B×n_head×T²)
加权求和O(B×n_head×T²×d_k)O(B×n_head×T×d_k)
合并与输出映射O(B×T×D²)O(B×T×D)
总计O(B × T² × D)O(B × n_head × T²)

其中D = d_model,d_k = D / n_head


二、前馈网络(FFN)

2.1 作用

FFN 对每个位置独立进行非线性变换,增加模型表达能力。标准结构:线性 → ReLU → 线性,通常中间维度d_ffd_model的 4 倍左右。

2.2 数学公式

FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2FFN(x)=ReLU(xW1+b1)W2+b2

2.3 代码实现(Python/PyTorch)

classFFN(nn.Module):def__init__(self,d_model,d_ff):super().__init__()self.linear1=nn.Linear(d_model,d_ff)self.linear2=nn.Linear(d_ff,d_model)self.activation=nn.ReLU()defforward(self,x):returnself.linear2(self.activation(self.linear1(x)))

2.4 图解

输入 (B, T, D) │ ├─→ linear1 (D → d_ff) → (B, T, d_ff) │ ├─→ ReLU → (B, T, d_ff) │ └─→ linear2 (d_ff → D) → (B, T, D)

2.5 C++ 代码(LibTorch)

classFFNImpl:publictorch::nn::Module{public:torch::nn::Linear linear1,linear2;FFNImpl(intd_model,intd_ff):linear1(d_model,d_ff),linear2(d_ff,d_model){register_module("linear1",linear1);register_module("linear2",linear2);}torch::Tensorforward(torch::Tensor x){returnlinear2->forward(torch::relu(linear1->forward(x)));}};TORCH_MODULE(FFN);

2.6 复杂度分析

操作时间复杂度空间复杂度
linear1O(B × T × D × d_ff)O(B × T × d_ff)
ReLUO(B × T × d_ff)O(B × T × d_ff)
linear2O(B × T × d_ff × D)O(B × T × D)
总计O(B × T × D × d_ff)O(B × T × max(D, d_ff))

d_ff = 4 × D时,复杂度约为O(4 × B × T × D²)


三、词嵌入(Embedding)

3.1 作用

将离散的 token ID 序列映射为稠密的连续向量,并乘以√d_model进行缩放,以便与位置编码相加时尺度匹配。

3.2 代码实现(Python/PyTorch)

classEmbedding(nn.Module):def__init__(self,vocab_size,d_model):super().__init__()self.embedding=nn.Embedding(vocab_size,d_model)self.d_model=d_modeldefforward(self,x):returnself.embedding(x)*math.sqrt(self.d_model)

3.3 图解

输入: (B, T) token IDs [ [1, 3, 0, ...] ] │ └─→ nn.Embedding 查表 (vocab_size × D) │ └─→ 输出 (B, T, D) │ └─→ 乘以 √D → (B, T, D)

3.4 C++ 代码(LibTorch)

classEmbeddingImpl:publictorch::nn::Module{public:torch::nn::Embedding embedding;intd_model;EmbeddingImpl(intvocab_size,intd_model_):embedding(vocab_size,d_model_),d_model(d_model_){register_module("embedding",embedding);}torch::Tensorforward(torch::Tensor x){returnembedding->forward(x)*std::sqrt(d_model);}};TORCH_MODULE(Embedding);

3.5 复杂度分析

操作时间复杂度空间复杂度
查表O(B × T)O(B × T × D)
乘法O(B × T × D)O(B × T × D)
总计O(B × T × D)O(B × T × D)

四、三个模块的组合使用

一个完整的 Transformer 编码器层通常由多头注意力 + 残差连接 + 层归一化FFN + 残差连接 + 层归一化构成。

classTransformerEncoderLayer(nn.Module):def__init__(self,d_model,num_heads,d_ff):super().__init__()self.self_attn=MultiHeadAttention(d_model,num_heads)self.ffn=FFN(d_model,d_ff)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)defforward(self,x,mask=None):# 自注意力 + 残差 + 层归一化attn_out=self.self_attn(x,x,x,mask)x=self.norm1(x+attn_out)# FFN + 残差 + 层归一化ffn_out=self.ffn(x)x=self.norm2(x+ffn_out)returnx

完整流程示例

vocab_size=10000d_model=512num_heads=8d_ff=2048batch_size=2seq_len=10# 输入 token IDsinput_ids=torch.randint(0,vocab_size,(batch_size,seq_len))# 嵌入层embed=Embedding(vocab_size,d_model)x=embed(input_ids)# (2,10,512)# 位置编码(此处略,可加上)# pos_enc = PositionalEncoding(d_model)# x = pos_enc(x)# Transformer 编码器层encoder_layer=TransformerEncoderLayer(d_model,num_heads,d_ff)output=encoder_layer(x)# (2,10,512)print(output.shape)# torch.Size([2, 10, 512])

五、复杂度总结

模块时间复杂度空间复杂度说明
MultiHeadAttentionO(B × T² × D)O(B × n_head × T²)核心瓶颈在 T²,长序列需优化
FFNO(B × T × D × d_ff)O(B × T × max(D, d_ff))通常 d_ff = 4D,复杂度约为 4×
EmbeddingO(B × T × D)O(B × T × D)查表操作,轻量级

优化建议

  • 对于长序列(T 很大),可使用稀疏注意力(如 FlashAttention)降低 T² 复杂度。
  • FFN 的中间维度 d_ff 越大模型容量越大,但计算量线性增加。
  • 嵌入层占参数量主要为vocab_size × D,大词表时需考虑参数共享或压缩。

🎯 总结

本文详细拆解了 Transformer 的三个核心模块:

  1. 多头注意力:让模型关注不同位置的多种关系,是 Transformer 成功的核心。
  2. 前馈网络:提供非线性变换,增强模型表达能力。
  3. 词嵌入:将离散符号映射到连续空间,是深度学习处理文本的起点。

通过理解这些模块的输入输出、形状变化和复杂度,能轻松搭建并优化自己的 Transformer 模型。

http://www.gsyq.cn/news/1335757.html

相关文章:

  • Delphi二进制迷宫破解:IDR交互式重构器的逆向工程革命
  • 你的闹钟为何总在熄屏后“哑火”?——AlarmManager 精准唤醒与 Doze 破解全指南
  • 2026年知名的镇江防腐网格桥架优质厂家推荐榜 - 行业平台推荐
  • Attractor Models 深度拆解:当循环 Transformer 遇见不动点,AI 学会了自己迭代到答案
  • 【从零学Vibe Coding】第一章:Vibe Coding 到底是什么?
  • O2OA(翱途)开发平台V10 财务管理|中小企业费用业务一体化
  • LLM结构化输出工程:让模型输出你真正需要的格式
  • MobileNetV2肺癌病理图像分类|全网独家实战,MSA注意力改进篇 引入MSA多尺度注意力,强化病理特征提取、助力微小病灶识别、病理切片分类、临床辅助诊断有效涨点
  • CAPEv2 沙箱安装部署
  • 一多 OS 的技术闭环彻底打通
  • 鸿蒙动态信息流与健康档案模块:声明式列表与网格的深度融合
  • AI产品经理入门实战:如何理解数字人驱动?
  • 百万级 MySQL 大表导入前,别让这两个默认参数拖垮性能_2026-05-20
  • COMSOL电磁超声仿真避坑指南:从‘域不适用’报错到结果收敛的完整调试流程
  • 无人机算法之第四章 ArduPilot 主要配置参数及效果
  • GNSS模块教程:大夏龙雀 DX-GP21,从硬件接线到 NMEA 数据解析
  • [具身智能-824]:人的大脑,如何实现高实时、多模态联合、发现表象背后的各种规律和层层叠叠的不同层次的语义的?
  • 【C++】类和对象( 类的定义、实例化、 this指针、 C++和C语言实现Stack对比)
  • 电脑截图工具深度测评:PixPin、Snipaste、兔灵截图(Utools插件)
  • ⚡ 淘汰你的不是 AI,而是会用 AI 的同行
  • 8 张 RTX 5090 跑 Qwen3.6-27B:从装 vLLM 到压测调优的真实数据(含完整脚本)
  • 全面详解 bgfx
  • 别再乱改Rime配置了!先搞懂程序文件夹和用户文件夹的区别(Windows/Ubuntu路径详解)
  • Cursor试用限制终极解决方案:3分钟快速重置设备标识实战指南
  • 无磁钻具:市场发展现状与未来前景趋势
  • FPGA管脚不够用?手把手教你用74HC595级联驱动8位数码管(附Verilog代码与仿真)
  • 测试经理为保障项目按期交付,主动规划核心内容
  • YimMenu:GTA5终极防护与增强完整指南
  • 保姆级教程:在S32G274ARDB2上,用IPCF点亮RGB LED(附源码解析)
  • cp520靶场学习笔记