背景介绍

在自回归生成中,模型每次只生成一个新 token。为了避免在每一步重复计算所有历史 token 的 Key 和 Value,推理框架通常会缓存历史 tokens 的 KV 向量。KV Cache 能显著减少重复计算,但也会带来可观的显存占用,因此 MQA、GQA、MLA 等注意力变种都在不同程度上围绕 KV Cache 做优化。

KV Cache 基础:MHA

Multi-Head Attention vs Grouped-Query Attention vs Multi-head Latent Attention
Multi-Head Attention vs Grouped-Query Attention vs Multi-head Latent Attention

在 Multi-Head Attention(MHA)中,每个 query head 都有对应的 key head 和 value head。前向计算过程中,为加速生成,可以缓存历史tokens的 Key 和 Value 向量。从向量计算的角度看,每个token需要存储的缓存元素为:

  • 每层的 KV Cache 需要存储\(n_h\)个头的 key 和 value。
  • 每个头的 key 和 value 维度为\(d_h\)。
  • 因此,每层单 token 的存储量为:

\(单 token 存储量=n_h⋅(d_h+d_h)=2⋅n_h⋅d_h\)

  • 对于 \(l\) 层 的Transformer,单 token 总存储量为:

\(单 token 总存储量=l⋅n_h⋅(d_h+d_h)=2⋅l⋅n_h⋅d_h\)

举个例子:Transformer层数\(l=60\),注意力头数\(n_h=128\),每个头的向量维度\(d_h=128\)。假设每个缓存元素采取FP16精度,占2个字节。则单token KV cache总存储大小为3840KiB,即3.75MiB。

KV Cache 优化一:Multi-Query Attention(MQA)

对于每个token,映射成key和value向量时,共享一个 \(W_k\) 和 \(W_v\),得到每个头都一样的 key 和 value 向量;映射成query向量时,和MHA保持一致,每个头是各不相同的 query 向量。所以:

  1. 计算上,减少了token向量映射到key和value的计算量,softmax score的计算量保持不变。
  2. 内存上,key 和 value 的向量缓存量级大大减少。

缺点是可能带来模型效果下降,具体幅度取决于模型规模、训练方式和任务。

KV Cache 优化二:Grouped-Query Attention(GQA)

为了缓解模型效果衰退,需要适当增加 key 和 value 的头数。对于单个token,映射成多头的key和value向量时,会把所有head分为g组,同一组内的head共享\(W_K\)和\(W_V\)。因为key和value的头数介于MQA和MHA,所以计算和存储上的节省介于二者之间;在合适训练或uptraining下,GQA通常可以取得接近MHA的效果,同时保留接近MQA的推理效率优势。

MHA/MQA/GQA 源码实现

下面我们一起看 Llama3 attention 源码,它通过 n_kv_heads 把MHA、MQA、GQA统一起来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# https://github.com/meta-llama/llama3/blob/main/llama/model.py

class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# kv的head数,决定attention的类型
# MHA:args.n_kv_heads = args.n_heads
# MQA:args.n_kv_heads = 1
# GQA:args.n_kv_heads > 1 and args.n_kv_head < args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 训练集群模型并行数,便于理解,假设为1,即整个模型在单卡中训练
model_parallel_size = fs_init.get_model_parallel_world_size()
# 单卡:query的head数
self.n_local_heads = args.n_heads // model_parallel_size
# 单卡:key、value的head数
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
# 单卡:每个kv head需要复制到多少个query head
self.n_rep = self.n_local_heads // self.n_local_kv_heads
# 每个head的维度 = 输入embeding维度 // query的head数
self.head_dim = args.dim // args.n_heads

# token映射成query的变换矩阵
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim, # 注意:这里是n_heads
bias=False,
gather_output=False,
init_method=lambda x: x,
)
# token映射成key的变换矩阵
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim, # 注意:这里是n_kv_heads
bias=False,
gather_output=False,
init_method=lambda x: x,
)
# token映射成value的变换矩阵
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim, # 注意:这里是n_kv_heads
bias=False,
gather_output=False,
init_method=lambda x: x,
)

# 多头concat后映射回d_model的变换矩阵
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)

# 历史token的K缓存
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
# 历史token的V缓存
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()

def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape

# token -> query, key, value
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# 旋转位置编码(Rotary Position Embedding,RoPE)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

# 缓存当前tokens的key、value向量
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

# 获取序列历史tokens的key、value向量
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

# 对于GQA/MQA,保证相同group下的key保持一致
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)

# 计算softmax score
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

# decode-only架构,需要mask掉右边的token
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# 对所有value加权求和
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 把向量映射回输入embeding维度,方便输入到下一层
return self.wo(output)

KV Cache 优化三:Multi-Head Latent Attention(MLA)

多头潜在注意力(Multi-Head Latent Attention,MLA)的核心在于对注意力的键(key)和值(value)进行低秩联合压缩,以便在推理过程中减少键值(KV)缓存。MLA的计算过程如下:

Multi-Head Latent Attention 通过低秩压缩减少 KV Cache 存储
Multi-Head Latent Attention 通过低秩压缩减少 KV Cache 存储

通过多头向量计算,对比MHA和MLA的向量计算过程:

MHA 与 MLA 的 Key/Value 向量缓存结构对比
MHA 与 MLA 的 Key/Value 向量缓存结构对比

上述图示仅展示了MLA对token向量如何映射为多头key和value向量的过程,未展示RoPE位置编码的映射计算过程。MLA将内容相关的低秩压缩向量与解耦的RoPE key向量分开缓存:前者用于恢复key/value的内容部分,后者提供注意力打分所需的位置信息。

从向量计算的角度看,每个token需要存储的缓存元素为:

  • 每层的 KV Cache 需要存储一个潜在压缩向量(compressed latent vector),维度为\(d_c\)。其中\(d_c\)远小于\(n_h d_h\)。对于DeepSeek-V2,\(d_c\)设置为\(4d_h\)
  • 同时,每层的 KV Cache 需要存储一个解耦的RoPE key向量,维度为\(d_h^R\)。对于DeepSeek-V2,\(d_h^R\)设置为\(\frac{d_h}{2}\)。
  • 因此,每层单 token 的存储量为:

\(单 token 存储量=(d_c+d_h^R) = \frac{9}{2}d_h\)

  • 对于 \(l\) 层 的Transformer,单 token 总存储量为:

\(单 token 总存储量=l⋅\frac{9}{2}d_h=\frac{9}{2}·d_h·l\)

参数对比

以DeepSeek-V2的参数为例,

  • Transformer层数\(l=60\)
  • 注意力头数\(n_h=128\)
  • 每个头的向量维度\(d_h=128\)
  • KV压缩向量维度\(d_c=512\)
  • Q和K每个头的解耦RoPE维度\(d_h^R=64\)
  • \(n_g=8\)

假设每个缓存元素采取FP16精度,占2个字节。

注意力机制 单token的KV Cache元素数 单token KV Cache存储大小
Multi-Head Attention(MHA) \(2n_hd_hl\) 3840KiB ~ 3.75MiB
Grouped-Query Attention(GQA) \(2n_gd_hl\) 240KiB
Multi-Query Attention(MQA) \(2d_hl\) 30KiB
Multi-Head Latent Attention(MLA) \((d_c+d_h^R)l\approx\frac92d_hl\) 67.5KiB

参考资料

  1. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
  2. DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
  3. DeepSeek-V3 Technical Report
  4. 大模型百倍推理加速之KV cache篇
  5. llm_note: transformer model