背景介绍
在自回归生成中,模型每次只生成一个新 token。为了避免在每一步重复计算所有历史 token 的 Key 和 Value,推理框架通常会缓存历史 tokens 的 KV 向量。KV Cache 能显著减少重复计算,但也会带来可观的显存占用,因此 MQA、GQA、MLA 等注意力变种都在不同程度上围绕 KV Cache 做优化。
KV Cache 基础:MHA

在 Multi-Head Attention(MHA)中,每个 query head 都有对应的 key head 和 value head。前向计算过程中,为加速生成,可以缓存历史tokens的 Key 和 Value 向量。从向量计算的角度看,每个token需要存储的缓存元素为:
- 每层的 KV Cache 需要存储\(n_h\)个头的 key 和 value。
- 每个头的 key 和 value 维度为\(d_h\)。
- 因此,每层单 token 的存储量为:
\(单 token 存储量=n_h⋅(d_h+d_h)=2⋅n_h⋅d_h\)
- 对于 \(l\) 层 的Transformer,单 token 总存储量为:
\(单 token 总存储量=l⋅n_h⋅(d_h+d_h)=2⋅l⋅n_h⋅d_h\)
举个例子:Transformer层数\(l=60\),注意力头数\(n_h=128\),每个头的向量维度\(d_h=128\)。假设每个缓存元素采取FP16精度,占2个字节。则单token KV cache总存储大小为3840KiB,即3.75MiB。
KV Cache 优化一:Multi-Query Attention(MQA)
对于每个token,映射成key和value向量时,共享一个 \(W_k\) 和 \(W_v\),得到每个头都一样的 key 和 value 向量;映射成query向量时,和MHA保持一致,每个头是各不相同的 query 向量。所以:
- 计算上,减少了token向量映射到key和value的计算量,softmax score的计算量保持不变。
- 内存上,key 和 value 的向量缓存量级大大减少。
缺点是可能带来模型效果下降,具体幅度取决于模型规模、训练方式和任务。
KV Cache 优化二:Grouped-Query Attention(GQA)
为了缓解模型效果衰退,需要适当增加 key 和 value 的头数。对于单个token,映射成多头的key和value向量时,会把所有head分为g组,同一组内的head共享\(W_K\)和\(W_V\)。因为key和value的头数介于MQA和MHA,所以计算和存储上的节省介于二者之间;在合适训练或uptraining下,GQA通常可以取得接近MHA的效果,同时保留接近MQA的推理效率优势。
MHA/MQA/GQA 源码实现
下面我们一起看 Llama3 attention 源码,它通过 n_kv_heads 把MHA、MQA、GQA统一起来:
1 | # https://github.com/meta-llama/llama3/blob/main/llama/model.py |
KV Cache 优化三:Multi-Head Latent Attention(MLA)
多头潜在注意力(Multi-Head Latent Attention,MLA)的核心在于对注意力的键(key)和值(value)进行低秩联合压缩,以便在推理过程中减少键值(KV)缓存。MLA的计算过程如下:

通过多头向量计算,对比MHA和MLA的向量计算过程:

上述图示仅展示了MLA对token向量如何映射为多头key和value向量的过程,未展示RoPE位置编码的映射计算过程。MLA将内容相关的低秩压缩向量与解耦的RoPE key向量分开缓存:前者用于恢复key/value的内容部分,后者提供注意力打分所需的位置信息。
从向量计算的角度看,每个token需要存储的缓存元素为:
- 每层的 KV Cache 需要存储一个潜在压缩向量(compressed latent vector),维度为\(d_c\)。其中\(d_c\)远小于\(n_h d_h\)。对于DeepSeek-V2,\(d_c\)设置为\(4d_h\)
- 同时,每层的 KV Cache 需要存储一个解耦的RoPE key向量,维度为\(d_h^R\)。对于DeepSeek-V2,\(d_h^R\)设置为\(\frac{d_h}{2}\)。
- 因此,每层单 token 的存储量为:
\(单 token 存储量=(d_c+d_h^R) = \frac{9}{2}d_h\)
- 对于 \(l\) 层 的Transformer,单 token 总存储量为:
\(单 token 总存储量=l⋅\frac{9}{2}d_h=\frac{9}{2}·d_h·l\)
参数对比
以DeepSeek-V2的参数为例,
- Transformer层数\(l=60\)
- 注意力头数\(n_h=128\)
- 每个头的向量维度\(d_h=128\)
- KV压缩向量维度\(d_c=512\)
- Q和K每个头的解耦RoPE维度\(d_h^R=64\)
- \(n_g=8\)
假设每个缓存元素采取FP16精度,占2个字节。
| 注意力机制 | 单token的KV Cache元素数 | 单token KV Cache存储大小 |
|---|---|---|
| Multi-Head Attention(MHA) | \(2n_hd_hl\) | 3840KiB ~ 3.75MiB |
| Grouped-Query Attention(GQA) | \(2n_gd_hl\) | 240KiB |
| Multi-Query Attention(MQA) | \(2d_hl\) | 30KiB |
| Multi-Head Latent Attention(MLA) | \((d_c+d_h^R)l\approx\frac92d_hl\) | 67.5KiB |