← 返回首页
🤖

RAG检索增强生成系统架构

📂 architecture ⏱ 5 min 957 words

RAG检索增强生成系统架构

RAG系统概述

RAG(Retrieval-Augmented Generation)通过结合检索和生成,让大语言模型能够访问外部知识库,减少幻觉并提供准确、有据可依的回答。RAG系统的核心流程包括:文档摄入与分块、向量化与索引、查询理解与检索、上下文组装与生成。

# RAG系统核心架构
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
from abc import ABC, abstractmethod

@dataclass
class Document:
    content: str
    metadata: Dict = field(default_factory=dict)
    doc_id: str = ""
    chunk_id: str = ""

@dataclass
class RetrievedChunk:
    document: Document
    score: float
    chunk_index: int = 0

@dataclass
class RAGResponse:
    answer: str
    sources: List[RetrievedChunk]
    confidence: float
    query: str

class RAGPipeline:
    def __init__(self, retriever, generator, reranker=None):
        self.retriever = retriever
        self.generator = generator
        self.reranker = reranker
    
    async def query(self, question: str, top_k: int = 5) -> RAGResponse:
        # 1. 检索相关文档
        retrieved = await self.retriever.retrieve(question, top_k=top_k * 2)
        
        # 2. 重排序(可选)
        if self.reranker:
            retrieved = self.reranker.rerank(question, retrieved, top_k=top_k)
        else:
            retrieved = retrieved[:top_k]
        
        # 3. 组装上下文
        context = self._build_context(retrieved)
        
        # 4. 生成回答
        answer = await self.generator.generate(question, context)
        
        return RAGResponse(
            answer=answer,
            sources=retrieved,
            confidence=self._calculate_confidence(retrieved),
            query=question
        )
    
    def _build_context(self, chunks: List[RetrievedChunk]) -> str:
        context_parts = []
        for i, chunk in enumerate(chunks):
            context_parts.append(
                f"[{i+1}] {chunk.document.content}"
            )
        return "\n\n".join(context_parts)
    
    def _calculate_confidence(self, chunks: List[RetrievedChunk]) -> float:
        if not chunks:
            return 0.0
        scores = [c.score for c in chunks]
        return sum(scores) / len(scores)

文档处理与分块策略

文档处理是RAG的基础,需要处理多种格式(PDF、Word、HTML等)并进行智能分块。分块策略直接影响检索质量,常见方法包括固定长度分块、语义分块和递归分块。

# 文档分块器
import re
from typing import List

class DocumentChunker:
    def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
    
    def chunk_text(self, text: str, metadata: Dict = None) -> List[Document]:
        """基于句子边界的智能分块"""
        sentences = re.split(r'(?<=[.!?])\s+', text)
        chunks = []
        current_chunk = []
        current_length = 0
        
        for sentence in sentences:
            sentence_length = len(sentence.split())
            
            if current_length + sentence_length > self.chunk_size and current_chunk:
                chunk_text = " ".join(current_chunk)
                chunks.append(Document(
                    content=chunk_text,
                    metadata=metadata or {},
                    chunk_id=f"chunk_{len(chunks)}"
                ))
                
                # 保留overlap部分
                overlap_sentences = []
                overlap_length = 0
                for s in reversed(current_chunk):
                    if overlap_length + len(s.split()) > self.chunk_overlap:
                        break
                    overlap_sentences.insert(0, s)
                    overlap_length += len(s.split())
                
                current_chunk = overlap_sentences
                current_length = overlap_length
            
            current_chunk.append(sentence)
            current_length += sentence_length
        
        # 处理剩余内容
        if current_chunk:
            chunks.append(Document(
                content=" ".join(current_chunk),
                metadata=metadata or {},
                chunk_id=f"chunk_{len(chunks)}"
            ))
        
        return chunks
    
    def chunk_markdown(self, text: str, metadata: Dict = None) -> List[Document]:
        """基于Markdown结构的分块"""
        sections = re.split(r'\n(?=#{1,3}\s)', text)
        chunks = []
        
        for section in sections:
            if len(section.strip()) == 0:
                continue
            
            # 如果section太长,进一步分块
            if len(section.split()) > self.chunk_size:
                sub_chunks = self.chunk_text(section, metadata)
                chunks.extend(sub_chunks)
            else:
                chunks.append(Document(
                    content=section.strip(),
                    metadata=metadata or {},
                    chunk_id=f"chunk_{len(chunks)}"
                ))
        
        return chunks

# 多格式文档加载器
class DocumentLoader:
    def __init__(self):
        self.loaders = {}
    
    def register_loader(self, file_type: str, loader_fn):
        self.loaders[file_type] = loader_fn
    
    def load(self, file_path: str) -> str:
        file_type = file_path.split(".")[-1].lower()
        loader = self.loaders.get(file_type)
        if loader:
            return loader(file_path)
        raise ValueError(f"Unsupported file type: {file_type}")

# 示例加载器
def load_pdf(file_path: str) -> str:
    # 使用PyPDF2或其他库
    return "PDF content..."

def load_markdown(file_path: str) -> str:
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

向量化与索引构建

将文档块转换为向量并建立索引是RAG的关键步骤。常用Embedding模型包括OpenAI、Cohere和开源的BGE模型。向量数据库选择需考虑性能、可扩展性和成本。

# 向量索引管理
from typing import List, Tuple
import numpy as np

class VectorIndex:
    def __init__(self, embedding_model, dimension: int = 1536):
        self.embedding_model = embedding_model
        self.dimension = dimension
        self.vectors = []
        self.documents = []
    
    def add_documents(self, documents: List[Document]):
        """批量添加文档到索引"""
        texts = [doc.content for doc in documents]
        embeddings = self._encode(texts)
        
        self.vectors.extend(embeddings)
        self.documents.extend(documents)
        
        print(f"Added {len(documents)} documents to index")
    
    def search(self, query: str, top_k: int = 5) -> List[RetrievedChunk]:
        """向量相似度搜索"""
        query_embedding = self._encode([query])[0]
        
        scores = []
        for i, vec in enumerate(self.vectors):
            score = self._cosine_similarity(query_embedding, vec)
            scores.append((i, score))
        
        # 按相似度排序
        scores.sort(key=lambda x: x[1], reverse=True)
        
        results = []
        for idx, score in scores[:top_k]:
            results.append(RetrievedChunk(
                document=self.documents[idx],
                score=score,
                chunk_index=idx
            ))
        
        return results
    
    def _encode(self, texts: List[str]) -> List[np.ndarray]:
        # 调用embedding模型
        return [np.random.rand(self.dimension) for _ in texts]
    
    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)

# 混合检索策略
class HybridRetriever:
    def __init__(self, vector_index, bm25_index=None):
        self.vector_index = vector_index
        self.bm25_index = bm25_index
    
    async def retrieve(self, query: str, top_k: int = 10) -> List[RetrievedChunk]:
        # 向量检索
        vector_results = self.vector_index.search(query, top_k)
        
        # BM25检索(可选)
        if self.bm25_index:
            bm25_results = self.bm25_index.search(query, top_k)
            # 合并结果
            return self._merge_results(vector_results, bm25_results)
        
        return vector_results
    
    def _merge_results(self, vec_results: List, bm25_results: List) -> List:
        # RRF (Reciprocal Rank Fusion) 合并
        merged = {}
        k = 60
        
        for rank, result in enumerate(vec_results):
            key = result.document.doc_id
            if key not in merged:
                merged[key] = {"result": result, "score": 0}
            merged[key]["score"] += 1 / (k + rank + 1)
        
        for rank, result in enumerate(bm25_results):
            key = result.document.doc_id
            if key not in merged:
                merged[key] = {"result": result, "score": 0}
            merged[key]["score"] += 1 / (k + rank + 1)
        
        sorted_results = sorted(
            merged.values(), 
            key=lambda x: x["score"], 
            reverse=True
        )
        
        return [item["result"] for item in sorted_results[:10]]

查询理解与改写

查询理解是提升RAG效果的关键环节,包括查询意图识别、查询扩展、多查询生成和查询路由。好的查询处理能显著提升检索召回率。

# 查询处理器
class QueryProcessor:
    def __init__(self, llm_client):
        self.llm = llm_client
    
    async def expand_query(self, query: str) -> List[str]:
        """查询扩展:生成多个相关查询"""
        prompt = f"""基于以下问题,生成3个相关的查询变体:
        
原始问题:{query}

请生成3个不同角度的查询:"""
        
        expanded = await self.llm.generate(prompt)
        return expanded.split("\n")
    
    async def decompose_query(self, query: str) -> List[str]:
        """复杂查询分解"""
        prompt = f"""将以下复杂问题分解为多个简单子问题:

问题:{query}

分解为子问题:"""
        
        decomposed = await self.llm.generate(prompt)
        return decomposed.split("\n")
    
    def detect_intent(self, query: str) -> str:
        """查询意图检测"""
        if any(word in query for word in ["如何", "怎么", "步骤"]):
            return "how_to"
        elif any(word in query for word in ["是什么", "定义", "含义"]):
            return "definition"
        elif any(word in query for word in ["比较", "区别", "差异"]):
            return "comparison"
        return "general"
    
    async def rewrite_query(self, query: str, context: str = "") -> str:
        """查询改写:使查询更适合检索"""
        prompt = f"""改写以下查询,使其更适合检索相关文档:

原始查询:{query}
{f'上下文:{context}' if context else ''}

改写后的查询:"""
        
        return await self.llm.generate(prompt)

# 查询路由
class QueryRouter:
    def __init__(self, retrievers: Dict[str, any]):
        self.retrievers = retrievers
    
    async def route(self, query: str) -> str:
        """根据查询内容路由到不同的检索器"""
        intent = QueryProcessor.detect_intent(self, query)
        
        if intent == "how_to":
            return "technical_docs"
        elif intent == "definition":
            return "knowledge_base"
        elif intent == "comparison":
            return "comparison_docs"
        return "general"
    
    async def retrieve(self, query: str) -> List[RetrievedChunk]:
        route = await self.route(query)
        retriever = self.retrievers.get(route)
        return await retriever.retrieve(query)

生成优化与答案质量

生成阶段需要确保答案准确、有据可依,并正确引用来源。优化策略包括提示工程、答案验证、幻觉检测和引用标注。

# RAG生成器
class RAGGenerator:
    def __init__(self, llm_client, prompt_template: str = None):
        self.llm = llm_client
        self.template = prompt_template or self._default_template()
    
    def _default_template(self) -> str:
        return """基于以下参考文档回答问题。如果文档中没有相关信息,请说明无法回答。

参考文档:
{context}

问题:{question}

请提供准确、简洁的回答,并标注信息来源:"""
    
    async def generate(self, question: str, context: str) -> str:
        prompt = self.template.format(context=context, question=question)
        return await self.llm.generate(prompt)
    
    async def generate_with_citations(self, question: str, 
                                     chunks: List[RetrievedChunk]) -> dict:
        """生成带引用的回答"""
        context = "\n\n".join([
            f"[{i+1}] {c.document.content}" 
            for i, c in enumerate(chunks)
        ])
        
        prompt = f"""基于以下文档回答问题,并在回答中标注引用编号。

文档:
{context}

问题:{question}

回答(使用[1]、[2]等标注来源):"""
        
        answer = await self.llm.generate(prompt)
        
        return {
            "answer": answer,
            "citations": self._extract_citations(answer),
            "sources": chunks
        }
    
    def _extract_citations(self, text: str) -> List[int]:
        import re
        citations = re.findall(r'\[(\d+)\]', text)
        return [int(c) for c in citations]

# 幻觉检测
class HallucinationDetector:
    def __init__(self, llm_client):
        self.llm = llm_client
    
    async def detect(self, answer: str, context: str) -> dict:
        prompt = f"""检查以下回答是否基于提供的文档,标记可能的幻觉。

文档内容:
{context}

回答:
{answer}

请列出回答中任何无法从文档中验证的信息:"""
        
        analysis = await self.llm.generate(prompt)
        
        return {
            "has_hallucination": "无法验证" in analysis,
            "analysis": analysis,
            "confidence": 0.85 if "无法验证" not in analysis else 0.3
        }