缓存策略
--- title: "缓存策略" description: "LLM缓存策略,包括前缀缓存和语义缓存实现" tags: ["缓存", "前缀缓存", "语义缓存", "KV缓存", "性能优化"] category: "llm" icon: "🧠"
缓存策略
LLM缓存是提升推理效率和降低成本的重要技术。通过缓存KV状态、前缀计算结果和语义相似的响应,可以显著减少重复计算,提高系统响应速度。不同的缓存策略适用于不同的应用场景。
KV缓存
基础KV缓存实现
import torch
from typing import Dict, Optional
from dataclasses import dataclass, field
@dataclass
class KVCache:
max_size: int
cache: Dict[str, tuple] = field(default_factory=dict)
def get(self, key: str) -> Optional[tuple]:
return self.cache.get(key)
def put(self, key: str, k: torch.Tensor, v: torch.Tensor):
if len(self.cache) >= self.max_size:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
self.cache[key] = (k, v)
def contains(self, key: str) -> bool:
return key in self.cache
def clear(self):
self.cache.clear()
def get_stats(self) -> dict:
return {"size": len(self.cache), "max_size": self.max_size}
cache = KVCache(max_size=100)
k_tensor = torch.randn(1, 10, 8, 64)
v_tensor = torch.randn(1, 10, 8, 64)
cache.put("prompt_001", k_tensor, v_tensor)
print(f"缓存状态: {cache.get_stats()}")
print(f"命中缓存: {cache.contains('prompt_001')}")
前缀缓存
缓存公共前缀的KV状态,适用于多轮对话和系统提示。
from hashlib import sha256
from typing import Dict, List
from dataclasses import dataclass, field
@dataclass
class PrefixCache:
max_prefix_length: int
prefix_registry: Dict[str, Dict] = field(default_factory=dict)
def _compute_prefix_key(self, tokens: List[int], prefix_len: int) -> str:
prefix = tuple(tokens[:prefix_len])
return sha256(str(prefix).encode()).hexdigest()[:16]
def cache_prefix(self, tokens: List[int], prefix_len: int, kv_state: dict):
key = self._compute_prefix_key(tokens, prefix_len)
self.prefix_registry[key] = {
"tokens": tokens[:prefix_len],
"kv_state": kv_state,
"hit_count": 0
}
def lookup_prefix(self, tokens: List[int], min_match: int = 4) -> tuple:
best_match = None
best_len = 0
for key, entry in self.prefix_registry.items():
cached_tokens = entry["tokens"]
match_len = 0
for i in range(min(len(cached_tokens), len(tokens))):
if cached_tokens[i] == tokens[i]:
match_len += 1
else:
break
if match_len >= min_match and match_len > best_len:
best_match = entry
best_len = match_len
if best_match:
best_match["hit_count"] += 1
return best_match["kv_state"], best_len
return None, 0
def get_stats(self) -> dict:
total_hits = sum(e["hit_count"] for e in self.prefix_registry.values())
return {"prefixes": len(self.prefix_registry), "total_hits": total_hits}
prefix_cache = PrefixCache(max_prefix_length=1024)
system_tokens = [101, 2023, 102, 3000, 400, 500]
prefix_cache.cache_prefix(system_tokens, 4, {"k": "cached_state"})
kv, length = prefix_cache.lookup_prefix(system_tokens)
print(f"前缀缓存: {prefix_cache.get_stats()}")
语义缓存
基于语义相似度缓存响应,适用于相似问题的快速回复。
import numpy as np
from typing import List, Optional
from dataclasses import dataclass, field
from sentence_transformers import SentenceTransformer
@dataclass
class SemanticCache:
similarity_threshold: float
model: SentenceTransformer = field(default_factory=lambda: SentenceTransformer("all-MiniLM-L6-v2"))
cache_entries: List[dict] = field(default_factory=list)
def _compute_embedding(self, text: str) -> np.ndarray:
return self.model.encode(text)
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def get(self, query: str) -> Optional[dict]:
query_embedding = self._compute_embedding(query)
best_match = None
best_score = 0
for entry in self.cache_entries:
score = self._cosine_similarity(query_embedding, entry["embedding"])
if score > best_score and score >= self.similarity_threshold:
best_match = entry
best_score = score
if best_match:
best_match["hit_count"] += 1
return {
"response": best_match["response"],
"similarity": best_score,
"cached": True
}
return None
def put(self, query: str, response: str):
embedding = self._compute_embedding(query)
self.cache_entries.append({
"query": query,
"response": response,
"embedding": embedding,
"hit_count": 0
})
def get_stats(self) -> dict:
total_hits = sum(e["hit_count"] for e in self.cache_entries)
return {"entries": len(self.cache_entries), "total_hits": total_hits}
cache = SemanticCache(similarity_threshold=0.85)
cache.put("什么是机器学习?", "机器学习是人工智能的一个分支...")
cache.put("如何学习Python?", "学习Python可以从基础语法开始...")
queries = ["机器学习是什么?", "Python怎么学?", "今天天气怎么样?"]
for query in queries:
result = cache.get(query)
if result:
print(f"命中: {query} → 相似度: {result['similarity']:.2f}")
else:
print(f"未命中: {query}")
print(f"统计: {cache.get_stats()}")
多级缓存架构
from dataclasses import dataclass, field
from typing import Dict, Optional
@dataclass
class MultiLevelCache:
l1_cache: Dict[str, str] = field(default_factory=dict)
l2_cache: Dict[str, str] = field(default_factory=dict)
l1_max_size: int = 100
l2_max_size: int = 1000
def get(self, key: str) -> Optional[str]:
if key in self.l1_cache:
return self.l1_cache[key]
if key in self.l2_cache:
value = self.l2_cache[key]
self._promote_to_l1(key, value)
return value
return None
def put(self, key: str, value: str):
self.l1_cache[key] = value
if len(self.l1_cache) > self.l1_max_size:
oldest = next(iter(self.l1_cache))
self._demote_to_l2(oldest, self.l1_cache.pop(oldest))
def _promote_to_l1(self, key: str, value: str):
self.l1_cache[key] = value
self.l2_cache.pop(key, None)
def _demote_to_l2(self, key: str, value: str):
self.l2_cache[key] = value
if len(self.l2_cache) > self.l2_max_size:
self.l2_cache.pop(next(iter(self.l2_cache)))
def get_stats(self) -> dict:
return {"l1_size": len(self.l1_cache), "l2_size": len(self.l2_cache)}
cache = MultiLevelCache()
for i in range(150):
cache.put(f"key-{i}", f"value-{i}")
print(f"多级缓存: {cache.get_stats()}")
print(f"查询 key-100: {cache.get('key-100')}")
缓存策略选择
根据场景选择缓存策略:KV缓存适合减少重复计算,前缀缓存适合多轮对话,语义缓存适合相似问题快速回复,多级缓存适合大规模部署。合理的缓存策略能显著提升LLM服务的性能和成本效率。