KV Cache(Key-Value Cache,键值缓存)是大语言模型推理里非常核心的优化。它解决的问题很直接:模型自回归生成时,每一步都会多出一个 token,如果每次都把完整上下文重新送进 Transformer,就会反复计算历史 token 的 Key 和 Value,浪费大量计算。
KV Cache 的做法是:历史 token 的 Key 和 Value 只算一次,后续生成时直接复用;每来一个新 token,只计算这个新 token 的 Query、Key、Value,再把新的 Key 和 Value 追加到缓存里。
一句话概括:
KV Cache 用显存换计算,避免重复计算历史 token 的 Key/Value,让每一步解码只处理新 token。
Transformer 自注意力里的 Q、K、V
Transformer 的注意力机制会把输入 token 表示映射成三组向量:
- Query(查询):当前 token 想要找什么信息;
- Key(键):每个 token 提供一个“可被匹配”的索引;
- Value(值):真正被聚合的信息内容。
设输入序列为:
其中:
- T 是序列长度;
- d_{model} 是模型隐藏层维度。
经过三个线性变换后得到:
注意力计算为:
在 decoder-only 模型里,例如 GPT、Llama、Qwen,每个 token 只能关注自己和它之前的 token,不能看到未来 token,因此还要加入 causal mask(因果掩码)。
整体结构可以抽象成这样:
flowchart LR
X[输入 token 表示 X] --> Q[线性变换得到 Q]
X --> K[线性变换得到 K]
X --> V[线性变换得到 V]
Q --> A[QK^T / sqrt(d)]
K --> A
A --> S[softmax + causal mask]
S --> O[注意力权重乘以 V]
V --> O
O --> Y[输出表示]
注意力里最关键的一步是 QK^T。它会计算每个 Query 和所有 Key 的匹配程度,然后用这个匹配结果去加权聚合 Value。
自回归生成为什么会重复计算
大语言模型生成文本时通常是自回归的。也就是说,模型一次只预测一个新 token,然后把这个新 token 拼到上下文后面,再预测下一个 token。
假设 prompt 有 3 个 token:
[t1, t2, t3]
模型生成第 4 个 token 后,输入变成:
[t1, t2, t3, t4]
再生成第 5 个 token 时,输入变成:
[t1, t2, t3, t4, t5]
如果没有 KV Cache,每一步都重新把完整序列送进模型,那么历史 token 的 Key 和 Value 会被反复计算。
| 生成步骤 | 当前序列 | 如果不使用 KV Cache,需要计算什么 |
|---|---|---|
| 生成 t4 | t1, t2, t3 | 计算 t1、t2、t3 的 Q/K/V |
| 生成 t5 | t1, t2, t3, t4 | 重新计算 t1、t2、t3、t4 的 Q/K/V |
| 生成 t6 | t1, t2, t3, t4, t5 | 重新计算 t1 到 t5 的 Q/K/V |
问题在于:历史 token 的隐藏表示在推理过程中不会变。对于同一层 Transformer 来说,t1、t2、t3 的 Key 和 Value 已经算过,再算一遍没有意义。
这个浪费可以用流程图表示:
flowchart TD
A[已有上下文 t1 t2 t3] --> B[生成 t4]
B --> C[输入 t1 t2 t3 t4]
C --> D[重新计算 t1 t2 t3 t4 的 K/V]
D --> E[生成 t5]
E --> F[输入 t1 t2 t3 t4 t5]
F --> G[再次重新计算历史 token 的 K/V]
随着生成长度增加,这种重复计算会越来越严重。
KV Cache 的核心机制
KV Cache 的核心思想非常简单:缓存每一层 attention 里已经计算过的 Key 和 Value。
对于某一层 attention,可以把缓存看成两个张量:
K_cache: [batch, num_heads, cached_seq_len, head_dim]
V_cache: [batch, num_heads, cached_seq_len, head_dim]
每生成一个新 token,只需要计算这个新 token 的:
q_new, k_new, v_new
然后把新的 Key 和 Value 追加到缓存中:
当前 token 的注意力计算变成:
这里的 Query 只有当前 token 一个,但 Key 和 Value 包含完整历史上下文。
flowchart LR
N[新 token] --> QKV[计算 q_new k_new v_new]
QKV --> Q[q_new]
QKV --> K[k_new]
QKV --> V[v_new]
KCache[(K Cache)]
VCache[(V Cache)]
K --> KCache
V --> VCache
Q --> ATT[当前 q 与完整 K Cache 做注意力]
KCache --> ATT
VCache --> ATT
ATT --> LOGITS[输出 logits]
LOGITS --> NEXT[采样得到下一个 token]
这样做之后,每一步不再为所有历史 token 重算 Q/K/V,而是只处理新 token。
Prefill 和 Decode:KV Cache 的两个阶段
LLM 推理通常可以拆成两个阶段:
- Prefill:处理 prompt;
- Decode:逐 token 生成。
这两个阶段的计算模式不同,KV Cache 的作用也不同。
sequenceDiagram
participant User as 输入 prompt
participant Model as Transformer
participant Cache as KV Cache
participant Sampler as 采样器
User->>Model: 一次性送入完整 prompt
Model->>Cache: 缓存每层的 K/V
Model-->>Sampler: 输出最后位置的 logits
Sampler-->>Model: 采样得到第一个新 token
loop 每生成一个 token
Model->>Model: 只计算当前 token 的 Q/K/V
Model->>Cache: 追加新的 K/V
Cache-->>Model: 提供完整历史 K/V
Model-->>Sampler: 输出下一个 token 的 logits
Sampler-->>Model: 采样得到新 token
end
Prefill 阶段
Prefill 阶段会一次性处理完整 prompt。
假设 prompt 长度为 P,模型会对这 P 个 token 做完整前向计算,并在每一层 attention 中保存它们的 Key 和 Value。
Prefill 阶段有两个结果:
- 得到 prompt 最后一个位置的 logits,用于采样第一个新 token;
- 建立初始 KV Cache,供后续 decode 阶段复用。
这一阶段的输入长度可能很长,所以它通常是大矩阵计算,GPU 利用率较高。
Decode 阶段
Decode 阶段每次只输入一个新 token。
对于当前新 token,模型会:
- 计算当前 token 的 Query、Key、Value;
- 把新的 Key 追加到 K Cache;
- 把新的 Value 追加到 V Cache;
- 用当前 Query 关注完整历史 Key/Value;
- 输出 logits;
- 采样得到下一个 token。
Decode 阶段每一步计算量小,但会循环很多次,因此延迟很容易被内存访问、缓存管理和小 batch 计算影响。
复杂度变化:不是魔法,只是少算了旧 token
KV Cache 最重要的收益来自“避免重算历史 token”。
设当前上下文长度为 t。
如果不使用 KV Cache,每一步都处理完整序列,attention 矩阵规模接近:
如果使用 KV Cache,每一步只处理当前新 token,它只需要和历史 t 个 Key 做注意力,attention 矩阵规模变成:
| 对比项 | 不使用 KV Cache | 使用 KV Cache |
|---|---|---|
| 每步输入 | 完整上下文 | 当前新 token |
| 历史 token 的 K/V | 每步重算 | 从缓存读取 |
| 当前步 attention 规模 | t \times t | 1 \times t |
| 每步 attention 复杂度 | 约 O(t^2) | 约 O(t) |
| 主要代价 | 计算浪费 | 额外显存 |
需要注意,KV Cache 并不会让生成长文本的总成本变成常数。因为第 t 个新 token 仍然要关注前面所有 token,所以 decode 阶段的 attention 访问长度仍然随上下文增长。
它减少的是“重复计算历史 token 表示”的成本。
KV Cache 占多少显存
KV Cache 的显存开销和序列长度、层数、KV 头数、head 维度、数据类型有关。
常见估算公式:
其中:
| 符号 | 含义 |
|---|---|
| B | batch size |
| L | Transformer 层数 |
| 2 | Key 和 Value 两份缓存 |
| T | 当前缓存 token 数 |
| H_{kv} | KV heads 数量 |
| D_{head} | 每个 head 的维度 |
| bytes | 每个元素占用字节数,例如 FP16 为 2 字节 |
以一个 Llama-like 7B 结构估算:
- batch size = 1;
- 层数 = 32;
- KV heads = 32;
- head_dim = 128;
- 上下文长度 = 4096;
- 数据类型 = FP16,2 字节。
显存约为:
如果上下文长度从 4K 增加到 128K,KV Cache 也会线性增长,显存压力会非常明显。
这也是长上下文推理昂贵的重要原因之一:模型权重大小固定,但 KV Cache 会随着请求长度和并发数增长。
用 Hugging Face Transformers 对比开启和关闭 KV Cache
Hugging Face Transformers 的 generate 默认通常会启用 KV Cache,也可以通过 use_cache 显式控制。
import time
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
model.eval()
prompt = "What is KV caching?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
for use_cache in [True, False]:
times = []
for _ in range(10):
start = time.time()
with torch.inference_mode():
_ = model.generate(
**inputs,
use_cache=use_cache,
max_new_tokens=1000,
do_sample=False,
)
times.append(time.time() - start)
print(
f"{'with' if use_cache else 'without'} KV Cache: "
f"{np.mean(times):.3f} ± {np.std(times):.3f} seconds"
)
在 GPT-2、Tesla T4、生成 1000 个 token 的一次测量中,可以得到类似结果:
| 设置 | 平均耗时 |
|---|---|
| 开启 KV Cache | 约 11.9 秒 |
| 关闭 KV Cache | 约 56.2 秒 |
具体数字会受 GPU、模型大小、batch size、生成长度和 Transformers 版本影响,但趋势通常很稳定:生成越长,KV Cache 的收益越明显。
简化版 KV Cache 伪代码
真实模型会在每一层维护 KV Cache。为了突出核心逻辑,可以先看一个简化版本:
# 伪代码:只展示单层 attention 的核心逻辑
k_cache = None
v_cache = None
for token in generation_loop:
# 当前步只输入一个 token
q_new, k_new, v_new = compute_qkv(token)
if k_cache is None:
k_cache = k_new
v_cache = v_new
else:
# 沿 sequence 维度追加
k_cache = concat(k_cache, k_new, dim="seq")
v_cache = concat(v_cache, v_new, dim="seq")
# 当前 token 的 Query 关注完整历史 K/V
attn_output = attention(q_new, k_cache, v_cache)
logits = lm_head(attn_output)
token = sample_next_token(logits)
多层 Transformer 中,每一层都有自己的缓存:
# 更接近 Transformer 的伪代码
past_key_values = [
{"k": None, "v": None}
for _ in range(num_layers)
]
for token in generation_loop:
hidden = embed(token)
for layer_id, layer in enumerate(transformer_layers):
q, k, v = layer.compute_qkv(hidden)
old_k = past_key_values[layer_id]["k"]
old_v = past_key_values[layer_id]["v"]
if old_k is not None:
k_full = torch.cat([old_k, k], dim=2) # seq 维度
v_full = torch.cat([old_v, v], dim=2)
else:
k_full = k
v_full = v
hidden = layer.attention(q, k_full, v_full)
past_key_values[layer_id]["k"] = k_full
past_key_values[layer_id]["v"] = v_full
logits = lm_head(hidden)
token = sample_next_token(logits)
生产级推理框架通常不会在每一步直接 torch.cat,因为反复拼接会触发内存复制。更常见的做法是预分配缓存空间,或者使用分页式 KV Cache 管理。
KV Cache 只适合自回归解码
KV Cache 主要服务于自回归生成,也就是 decoder-only 模型或 encoder-decoder 模型的 decoder 部分。
| 模型类型 | 是否使用 KV Cache | 原因 |
|---|---|---|
| GPT / Llama / Qwen 等 decoder-only 模型 | 使用 | 逐 token 生成,需要复用历史 K/V |
| T5 等 encoder-decoder 模型的 decoder | 使用 | decoder 自回归生成 |
| encoder-decoder 的 encoder | 通常不需要逐步追加 | encoder 一次性处理输入 |
| BERT 等双向编码模型 | 通常不使用 | 不是逐 token 生成模型 |
训练阶段也通常不会用推理式 KV Cache。训练时会把完整序列并行送入模型,通过 causal mask 同时计算所有位置的损失;推理时才需要一步一步生成新 token。
常见优化方向
KV Cache 带来了速度收益,也带来了显存和调度问题。现代推理系统围绕 KV Cache 做了很多工程优化。
| 问题 | 优化方法 | 核心思路 |
|---|---|---|
| 长上下文导致 KV Cache 过大 | KV Cache 量化 | 用 INT8、INT4 等低精度格式保存缓存 |
| 多请求并发时显存碎片严重 | PagedAttention | 像虚拟内存一样分页管理 KV Cache |
| 多个请求共享相同 prompt | Prefix Caching | 复用公共前缀的 KV Cache |
| batch 内请求长度不同 | Continuous Batching | 动态把不同请求合并调度,提高 GPU 利用率 |
| KV heads 太多 | MQA / GQA | 多个 Query heads 共享较少的 KV heads |
| 上下文无限增长 | Sliding Window Attention | 只保留最近一段窗口内的 KV |
其中,PagedAttention 是 vLLM 这类推理框架的重要设计。它不是改变注意力数学公式,而是改变 KV Cache 的内存管理方式,减少碎片和无效拷贝,让更多请求可以同时跑在 GPU 上。
使用 KV Cache 时容易踩的坑
1. 显存会随着上下文长度线性增长
模型权重加载后显存占用基本固定,但 KV Cache 会随着请求增长。长上下文、多并发、大 batch 会让缓存迅速变大,最终触发 OOM(Out Of Memory,内存不足)。
2. 位置编码必须对齐
使用 RoPE(Rotary Position Embedding,旋转位置编码)或绝对位置编码时,新 token 的 position id 必须正确递增。如果缓存里的 token 位置和新 token 位置错位,模型输出会异常。
3. Beam Search 需要重排缓存
Beam Search 会在每一步保留多个候选分支。分支被选择、丢弃或重排时,对应的 KV Cache 也必须同步重排,否则 token 分支和缓存内容会对不上。
4. 频繁拼接缓存会拖慢速度
示例代码里的 torch.cat 适合理解原理,但每步拼接都会产生新张量。高性能推理框架通常使用预分配、分页、块管理等方式维护缓存。
5. KV Cache 不能跨不同上下文随便复用
只有当请求前缀完全一致时,缓存才可以复用。如果 prompt 中任意 token 不同,对应位置之后的 Key/Value 都不能直接共享。
小结
KV Cache 是 LLM 推理加速的基础组件。它把历史 token 的 Key 和 Value 存起来,使 decode 阶段每一步只需要处理当前新 token,而不是反复处理完整上下文。
它带来的变化可以概括为:
| 维度 | 变化 |
|---|---|
| 计算 | 避免重复计算历史 token 的 K/V |
| 延迟 | 长文本生成明显变快 |
| 显存 | 额外保存每层 Key/Value |
| 工程难点 | 长上下文、多并发、缓存碎片、分支重排 |
只要是自回归 LLM 推理,KV Cache 几乎都是默认配置。不开启 KV Cache,短输出还能勉强运行;一旦生成长度变长,重复计算历史 token 的代价会迅速放大。