背景介绍

当上下文长度变长后,标准全量 attention 的计算量和 KV Cache 访问成本都会迅速上升。围绕长上下文,常见研究方向包括:观察语言模型如何分配 attention weights、通过稀疏模式降低计算复杂度,以及通过 FlashAttention 这类 IO-aware 实现提升硬件效率。

长上下文中的 Attention 分布模式

Fu YaoLLaMA-2-7B-80K模型上进行了实验,并提取了其注意力张量(tensor)。注意力tensor包含了三个维度/层级:

  • 网络层数,即深度(depth):32层
  • 注意力头数(heads):每层有32个头
  • 上下文长度(context length):输入文档有50K tokens

作者在不同层观测到注意力的分布不尽相同:

  • 在Layer 0和Layer 1中,大多数(或全部)注意力头(attention heads)遵循均匀分布(uniform distribution)。
  • 在Layer 2-30中,许多注意力头呈现V型模式(V-shaped pattern)– 即注意力权重(attention mass)主要分配在序列的起始和末尾部分(first and last few tokens)。不过,也有一些注意力头表现出其他模式,例如:分散在中间部分(scattered over the middle)、集中在中间位置(concentrated on the middle)。值得注意的是,Layer 2-30中未观察到均匀分布。
  • 最后,在Layer 31(即最后一层),所有上述注意力模式均存在。
V-shaped:第5层,第20头,注意力遵循V形模式
V-shaped:第5层,第20头,注意力遵循V形模式
Attention sink:第12层,第23头,注意力沉没模式
Attention sink:第12层,第23头,注意力沉没模式
Recency bias:第5层,第15头,最近偏见模式
Recency bias:第5层,第15头,最近偏见模式
Scattered over middle:第11层,第2头。概率质量在中间的多个token上分散
Scattered over middle:第11层,第2头。概率质量在中间的多个token上分散
Concentrated on middle:第11层,第14头。概率质量集中在中间的很少几个token上
Concentrated on middle:第11层,第14头。概率质量集中在中间的很少几个token上
Uniform:第0层,第5头,所有token上的均匀分布
Uniform:第0层,第5头,所有token上的均匀分布

from:https://yaofu.notion.site/How-Do-Language-Models-put-Attention-Weights-over-Long-Context-10250219d5ce42e8b465087c383a034e

稀疏 Attention 变种

Longformer(2020,AllenAI)

Sparse Attention优化,主要结合滑动窗口局部注意力+任务驱动全局注意力方式,也支持空洞滑动窗口变体,把注意力计算从长度平方降低到线性。

如何选择全局注意力:

  1. 不同任务选择不同的全局注意力
  2. 分类任务:CLS token;QA任务:所有Q对应的token。
Longformer 的滑动窗口注意力与全局注意力结构
Longformer 的滑动窗口注意力与全局注意力结构

BigBird(2020,Google)

Sparse Attention优化,结合随机+滑动窗口+全局注意力方式,把注意力计算从长度平方降低到线性。

全局注意力:

  1. 内部选择一些行或者列
  2. 外部扩展:如Bert分类任务中的CLS token。
BigBird 结合局部窗口、全局 token 与随机注意力的稀疏结构
BigBird 结合局部窗口、全局 token 与随机注意力的稀疏结构

Routing Transformer(2020/2021,Google)

与传统的全注意力机制相比,Routing Transformer通过在线k-means聚类算法动态地学习和分配注意力权重,使得每个位置仅关注其所在簇内的关键点,从而将注意力计算复杂度从\(O(n^2d)\)降低到约\(O(n^{1.5}d)\)。具体地:

  1. 首先使用LN把 key 和 query 投影到d维的单元球上。
  2. 然后计算历史query、key和聚类中心的距离。聚成\(k\)簇,每个簇选择相似度最高的窗口大小\(w = n/k\)的index,获取对应index的 query 和 key,并计算簇内 softmax。
  3. 在每个簇中,加权求和value 得到token的上下文表示。
Routing Transformer 基于聚类路由的稀疏注意力模式
Routing Transformer 基于聚类路由的稀疏注意力模式

Native Sparse Attention(2025,DeepSeek)

NSA采用动态构建的层次化稀疏策略,结合粗粒度令牌压缩和细粒度令牌选择来保留全局、局部上下文信息。方法如下:

  1. Token压缩:首先将历史的 k 和 v 通过滑动窗口分割成块,滑动窗口大小为\(l\),滑动步幅为\(d\),保证\(d < l\)避免产生信息碎片。然后,每个块通过一个MLP压缩到一个key和value。
  2. Token选择:将键值序列划分为选择块,为每个块分配重要性得分。借用压缩token注意力计算产生的中间注意力分数,推导出选择块的重要性分数。最后选出top n个token块。
  3. 滑动窗口:在一个窗口w中维护最近的token,来明确处理局部上下文。
Native Sparse Attention 的压缩、选择与滑动窗口注意力模块
Native Sparse Attention 的压缩、选择与滑动窗口注意力模块

高性能 Attention 实现

FlashAttention

FlashAttention的原理可以通俗理解为“化整为零,边算边扔”。它是精确attention实现,不是近似attention;通过将注意力计算分块处理,使用online softmax和分块累积,避免显式存储完整的注意力矩阵,并利用GPU的高效内存访问(如SRAM)来加速计算。反向传播阶段还会通过重计算减少显存占用。

FlashAttention 通过分块计算减少 HBM 读写开销
FlashAttention 通过分块计算减少 HBM 读写开销

标准 Attention 实现伪码

标准 Attention 计算流程伪代码
标准 Attention 计算流程伪代码

FlashAttention 实现伪码

FlashAttention 分块计算流程伪代码
FlashAttention 分块计算流程伪代码
  • 向量 \(\ell\) 初始化为零,用于存储累积的归一化因子。
  • 向量 \(m\) 初始化为负无穷,用于跟踪最大值,以便在计算过程中进行数值稳定处理。
  • 10行:
    • 这一步计算每一行的最大值(rowmax)用于数值稳定。这在计算注意力分数的时候非常关键,因为通过从logits中减去每一行的最大值,可以防止指数计算中出现数值过大的情况,从而避免溢出。
    • 对稳定后的分数进行指数运算,以得到注意力权重。这是点对点的计算(pointwise),意味着对矩阵中的每一个元素进行操作。
    • 计算每一行的注意力权重和。这是为了随后进行标准化,使得注意力权重能正常归一化到0到1之间。
  • 11行:
    • 更新适应的最大值,对于当前行的最大值 \(m_i\) 与最新计算的 \(\tilde{m}_{ij}\) 进行比较并保留较大的值。这确保在随后计算中,记忆曾经遇到过的最大值,以保证数值稳定。
    • 更新归一化因子 \(\ell_i\)。使用先前与新的指数因子来调整当前块的注意力和归一化。这一步在结合权重时很关键,以确保所有概率值在理论上的总和能够稳定在期望值附近。

补充讨论:复杂度、优化目标与注意力权重

  1. 《attention is logarithmic, actually》

中文地址:https://mp.weixin.qq.com/s/iDWcMYhTpqs-4hqsPGmVaA

一个时间复杂度为 O (n³) 但能够并行的算法,和一个必须按顺序执行的算法,单从时间复杂度上看不出来它们的区别。而且,有些算法天生就是并行的,比如线性代数,但人们还在用时间复杂度来描述它们,这其实是很荒谬的。

我们不仅要考虑算法执行的原始操作数量(即「work」),更要关注计算图相对于输入大小的「depth」,也就是不可并行的顺序操作的最小数量。因为这些顺序操作是不可避免的,无论你的计算机有多少个核心,它们都会造成阻塞。

逐个元素相乘的时间复杂度是多少?

Attention 输出中权重矩阵与 Value 矩阵的逐元素加权示意
Attention 输出中权重矩阵与 Value 矩阵的逐元素加权示意

延伸阅读:《Neural GPUS LEARN ALGORITHMS》https://arxiv.org/pdf/1511.08228

  1. 在attention加速上,不同阶段Attention的优化目标有所不同

在DeepSeek的《Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention》中提到:

算术强度(Arithmetic Intensity)是计算操作与内存访问的比率。它内在地决定了算法在硬件上的优化。每个GPU都有一个关键的算术强度,由其峰值计算能力和内存带宽决定,通过这两个硬件限制的比例来计算。对于计算任务来说,高于这个临界值的算术强度成为计算受限(受GPU FLOPS限制),而低于它的则成为内存受限(受内存带宽限制)。

具体来说,对于自注意力机制,在训练和预填充阶段,批矩阵乘法和注意计算表现出高算术强度,使得这些阶段在现代加速器上成为计算限制。相比之下,自回归解码由于每次前向传递生成一个令牌,同时需要加载整个键值缓存而受到内存带宽的约束,导致低算术强度。这导致了不同的优化目标:

  • 减少训练和预填充期间的计算成本。
  • 减少解码期间的内存访问。
  1. 《How Do Language Models put Attention Weights over Long Context?》

中文地址:https://zhuanlan.zhihu.com/p/689183412

1%的注意力权重是否真的不重要?

  • 目前尚无定论。
  • 考虑到语言模型的敏感性,即使某个标记(token)仅占1%的注意力权重(attention mass),它仍可能对**下一个词预测(next-word prediction)**起到关键作用。

是否忽略了值向量(value vector)的影响

  • 确实如此。
  • 假设某个标记的值向量范数(value vector’s norm)是其他标记的100倍,即使其仅接收1%的注意力权重,仍会对输出向量(output vector)产生显著影响。

参考资料

  1. Efficient Content-Based Sparse Attention with Routing Transformers
  2. Big Bird: Transformers for Longer Sequences
  3. Longformer: The Long-Document Transformer
  4. FlashAttention
  5. attention is logarithmic, actually
  6. 注意力实际上是对数的
  7. How Do Language Models put Attention Weights over Long Context?
  8. 语言模型如何将注意力权重放在长上下文上
  9. 如何可视化LLM Attention
  10. BertViz 可视化工具库