← 返回首页
🤖

Embedding与向量检索架构:ANN索引优化

📂 architecture ⏱ 6 min 1199 words

Embedding与向量检索架构:ANN索引优化

向量检索系统概述

向量检索是将高维向量(如文本、图像的Embedding)进行相似度搜索的技术。传统精确搜索在高维空间中效率极低,需要使用近似最近邻(ANN)算法在精度和速度之间取得平衡。

# 向量检索引擎核心
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict
import numpy as np
from abc import ABC, abstractmethod

@dataclass
class VectorRecord:
    id: str
    vector: np.ndarray
    metadata: Dict = field(default_factory=dict)

@dataclass
class SearchResult:
    id: str
    score: float
    vector: Optional[np.ndarray] = None
    metadata: Dict = field(default_factory=dict)

class VectorIndex(ABC):
    @abstractmethod
    def add(self, record: VectorRecord):
        pass
    
    @abstractmethod
    def search(self, query: np.ndarray, top_k: int) -> List[SearchResult]:
        pass
    
    @abstractmethod
    def delete(self, id: str) -> bool:
        pass
    
    @abstractmethod
    def size(self) -> int:
        pass

class LinearIndex(VectorIndex):
    """暴力搜索索引(用于小规模数据或作为baseline)"""
    
    def __init__(self, dimension: int):
        self.dimension = dimension
        self.records: List[VectorRecord] = []
        self.id_index: Dict[str, int] = {}
    
    def add(self, record: VectorRecord):
        if len(record.vector) != self.dimension:
            raise ValueError(f"Vector dimension mismatch: expected {self.dimension}")
        
        idx = len(self.records)
        self.records.append(record)
        self.id_index[record.id] = idx
    
    def search(self, query: np.ndarray, top_k: int) -> List[SearchResult]:
        scores = []
        for record in self.records:
            score = self._cosine_similarity(query, record.vector)
            scores.append((record, score))
        
        scores.sort(key=lambda x: x[1], reverse=True)
        
        results = []
        for record, score in scores[:top_k]:
            results.append(SearchResult(
                id=record.id,
                score=score,
                vector=record.vector,
                metadata=record.metadata
            ))
        
        return results
    
    def delete(self, id: str) -> bool:
        if id in self.id_index:
            idx = self.id_index.pop(id)
            self.records[idx] = None  # 软删除
            return True
        return False
    
    def size(self) -> int:
        return len([r for r in self.records if r is not None])
    
    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)

HNSW索引算法

HNSW(Hierarchical Navigable Small World)是目前最流行的ANN算法之一,通过构建层次化的小世界图实现高效的近似搜索。其特点是在精度和速度之间提供了优秀的平衡。

# HNSW索引实现
import heapq
from collections import defaultdict

class HNSWIndex:
    def __init__(self, dimension: int, max_connections: int = 16,
                 max_layers: int = 16, ef_construction: int = 200):
        self.dimension = dimension
        self.M = max_connections
        self.max_layers = max_layers
        self.ef_construction = ef_construction
        
        self.vectors = {}
        self.graph = defaultdict(dict)  # node_id -> {neighbor_id: distance}
        self.layers = defaultdict(list)  # layer -> [node_ids]
        self.entry_point = None
        self.current_layer = 0
    
    def add(self, record: VectorRecord):
        node_id = record.id
        self.vectors[node_id] = record.vector
        
        # 确定节点所在层级
        level = self._random_level()
        
        # 如果是第一个节点
        if self.entry_point is None:
            self.entry_point = node_id
            self.layers[0].append(node_id)
            return
        
        # 从顶层开始搜索
        current = self.entry_point
        for layer in range(self.current_layer, level, -1):
            current = self._search_layer(current, record.vector, 1, layer)[0]
        
        # 在目标层级插入节点
        for layer in range(min(level, self.current_layer), -1, -1):
            candidates = self._search_layer(current, record.vector, 
                                           self.ef_construction, layer)
            
            # 选择最好的M个连接
            neighbors = self._select_neighbors(record.vector, candidates, self.M)
            
            # 添加双向连接
            self.graph[node_id][layer] = {}
            self.graph[current][layer][node_id] = self._distance(
                record.vector, self.vectors[current]
            )
            
            for neighbor_id, dist in neighbors.items():
                self.graph[node_id][layer][neighbor_id] = dist
                self.graph[neighbor_id][layer][node_id] = dist
            
            self.layers[layer].append(node_id)
            current = neighbors[0] if neighbors else current
        
        # 更新入口点
        if level > self.current_layer:
            self.current_layer = level
            self.entry_point = node_id
    
    def search(self, query: np.ndarray, top_k: int, 
               ef_search: int = 50) -> List[SearchResult]:
        if self.entry_point is None:
            return []
        
        current = self.entry_point
        
        # 从顶层向下搜索
        for layer in range(self.current_layer, 0, -1):
            candidates = self._search_layer(current, query, 1, layer)
            current = candidates[0] if candidates else current
        
        # 在底层搜索
        candidates = self._search_layer(current, query, ef_search, 0)
        
        # 返回top_k结果
        results = []
        for node_id, dist in sorted(candidates.items(), key=lambda x: x[1])[:top_k]:
            results.append(SearchResult(
                id=node_id,
                score=1.0 / (1.0 + dist),
                vector=self.vectors.get(node_id),
                metadata={}
            ))
        
        return results
    
    def _search_layer(self, entry_id: str, query: np.ndarray, 
                     ef: int, layer: int) -> Dict[str, float]:
        """在指定层级搜索"""
        candidates = {}
        visited = {entry_id}
        
        # 计算入口点距离
        entry_dist = self._distance(query, self.vectors[entry_id])
        candidates[entry_id] = entry_dist
        
        while candidates:
            # 选择最近的未访问节点
            current_id = min(candidates, key=candidates.get)
            current_dist = candidates[current_id]
            
            if current_dist > candidates.get(current_id, float('inf')):
                break
            
            visited.add(current_id)
            
            # 遍历邻居
            for neighbor_id, neighbor_dist in self.graph.get(current_id, {}).get(layer, {}).items():
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    
                    if neighbor_id in self.vectors:
                        dist = self._distance(query, self.vectors[neighbor_id])
                        
                        if dist < candidates.get(neighbor_id, float('inf')):
                            candidates[neighbor_id] = dist
        
        return candidates
    
    def _select_neighbors(self, query: np.ndarray, 
                         candidates: Dict[str, float], 
                         M: int) -> Dict[str, float]:
        """选择M个最近邻"""
        sorted_candidates = sorted(candidates.items(), key=lambda x: x[1])
        return dict(sorted_candidates[:M])
    
    def _random_level(self) -> int:
        """随机生成层级"""
        import random
        level = 0
        while random.random() < 0.5 and level < self.max_layers:
            level += 1
        return level
    
    def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
        """计算欧氏距离"""
        return float(np.linalg.norm(a - b))

IVF-PQ索引算法

IVF-PQ(Inverted File Index with Product Quantization)通过向量量化压缩和倒排索引加速搜索,适合超大规模数据集。IVF将向量空间划分为多个聚类,PQ将高维向量压缩为紧凑编码。

# IVF-PQ索引实现
class IVFPQIndex:
    def __init__(self, dimension: int, n_clusters: int = 1024,
                 n_subquantizers: int = 8, n_bits: int = 8):
        self.dimension = dimension
        self.n_clusters = n_clusters
        self.n_subquantizers = n_subquantizers
        self.n_bits = n_bits
        
        self.centroids = None  # K-means聚类中心
        self.subquantizers = []  # 子量化器码本
        self.inverted_lists = defaultdict(list)  # cluster_id -> [(id, pq_code)]
        self.vectors = {}
    
    def train(self, vectors: np.ndarray):
        """训练量化器"""
        from sklearn.cluster import MiniBatchKMeans
        
        # 训练IVF聚类
        kmeans = MiniBatchKMeans(n_clusters=self.n_clusters, random_state=42)
        kmeans.fit(vectors)
        self.centroids = kmeans.cluster_centers_
        
        # 训练PQ子量化器
        sub_dim = self.dimension // self.n_subquantizers
        for i in range(self.n_subquantizers):
            sub_vectors = vectors[:, i*sub_dim:(i+1)*sub_dim]
            kmeans_pq = MiniBatchKMeans(n_clusters=2**self.n_bits, random_state=42)
            kmeans_pq.fit(sub_vectors)
            self.subquantizers.append(kmeans_pq.cluster_centers_)
    
    def add(self, record: VectorRecord):
        vector = record.vector
        
        # 找到最近聚类
        cluster_id = self._find_cluster(vector)
        
        # PQ编码
        pq_code = self._pq_encode(vector)
        
        self.inverted_lists[cluster_id].append((record.id, pq_code))
        self.vectors[record.id] = vector
    
    def search(self, query: np.ndarray, top_k: int, 
               n_probe: int = 10) -> List[SearchResult]:
        # 找到最近的n_probe个聚类
        cluster_distances = []
        for i, centroid in enumerate(self.centroids):
            dist = np.linalg.norm(query - centroid)
            cluster_distances.append((i, dist))
        
        cluster_distances.sort(key=lambda x: x[1])
        probe_clusters = [c[0] for c in cluster_distances[:n_probe]]
        
        # 在候选聚类中搜索
        candidates = []
        for cluster_id in probe_clusters:
            for doc_id, pq_code in self.inverted_lists[cluster_id]:
                # 使用PQ近似距离
                approx_dist = self._pq_distance(query, pq_code)
                candidates.append((doc_id, approx_dist))
        
        # 排序并返回top_k
        candidates.sort(key=lambda x: x[1])
        
        results = []
        for doc_id, dist in candidates[:top_k]:
            results.append(SearchResult(
                id=doc_id,
                score=1.0 / (1.0 + dist),
                vector=self.vectors.get(doc_id),
                metadata={}
            ))
        
        return results
    
    def _find_cluster(self, vector: np.ndarray) -> int:
        """找到最近聚类"""
        distances = np.linalg.norm(self.centroids - vector, axis=1)
        return int(np.argmin(distances))
    
    def _pq_encode(self, vector: np.ndarray) -> np.ndarray:
        """PQ编码"""
        sub_dim = self.dimension // self.n_subquantizers
        codes = []
        
        for i in range(self.n_subquantizers):
            sub_vector = vector[i*sub_dim:(i+1)*sub_dim]
            distances = np.linalg.norm(self.subquantizers[i] - sub_vector, axis=1)
            codes.append(int(np.argmin(distances)))
        
        return np.array(codes)
    
    def _pq_distance(self, query: np.ndarray, pq_code: np.ndarray) -> float:
        """使用PQ编码计算近似距离"""
        sub_dim = self.dimension // self.n_subquantizers
        distance = 0.0
        
        for i in range(self.n_subquantizers):
            sub_query = query[i*sub_dim:(i+1)*sub_dim]
            code = pq_code[i]
            centroid = self.subquantizers[i][code]
            distance += np.linalg.norm(sub_query - centroid) ** 2
        
        return np.sqrt(distance)

向量数据库集成

生产环境通常使用专业的向量数据库(Milvus、Pinecone、Weaviate等),它们提供分布式架构、高可用性和丰富的查询功能。

# 向量数据库适配器
class VectorDBAdapter:
    def __init__(self, db_type: str, connection_config: Dict):
        self.db_type = db_type
        self.config = connection_config
        self.client = None
    
    def connect(self):
        """连接数据库"""
        if self.db_type == "milvus":
            self._connect_milvus()
        elif self.db_type == "pinecone":
            self._connect_pinecone()
        elif self.db_type == "weaviate":
            self._connect_weaviate()
    
    def _connect_milvus(self):
        from pymilvus import connections
        connections.connect(
            alias="default",
            host=self.config["host"],
            port=self.config["port"]
        )
    
    def _connect_pinecone(self):
        import pinecone
        pinecone.init(
            api_key=self.config["api_key"],
            environment=self.config["environment"]
        )
    
    def _connect_weaviate(self):
        import weaviate
        self.client = weaviate.Client(
            url=self.config["url"]
        )
    
    def create_collection(self, collection_name: str, dimension: int):
        """创建集合"""
        if self.db_type == "milvus":
            from pymilvus import CollectionSchema, FieldSchema, DataType, Collection
            
            fields = [
                FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
                FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
                FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535)
            ]
            schema = CollectionSchema(fields)
            collection = Collection(collection_name, schema)
        
        elif self.db_type == "pinecone":
            import pinecone
            pinecone.create_index(
                name=collection_name,
                dimension=dimension,
                metric="cosine"
            )
    
    def upsert(self, collection_name: str, vectors: List[Dict]):
        """插入向量"""
        if self.db_type == "milvus":
            from pymilvus import Collection
            collection = Collection(collection_name)
            collection.insert([
                [v["id"] for v in vectors],
                [v["vector"] for v in vectors],
                [str(v.get("metadata", {})) for v in vectors]
            ])
        
        elif self.db_type == "pinecone":
            import pinecone
            index = pinecone.Index(collection_name)
            index.upsert(vectors=[
                (v["id"], v["vector"], v.get("metadata", {}))
                for v in vectors
            ])
    
    def query(self, collection_name: str, vector: np.ndarray, 
              top_k: int = 10, filter_expr: str = None) -> List[Dict]:
        """查询向量"""
        if self.db_type == "milvus":
            from pymilvus import Collection
            collection = Collection(collection_name)
            collection.load()
            
            results = collection.search(
                data=[vector.tolist()],
                anns_field="embedding",
                param={"metric_type": "COSINE", "params": {"nprobe": 16}},
                limit=top_k,
                expr=filter_expr
            )
            
            return [
                {"id": hit.id, "score": hit.score, "metadata": hit.entity.get("metadata")}
                for hit in results[0]
            ]
        
        elif self.db_type == "pinecone":
            import pinecone
            index = pinecone.Index(collection_name)
            
            results = index.query(
                vector=vector.tolist(),
                top_k=top_k,
                filter=filter_expr,
                include_metadata=True
            )
            
            return [
                {"id": match["id"], "score": match["score"], "metadata": match.get("metadata")}
                for match in results["matches"]
            ]

多模态Embedding架构

多模态Embedding将不同模态(文本、图像、音频)映射到统一的向量空间,支持跨模态检索。常用模型包括CLIP、ALIGN和多模态BERT。

# 多模态Embedding模型
import torch
import torch.nn as nn

class MultiModalEncoder(nn.Module):
    def __init__(self, text_dim=768, image_dim=2048, output_dim=512):
        super().__init__()
        
        # 文本编码器
        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim)
        )
        
        # 图像编码器
        self.image_encoder = nn.Sequential(
            nn.Linear(image_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim)
        )
        
        # 投影层
        self.projection = nn.Linear(output_dim, output_dim)
    
    def encode_text(self, text_features: torch.Tensor) -> torch.Tensor:
        """编码文本"""
        return self.projection(self.text_encoder(text_features))
    
    def encode_image(self, image_features: torch.Tensor) -> torch.Tensor:
        """编码图像"""
        return self.projection(self.image_encoder(image_features))
    
    def forward(self, text_features=None, image_features=None):
        """前向传播"""
        if text_features is not None:
            return self.encode_text(text_features)
        elif image_features is not None:
            return self.encode_image(image_features)
        return None

# CLIP风格的对比学习
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, text_embeddings, image_embeddings):
        """对比学习损失"""
        # 归一化
        text_embeddings = nn.functional.normalize(text_embeddings, dim=1)
        image_embeddings = nn.functional.normalize(image_embeddings, dim=1)
        
        # 计算相似度矩阵
        logits = torch.matmul(text_embeddings, image_embeddings.T) / self.temperature
        
        # 标签:对角线为正样本
        labels = torch.arange(len(logits), device=logits.device)
        
        # 双向对比损失
        loss_t2i = nn.functional.cross_entropy(logits, labels)
        loss_i2t = nn.functional.cross_entropy(logits.T, labels)
        
        return (loss_t2i + loss_i2t) / 2