深度学习与神经网络学习笔记 —— Transformer模型原理与实现
课程内容:Transformer模型综述、输入模块(Input Block)、编码器(Encoder)、解码器(Decoder)、输出模块(Output Layer)以及网络训练过程。
核心关键词:Attention、Self-Attention、Multi-Head Attention、Positional Encoding、Encoder、Decoder、大语言模型
一、Transformer的提出背景与整体架构
在Transformer出现之前,自然语言处理领域主要采用循环神经网络(RNN)和长短时记忆网络(LSTM)完成序列建模任务。虽然LSTM能够在一定程度上解决传统RNN的梯度消失问题,但随着序列长度增加,模型仍然面临长期依赖难以学习、训练效率低以及难以并行计算等问题。课程中首先回顾了Transformer出现之前序列模型的发展过程,并指出传统RNN结构存在串行计算的天然缺陷。
Transformer最大的创新在于:
完全抛弃循环结构,仅利用Attention机制完成序列建模。
其整体结构由Encoder和Decoder两部分组成:
Input ↓ Embedding ↓ Positional Encoding ↓ Encoder × N ↓ Decoder × N ↓ Linear ↓ Softmax ↓ Output传统RNN的信息传播过程可以表示为:
ht=f(ht−1,xt)h_t=f(h_{t-1},x_t)ht=f(ht−1,xt)
这种结构导致计算必须按照时间顺序逐步进行。
而Transformer通过Attention直接建立任意两个位置之间的联系:
yi=∑j=1nαijxjy_i=\sum_{j=1}^{n}\alpha_{ij}x_jyi=∑j=1nαijxj
因此模型能够实现完全并行化计算,大幅提高训练效率。课程中特别强调了GPU并行计算对于Transformer成功的重要意义。
二、输入模块:Embedding与位置编码
Transformer虽然能够处理序列数据,但神经网络无法直接理解文本,因此首先需要将单词转换成向量表示。
假设词表大小为 VVV,词向量维度为 ddd,则Embedding矩阵可以表示为:
E∈RV×dE\in\mathbb{R}^{V\times d}E∈RV×d
每个单词经过Embedding后变成:
xi∈Rdx_i\in\mathbb{R}^{d}xi∈Rd
然而,Transformer没有RNN的时间顺序结构,因此无法自动获得位置信息。
为了解决这一问题,课程中介绍了位置编码(Positional Encoding)机制。
Transformer采用正弦函数与余弦函数构造位置向量:
PE(pos,2i)=sin(pos100002id)PE(pos,2i)=\sin\left(\frac{pos}{10000^{\frac{2i}{d}}}\right)PE(pos,2i)=sin(10000d2ipos)
PE(pos,2i+1)=cos(pos100002id)PE(pos,2i+1)=\cos\left(\frac{pos}{10000^{\frac{2i}{d}}}\right)PE(pos,2i+1)=cos(10000d2ipos)
课程中给出了上述位置编码公式。
最终输入表示为:
Z=X+PEZ=X+PEZ=X+PE
其中:
- XXX 表示词向量
- PEPEPE 表示位置编码
这种设计使Transformer既能够获取语义信息,又能够获取词语的顺序信息。
三、Encoder结构与Self-Attention机制
Encoder是Transformer最核心的部分。
课程中指出,一个Encoder层主要由以下模块组成:
Multi-Head Attention ↓ Add & Norm ↓ Feed Forward ↓ Add & Norm其中最重要的是Self-Attention机制。
Self-Attention的基本思想是:
一个单词在表示自身时,需要关注句子中其它所有单词。
例如:
The animal didn't cross the street because it was tired.这里的“it”究竟指什么?
Self-Attention会自动计算“it”与所有词之间的关联程度,从而确定其真正含义。
在Transformer中,每个输入向量会生成三个向量:
- Query(查询向量)
- Key(键向量)
- Value(值向量)
计算方式分别为:
Q=XWQQ=XW_QQ=XWQ
K=XWKK=XW_KK=XWK
V=XWVV=XW_VV=XWV
随后利用Query与Key计算相似度:
Score=QKTScore=QK^TScore=QKT
为了避免维度过大导致数值不稳定,需要进行缩放:
QKTdk\frac{QK^T}{\sqrt{d_k}}dkQKT
最终经过Softmax得到注意力权重:
A=softmax(QKTdk)A=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)A=softmax(dkQKT)
课程中给出了完整的Scaled Dot-Product Attention公式:
Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
该公式被认为是Transformer最核心的数学表达式。
为了提升模型表达能力,Transformer进一步引入多头注意力机制(Multi-Head Attention)。
第 iii 个注意力头:
headi=Attention(QWiQ,KWiK,VWiV)head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)
多个头拼接后:
MultiHead=Concat(head1,...,headh)WOMultiHead=Concat(head_1,...,head_h)W^OMultiHead=Concat(head1,...,headh)WO
这样模型能够同时关注不同语义关系。
四、Decoder结构与输出生成机制
Decoder的整体结构与Encoder类似,但增加了Masked Self-Attention机制。
原因在于:
在文本生成过程中,当前位置不能提前看到未来单词。
例如:
I am going to buy a new car当模型生成“buy”时,不应该提前知道后面的“car”。课程中也利用类似句子说明了解码过程。
因此Decoder中的注意力矩阵需要加入Mask:
Mask(i,j)={0,j≤i−∞,j>iMask(i,j)=\begin{cases}0,&j\le i\\-\infty,&j>i\end{cases}Mask(i,j)={0,−∞,j≤ij>i
Masked Attention变为:
softmax(QKT+Maskdk)softmax\left(\frac{QK^T+Mask}{\sqrt{d_k}}\right)softmax(dkQKT+Mask)
Decoder除了Masked Self-Attention外,还会接收Encoder输出的信息。
因此Decoder包含两种Attention:
- Self-Attention
- Encoder-Decoder Attention
最终输出经过线性层:
z=Woh+bz=W_o h+bz=Woh+b
再经过Softmax转换为概率分布:
P(yi)=ezi∑jezjP(y_i)=\frac{e^{z_i}}{\sum_j e^{z_j}}P(yi)=∑jezjezi
概率最大的词将作为下一时刻输出。
五、Transformer训练过程与课程总结
Transformer训练本质上属于监督学习过程。
对于一个长度为 nnn 的目标序列:
(y1,y2,⋯ ,yn)(y_1,y_2,\cdots,y_n)(y1,y2,⋯,yn)
模型希望最大化整个序列出现概率:
P(Y)=∏t=1nP(yt∣y<t,X)P(Y)=\prod_{t=1}^{n}P(y_t|y_{<t},X)P(Y)=∏t=1nP(yt∣y<t,X)
训练时通常采用交叉熵损失函数:
L=−∑iyilog(y^i)L=-\sum_i y_i\log(\hat y_i)L=−∑iyilog(y^i)
对于整个数据集:
L=−1N∑n=1N∑iyi(n)log(y^i(n))L=-\frac{1}{N}\sum_{n=1}^{N}\sum_i y_i^{(n)}\log(\hat y_i^{(n)})L=−N1∑n=1N∑iyi(n)log(y^i(n))
通过反向传播不断更新参数:
θt+1=θt−η∇L(θt)\theta_{t+1}=\theta_t-\eta\nabla L(\theta_t)θt+1=θt−η∇L(θt)
课程最后总结指出,Transformer最大的贡献在于:
- 利用Attention替代循环结构;
- 实现完全并行训练;
- 更好地建模长距离依赖关系;
- 成为现代大模型的基础架构。
目前几乎所有主流大模型都建立在Transformer框架之上,例如:
- GPT系列
- BERT系列
- T5
- PaLM
- LLaMA
- Qwen
从技术发展角度来看,Transformer已经不仅仅是一个神经网络模型,而是现代人工智能尤其是大语言模型时代最重要的基础架构之一。
