
解锁速度:快速响应 Llms 高速缓存增强生成综合指南
- Rifx.Online
- Large Language Models , Best Practices , Case Studies
- 05 Mar, 2025
检索增强生成 (RAG)
检索增强生成 (RAG) 是一种强大的方法,可以将外部知识库连接到大语言模型 (LLM),并在用户每次提问时获取上下文,但由于其检索延迟,它可能会减慢大语言模型 (LLM) 的性能。
缓存增强生成 (CAG)
缓存增强生成 (CAG) 提供了一种更快的替代方案;它不是进行实时检索,而是将相关文档 预加载 到模型的上下文中,并存储该推理状态 — 也称为键值 (KV) 缓存。这种方法消除了检索延迟,使模型能够瞬时访问预加载的信息,从而更快、更高效地响应。
要了解有关 CAG 的更技术性解释,请查看 这篇文章。
在本教程中,我们将展示如何构建一个简单的 CAG 设置,以便提前嵌入所有知识,快速回答多个用户查询,并在不每次重新加载整个上下文的情况下重置缓存。
前提条件
- 一个 HuggingFace 账户和一个 HuggingFace 访问令牌
- 一个包含关于您的句子的 document.txt 文件。
项目设置
我们导入必要的库:
torch
用于 PyTorch。transformers
用于 Hugging Face。DynamicCache
用于存储模型的键值状态。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os
生成函数
接下来我们将定义 generate
函数。
generate
函数处理使用贪婪解码的缓存知识进行逐个标记的生成。
贪婪解码是一种简单的文本生成方法,在每一步中,选择概率最高的标记(logits 中的最大值)作为下一个标记。
我们传入这些输入:
model
: 大语言模型 (LLM),本教程中使用的是 Mistral-7B。input_ids
: 包含分词输入序列的张量。past_key_values
: 缓存增强生成 (CAG) 的核心组件。使用先前计算的注意力值的缓存来加速推理,避免重新计算。max_new_tokens
: 生成的新标记的最大数量。默认值为 50。
该函数在一个循环中运行,迭代次数最多为 max_new_tokens
次,或者如果生成了结束标记(如果配置)则提前终止。
在每次迭代中:
- 模型处理当前输入标记以及缓存的
past_key_values
,产生下一个标记的 logits。 - 分析 logits,以使用贪婪解码识别概率最高的标记。
- 将这个新标记附加到输出序列中,并更新缓存 (
past_key_values
) 以包含当前上下文。 - 新生成的标记成为下一次迭代的输入。
def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:]
动态缓存 (DynamicCache) 设置
接下来,我们将定义 get_kv_cache
函数,该函数为 transformer 模型的注意力机制准备一个可重用的键值 (KV) 缓存,以及 clean_up
函数,该函数通过删除不必要的条目来清理键值缓存,以确保您可以回答多个独立的问题而不会“污染”缓存。
get_kv_cache
将一个提示(在我们的例子中,来自 document.txt
的知识)通过模型处理一次,创建一个记录每一层所有隐藏状态的 KV 缓存。
get_kv_cache
接收以下输入:
model
: 用于编码提示的 transformer 模型。tokenizer
: 将提示转换为 token ID 的分词器。prompt
: 用作提示的字符串输入。
并返回一个 DynamicCache
类型的对象。
get_kv_cache
函数首先使用分词器对提供的提示进行分词,将其转换为输入 ID,然后初始化一个 DynamicCache
对象以存储键值对,接着在启用缓存的情况下(use_cache=True
)通过模型执行前向传播。这会用模型计算的结果填充缓存中的键值对。
clean_up
通过移除处理过程中添加的任何额外 token 来修剪 DynamicCache
对象以匹配原始序列长度。对于缓存的每一层,它切片键和值张量,仅保留序列维度上的前 origin_len
个 token。
def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache()
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache
def clean_up(cache: DynamicCache, origin_len: int):
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
加载大语言模型 (Mistral)
现在我们将加载 Mistral-7B 模型,并在 GPU 上以全精度或半精度 (FP16) 加载分词器和模型(如果可用)。
请记得输入 YOUR_HF_TOKEN
作为您的唯一 HuggingFace 令牌。
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, token="YOUR_HF_TOKEN", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
token="YOUR_HF_TOKEN"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Loaded {model_name}.")
从 document.txt 创建知识提示
接下来,我们将读取 document.txt
,您可以在其中填写有关自己的信息。在本教程中,document.txt
包含关于我的信息(Ronan Takizawa)。
在这里,我们构建一个简单的系统提示嵌入,包含文档信息,并将其传递给 get_kv_cache
以生成 KV 缓存。
with open("document.txt", "r", encoding="utf-8") as f:
doc_text = f.read()
system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()
ronan_cache = get_kv_cache(model, tokenizer, system_prompt)
origin_len = ronan_cache.key_cache[0].shape[-2]
print("KV cache built.")
Ask Questions Reusing the Cache
我们首先运行 clean_up
来清理我们的缓存(这是 CAG 的良好实践)。
接下来,我们将问题转换为 input_ids_q1
中的令牌,然后附加到存储在 ronan_cache
中的知识上下文。
最后,我们调用 generate
来生成答案,并使用 tokenizer.decode
解码最终结果。
question1 = "Who is Ronan Takizawa?"
clean_up(ronan_cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, ronan_cache)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)
print("Q1:", question1)
print("A1:", answer1)
你应该期待这样的响应:
Q1: Who is Ronan Takizawa?
A1: Answer: Ronan Takizawa is an ambitious and accomplished
tech enthusiast. He has a diverse skill set in
software development, AI/ML…
结论
缓存增强生成 (CAG) 通过直接在模型的上下文窗口中存储小型知识库,简化了 AI 架构,消除了 RAG 中检索循环的需要,并减少了延迟。这种方法提高了响应速度,并改善了大语言模型 (LLM) 对外部知识的响应能力。通过利用 CAG,开发者可以简化他们的 AI 系统,以实现更快和更高效的知识集成,特别是对于具有稳定、紧凑数据集的任务。