← 返回首页
🧠

缓存策略

📂 llm ⏱ 3 min 565 words

--- title: "缓存策略" description: "LLM缓存策略,包括前缀缓存和语义缓存实现" tags: ["缓存", "前缀缓存", "语义缓存", "KV缓存", "性能优化"] category: "llm" icon: "🧠"

缓存策略

LLM缓存是提升推理效率和降低成本的重要技术。通过缓存KV状态、前缀计算结果和语义相似的响应,可以显著减少重复计算,提高系统响应速度。不同的缓存策略适用于不同的应用场景。

KV缓存

基础KV缓存实现

import torch
from typing import Dict, Optional
from dataclasses import dataclass, field

@dataclass
class KVCache:
    max_size: int
    cache: Dict[str, tuple] = field(default_factory=dict)

    def get(self, key: str) -> Optional[tuple]:
        return self.cache.get(key)

    def put(self, key: str, k: torch.Tensor, v: torch.Tensor):
        if len(self.cache) >= self.max_size:
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        self.cache[key] = (k, v)

    def contains(self, key: str) -> bool:
        return key in self.cache

    def clear(self):
        self.cache.clear()

    def get_stats(self) -> dict:
        return {"size": len(self.cache), "max_size": self.max_size}

cache = KVCache(max_size=100)
k_tensor = torch.randn(1, 10, 8, 64)
v_tensor = torch.randn(1, 10, 8, 64)
cache.put("prompt_001", k_tensor, v_tensor)
print(f"缓存状态: {cache.get_stats()}")
print(f"命中缓存: {cache.contains('prompt_001')}")

前缀缓存

缓存公共前缀的KV状态,适用于多轮对话和系统提示。

from hashlib import sha256
from typing import Dict, List
from dataclasses import dataclass, field

@dataclass
class PrefixCache:
    max_prefix_length: int
    prefix_registry: Dict[str, Dict] = field(default_factory=dict)

    def _compute_prefix_key(self, tokens: List[int], prefix_len: int) -> str:
        prefix = tuple(tokens[:prefix_len])
        return sha256(str(prefix).encode()).hexdigest()[:16]

    def cache_prefix(self, tokens: List[int], prefix_len: int, kv_state: dict):
        key = self._compute_prefix_key(tokens, prefix_len)
        self.prefix_registry[key] = {
            "tokens": tokens[:prefix_len],
            "kv_state": kv_state,
            "hit_count": 0
        }

    def lookup_prefix(self, tokens: List[int], min_match: int = 4) -> tuple:
        best_match = None
        best_len = 0

        for key, entry in self.prefix_registry.items():
            cached_tokens = entry["tokens"]
            match_len = 0
            for i in range(min(len(cached_tokens), len(tokens))):
                if cached_tokens[i] == tokens[i]:
                    match_len += 1
                else:
                    break

            if match_len >= min_match and match_len > best_len:
                best_match = entry
                best_len = match_len

        if best_match:
            best_match["hit_count"] += 1
            return best_match["kv_state"], best_len
        return None, 0

    def get_stats(self) -> dict:
        total_hits = sum(e["hit_count"] for e in self.prefix_registry.values())
        return {"prefixes": len(self.prefix_registry), "total_hits": total_hits}

prefix_cache = PrefixCache(max_prefix_length=1024)
system_tokens = [101, 2023, 102, 3000, 400, 500]
prefix_cache.cache_prefix(system_tokens, 4, {"k": "cached_state"})
kv, length = prefix_cache.lookup_prefix(system_tokens)
print(f"前缀缓存: {prefix_cache.get_stats()}")

语义缓存

基于语义相似度缓存响应,适用于相似问题的快速回复。

import numpy as np
from typing import List, Optional
from dataclasses import dataclass, field
from sentence_transformers import SentenceTransformer

@dataclass
class SemanticCache:
    similarity_threshold: float
    model: SentenceTransformer = field(default_factory=lambda: SentenceTransformer("all-MiniLM-L6-v2"))
    cache_entries: List[dict] = field(default_factory=list)

    def _compute_embedding(self, text: str) -> np.ndarray:
        return self.model.encode(text)

    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

    def get(self, query: str) -> Optional[dict]:
        query_embedding = self._compute_embedding(query)

        best_match = None
        best_score = 0

        for entry in self.cache_entries:
            score = self._cosine_similarity(query_embedding, entry["embedding"])
            if score > best_score and score >= self.similarity_threshold:
                best_match = entry
                best_score = score

        if best_match:
            best_match["hit_count"] += 1
            return {
                "response": best_match["response"],
                "similarity": best_score,
                "cached": True
            }
        return None

    def put(self, query: str, response: str):
        embedding = self._compute_embedding(query)
        self.cache_entries.append({
            "query": query,
            "response": response,
            "embedding": embedding,
            "hit_count": 0
        })

    def get_stats(self) -> dict:
        total_hits = sum(e["hit_count"] for e in self.cache_entries)
        return {"entries": len(self.cache_entries), "total_hits": total_hits}

cache = SemanticCache(similarity_threshold=0.85)
cache.put("什么是机器学习?", "机器学习是人工智能的一个分支...")
cache.put("如何学习Python?", "学习Python可以从基础语法开始...")

queries = ["机器学习是什么?", "Python怎么学?", "今天天气怎么样?"]
for query in queries:
    result = cache.get(query)
    if result:
        print(f"命中: {query} → 相似度: {result['similarity']:.2f}")
    else:
        print(f"未命中: {query}")
print(f"统计: {cache.get_stats()}")

多级缓存架构

from dataclasses import dataclass, field
from typing import Dict, Optional

@dataclass
class MultiLevelCache:
    l1_cache: Dict[str, str] = field(default_factory=dict)
    l2_cache: Dict[str, str] = field(default_factory=dict)
    l1_max_size: int = 100
    l2_max_size: int = 1000

    def get(self, key: str) -> Optional[str]:
        if key in self.l1_cache:
            return self.l1_cache[key]
        if key in self.l2_cache:
            value = self.l2_cache[key]
            self._promote_to_l1(key, value)
            return value
        return None

    def put(self, key: str, value: str):
        self.l1_cache[key] = value
        if len(self.l1_cache) > self.l1_max_size:
            oldest = next(iter(self.l1_cache))
            self._demote_to_l2(oldest, self.l1_cache.pop(oldest))

    def _promote_to_l1(self, key: str, value: str):
        self.l1_cache[key] = value
        self.l2_cache.pop(key, None)

    def _demote_to_l2(self, key: str, value: str):
        self.l2_cache[key] = value
        if len(self.l2_cache) > self.l2_max_size:
            self.l2_cache.pop(next(iter(self.l2_cache)))

    def get_stats(self) -> dict:
        return {"l1_size": len(self.l1_cache), "l2_size": len(self.l2_cache)}

cache = MultiLevelCache()
for i in range(150):
    cache.put(f"key-{i}", f"value-{i}")
print(f"多级缓存: {cache.get_stats()}")
print(f"查询 key-100: {cache.get('key-100')}")

缓存策略选择

根据场景选择缓存策略:KV缓存适合减少重复计算,前缀缓存适合多轮对话,语义缓存适合相似问题快速回复,多级缓存适合大规模部署。合理的缓存策略能显著提升LLM服务的性能和成本效率。