RAG检索增强生成系统架构
RAG检索增强生成系统架构
RAG系统概述
RAG(Retrieval-Augmented Generation)通过结合检索和生成,让大语言模型能够访问外部知识库,减少幻觉并提供准确、有据可依的回答。RAG系统的核心流程包括:文档摄入与分块、向量化与索引、查询理解与检索、上下文组装与生成。
# RAG系统核心架构
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
from abc import ABC, abstractmethod
@dataclass
class Document:
content: str
metadata: Dict = field(default_factory=dict)
doc_id: str = ""
chunk_id: str = ""
@dataclass
class RetrievedChunk:
document: Document
score: float
chunk_index: int = 0
@dataclass
class RAGResponse:
answer: str
sources: List[RetrievedChunk]
confidence: float
query: str
class RAGPipeline:
def __init__(self, retriever, generator, reranker=None):
self.retriever = retriever
self.generator = generator
self.reranker = reranker
async def query(self, question: str, top_k: int = 5) -> RAGResponse:
# 1. 检索相关文档
retrieved = await self.retriever.retrieve(question, top_k=top_k * 2)
# 2. 重排序(可选)
if self.reranker:
retrieved = self.reranker.rerank(question, retrieved, top_k=top_k)
else:
retrieved = retrieved[:top_k]
# 3. 组装上下文
context = self._build_context(retrieved)
# 4. 生成回答
answer = await self.generator.generate(question, context)
return RAGResponse(
answer=answer,
sources=retrieved,
confidence=self._calculate_confidence(retrieved),
query=question
)
def _build_context(self, chunks: List[RetrievedChunk]) -> str:
context_parts = []
for i, chunk in enumerate(chunks):
context_parts.append(
f"[{i+1}] {chunk.document.content}"
)
return "\n\n".join(context_parts)
def _calculate_confidence(self, chunks: List[RetrievedChunk]) -> float:
if not chunks:
return 0.0
scores = [c.score for c in chunks]
return sum(scores) / len(scores)
文档处理与分块策略
文档处理是RAG的基础,需要处理多种格式(PDF、Word、HTML等)并进行智能分块。分块策略直接影响检索质量,常见方法包括固定长度分块、语义分块和递归分块。
# 文档分块器
import re
from typing import List
class DocumentChunker:
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def chunk_text(self, text: str, metadata: Dict = None) -> List[Document]:
"""基于句子边界的智能分块"""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence.split())
if current_length + sentence_length > self.chunk_size and current_chunk:
chunk_text = " ".join(current_chunk)
chunks.append(Document(
content=chunk_text,
metadata=metadata or {},
chunk_id=f"chunk_{len(chunks)}"
))
# 保留overlap部分
overlap_sentences = []
overlap_length = 0
for s in reversed(current_chunk):
if overlap_length + len(s.split()) > self.chunk_overlap:
break
overlap_sentences.insert(0, s)
overlap_length += len(s.split())
current_chunk = overlap_sentences
current_length = overlap_length
current_chunk.append(sentence)
current_length += sentence_length
# 处理剩余内容
if current_chunk:
chunks.append(Document(
content=" ".join(current_chunk),
metadata=metadata or {},
chunk_id=f"chunk_{len(chunks)}"
))
return chunks
def chunk_markdown(self, text: str, metadata: Dict = None) -> List[Document]:
"""基于Markdown结构的分块"""
sections = re.split(r'\n(?=#{1,3}\s)', text)
chunks = []
for section in sections:
if len(section.strip()) == 0:
continue
# 如果section太长,进一步分块
if len(section.split()) > self.chunk_size:
sub_chunks = self.chunk_text(section, metadata)
chunks.extend(sub_chunks)
else:
chunks.append(Document(
content=section.strip(),
metadata=metadata or {},
chunk_id=f"chunk_{len(chunks)}"
))
return chunks
# 多格式文档加载器
class DocumentLoader:
def __init__(self):
self.loaders = {}
def register_loader(self, file_type: str, loader_fn):
self.loaders[file_type] = loader_fn
def load(self, file_path: str) -> str:
file_type = file_path.split(".")[-1].lower()
loader = self.loaders.get(file_type)
if loader:
return loader(file_path)
raise ValueError(f"Unsupported file type: {file_type}")
# 示例加载器
def load_pdf(file_path: str) -> str:
# 使用PyPDF2或其他库
return "PDF content..."
def load_markdown(file_path: str) -> str:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
向量化与索引构建
将文档块转换为向量并建立索引是RAG的关键步骤。常用Embedding模型包括OpenAI、Cohere和开源的BGE模型。向量数据库选择需考虑性能、可扩展性和成本。
# 向量索引管理
from typing import List, Tuple
import numpy as np
class VectorIndex:
def __init__(self, embedding_model, dimension: int = 1536):
self.embedding_model = embedding_model
self.dimension = dimension
self.vectors = []
self.documents = []
def add_documents(self, documents: List[Document]):
"""批量添加文档到索引"""
texts = [doc.content for doc in documents]
embeddings = self._encode(texts)
self.vectors.extend(embeddings)
self.documents.extend(documents)
print(f"Added {len(documents)} documents to index")
def search(self, query: str, top_k: int = 5) -> List[RetrievedChunk]:
"""向量相似度搜索"""
query_embedding = self._encode([query])[0]
scores = []
for i, vec in enumerate(self.vectors):
score = self._cosine_similarity(query_embedding, vec)
scores.append((i, score))
# 按相似度排序
scores.sort(key=lambda x: x[1], reverse=True)
results = []
for idx, score in scores[:top_k]:
results.append(RetrievedChunk(
document=self.documents[idx],
score=score,
chunk_index=idx
))
return results
def _encode(self, texts: List[str]) -> List[np.ndarray]:
# 调用embedding模型
return [np.random.rand(self.dimension) for _ in texts]
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)
# 混合检索策略
class HybridRetriever:
def __init__(self, vector_index, bm25_index=None):
self.vector_index = vector_index
self.bm25_index = bm25_index
async def retrieve(self, query: str, top_k: int = 10) -> List[RetrievedChunk]:
# 向量检索
vector_results = self.vector_index.search(query, top_k)
# BM25检索(可选)
if self.bm25_index:
bm25_results = self.bm25_index.search(query, top_k)
# 合并结果
return self._merge_results(vector_results, bm25_results)
return vector_results
def _merge_results(self, vec_results: List, bm25_results: List) -> List:
# RRF (Reciprocal Rank Fusion) 合并
merged = {}
k = 60
for rank, result in enumerate(vec_results):
key = result.document.doc_id
if key not in merged:
merged[key] = {"result": result, "score": 0}
merged[key]["score"] += 1 / (k + rank + 1)
for rank, result in enumerate(bm25_results):
key = result.document.doc_id
if key not in merged:
merged[key] = {"result": result, "score": 0}
merged[key]["score"] += 1 / (k + rank + 1)
sorted_results = sorted(
merged.values(),
key=lambda x: x["score"],
reverse=True
)
return [item["result"] for item in sorted_results[:10]]
查询理解与改写
查询理解是提升RAG效果的关键环节,包括查询意图识别、查询扩展、多查询生成和查询路由。好的查询处理能显著提升检索召回率。
# 查询处理器
class QueryProcessor:
def __init__(self, llm_client):
self.llm = llm_client
async def expand_query(self, query: str) -> List[str]:
"""查询扩展:生成多个相关查询"""
prompt = f"""基于以下问题,生成3个相关的查询变体:
原始问题:{query}
请生成3个不同角度的查询:"""
expanded = await self.llm.generate(prompt)
return expanded.split("\n")
async def decompose_query(self, query: str) -> List[str]:
"""复杂查询分解"""
prompt = f"""将以下复杂问题分解为多个简单子问题:
问题:{query}
分解为子问题:"""
decomposed = await self.llm.generate(prompt)
return decomposed.split("\n")
def detect_intent(self, query: str) -> str:
"""查询意图检测"""
if any(word in query for word in ["如何", "怎么", "步骤"]):
return "how_to"
elif any(word in query for word in ["是什么", "定义", "含义"]):
return "definition"
elif any(word in query for word in ["比较", "区别", "差异"]):
return "comparison"
return "general"
async def rewrite_query(self, query: str, context: str = "") -> str:
"""查询改写:使查询更适合检索"""
prompt = f"""改写以下查询,使其更适合检索相关文档:
原始查询:{query}
{f'上下文:{context}' if context else ''}
改写后的查询:"""
return await self.llm.generate(prompt)
# 查询路由
class QueryRouter:
def __init__(self, retrievers: Dict[str, any]):
self.retrievers = retrievers
async def route(self, query: str) -> str:
"""根据查询内容路由到不同的检索器"""
intent = QueryProcessor.detect_intent(self, query)
if intent == "how_to":
return "technical_docs"
elif intent == "definition":
return "knowledge_base"
elif intent == "comparison":
return "comparison_docs"
return "general"
async def retrieve(self, query: str) -> List[RetrievedChunk]:
route = await self.route(query)
retriever = self.retrievers.get(route)
return await retriever.retrieve(query)
生成优化与答案质量
生成阶段需要确保答案准确、有据可依,并正确引用来源。优化策略包括提示工程、答案验证、幻觉检测和引用标注。
# RAG生成器
class RAGGenerator:
def __init__(self, llm_client, prompt_template: str = None):
self.llm = llm_client
self.template = prompt_template or self._default_template()
def _default_template(self) -> str:
return """基于以下参考文档回答问题。如果文档中没有相关信息,请说明无法回答。
参考文档:
{context}
问题:{question}
请提供准确、简洁的回答,并标注信息来源:"""
async def generate(self, question: str, context: str) -> str:
prompt = self.template.format(context=context, question=question)
return await self.llm.generate(prompt)
async def generate_with_citations(self, question: str,
chunks: List[RetrievedChunk]) -> dict:
"""生成带引用的回答"""
context = "\n\n".join([
f"[{i+1}] {c.document.content}"
for i, c in enumerate(chunks)
])
prompt = f"""基于以下文档回答问题,并在回答中标注引用编号。
文档:
{context}
问题:{question}
回答(使用[1]、[2]等标注来源):"""
answer = await self.llm.generate(prompt)
return {
"answer": answer,
"citations": self._extract_citations(answer),
"sources": chunks
}
def _extract_citations(self, text: str) -> List[int]:
import re
citations = re.findall(r'\[(\d+)\]', text)
return [int(c) for c in citations]
# 幻觉检测
class HallucinationDetector:
def __init__(self, llm_client):
self.llm = llm_client
async def detect(self, answer: str, context: str) -> dict:
prompt = f"""检查以下回答是否基于提供的文档,标记可能的幻觉。
文档内容:
{context}
回答:
{answer}
请列出回答中任何无法从文档中验证的信息:"""
analysis = await self.llm.generate(prompt)
return {
"has_hallucination": "无法验证" in analysis,
"analysis": analysis,
"confidence": 0.85 if "无法验证" not in analysis else 0.3
}