← 返回首页
🧠

分布外检测:识别异常输入

📂 llm ⏱ 4 min 681 words

--- title: "分布外检测:识别异常输入" description: "检测LLM的分布外输入,防止模型处理未知数据" tags: ["分布外检测", "OOD", "异常检测", "LLM", "输入验证"] category: "llm" icon: "🔍"

分布外检测:识别异常输入

OOD检测概述

分布外(Out-of-Distribution, OOD)检测是识别与训练数据分布不同的输入,防止模型处理未知或异常数据。

检测方法

1. 基于距离的检测

import numpy as np
import torch
from typing import List, Dict, Tuple
from sklearn.metrics.pairwise import cosine_distances

class DistanceBasedOODDetector:
    """基于距离的OOD检测"""
    
    def __init__(self, model, tokenizer, reference_embeddings: np.ndarray):
        self.model = model
        self.tokenizer = tokenizer
        self.reference_embeddings = reference_embeddings
        self.threshold = None
    
    def set_threshold(self, percentile: float = 95):
        """设置阈值"""
        distances = self._compute_distances(self.reference_embeddings)
        self.threshold = np.percentile(distances, percentile)
    
    def _compute_distances(self, embeddings: np.ndarray) -> np.ndarray:
        """计算到参考数据的距离"""
        distances = cosine_distances(embeddings, self.reference_embeddings)
        return distances.min(axis=1)
    
    def _get_embedding(self, text: str) -> np.ndarray:
        """获取文本嵌入"""
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            embedding = outputs.hidden_states[-1].mean(dim=1).squeeze().numpy()
        return embedding
    
    def detect(self, text: str) -> Dict:
        """检测单个样本"""
        embedding = self._get_embedding(text)
        distance = self._compute_distances(embedding.reshape(1, -1))[0]
        
        is_ood = distance > self.threshold if self.threshold else False
        
        return {
            "text": text,
            "distance": float(distance),
            "threshold": self.threshold,
            "is_ood": is_ood,
            "confidence": min(distance / self.threshold, 2.0) if self.threshold else 0
        }
    
    def detect_batch(self, texts: List[str]) -> List[Dict]:
        """批量检测"""
        embeddings = np.array([self._get_embedding(text) for text in texts])
        distances = self._compute_distances(embeddings)
        
        results = []
        for text, distance in zip(texts, distances):
            is_ood = distance > self.threshold if self.threshold else False
            results.append({
                "text": text,
                "distance": float(distance),
                "is_ood": is_ood
            })
        
        return results

2. 基于不确定性的检测

class UncertaintyOODDetector:
    """基于不确定性的OOD检测"""
    
    def __init__(self, model, tokenizer, n_samples: int = 10):
        self.model = model
        self.tokenizer = tokenizer
        self.n_samples = n_samples
    
    def compute_uncertainty(self, text: str) -> Dict:
        """计算不确定性"""
        inputs = self.tokenizer(text, return_tensors="pt")
        
        # MC Dropout采样
        predictions = []
        self.model.train()  # 启用dropout
        
        for _ in range(self.n_samples):
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits
                probs = torch.softmax(logits, dim=-1)
                predictions.append(probs.numpy())
        
        self.model.eval()
        
        predictions = np.array(predictions)
        
        # 计算不确定性指标
        mean_prediction = predictions.mean(axis=0)
        
        # 熵
        entropy = -np.sum(mean_prediction * np.log(mean_prediction + 1e-10), axis=-1)
        
        # 互信息
        expected_entropy = -np.mean(np.sum(predictions * np.log(predictions + 1e-10), axis=-1), axis=0)
        mutual_info = entropy - expected_entropy
        
        # 预测方差
        prediction_variance = np.var(predictions, axis=0).mean()
        
        return {
            "entropy": float(entropy.mean()),
            "mutual_information": float(mutual_info.mean()),
            "prediction_variance": float(prediction_variance),
            "mean_confidence": float(mean_prediction.max(axis=-1).mean())
        }
    
    def detect(self, text: str, uncertainty_threshold: float = 1.0) -> Dict:
        """检测OOD样本"""
        uncertainty = self.compute_uncertainty(text)
        
        is_ood = uncertainty["entropy"] > uncertainty_threshold
        
        return {
            "text": text,
            "uncertainty": uncertainty,
            "is_ood": is_ood,
            "confidence": uncertainty["entropy"] / uncertainty_threshold if uncertainty_threshold else 0
        }

3. 基于重构的检测

class ReconstructionOODDetector:
    """基于重构的OOD检测"""
    
    def __init__(self, autoencoder, tokenizer):
        self.autoencoder = autoencoder
        self.tokenizer = tokenizer
        self.reconstruction_threshold = None
    
    def set_threshold(self, reference_texts: List[str], percentile: float = 95):
        """设置阈值"""
        recon_errors = []
        for text in reference_texts:
            error = self._compute_reconstruction_error(text)
            recon_errors.append(error)
        
        self.reconstruction_threshold = np.percentile(recon_errors, percentile)
    
    def _compute_reconstruction_error(self, text: str) -> float:
        """计算重构误差"""
        inputs = self.tokenizer(text, return_tensors="pt")
        
        with torch.no_grad():
            reconstructed = self.autoencoder(inputs["input_ids"])
            error = torch.nn.functional.mse_loss(
                reconstructed["reconstructions"],
                inputs["input_ids"].float()
            ).item()
        
        return error
    
    def detect(self, text: str) -> Dict:
        """检测OOD样本"""
        recon_error = self._compute_reconstruction_error(text)
        
        is_ood = recon_error > self.reconstruction_threshold if self.reconstruction_threshold else False
        
        return {
            "text": text,
            "reconstruction_error": recon_error,
            "threshold": self.reconstruction_threshold,
            "is_ood": is_ood,
            "confidence": recon_error / self.reconstruction_threshold if self.reconstruction_threshold else 0
        }

高级检测器

1. 集成OOD检测

class EnsembleOODDetector:
    """集成OOD检测"""
    
    def __init__(self, detectors: List):
        self.detectors = detectors
        self.weights = [1.0 / len(detectors)] * len(detectors)
    
    def detect(self, text: str, voting_threshold: float = 0.5) -> Dict:
        """集成检测"""
        predictions = []
        scores = []
        
        for detector in self.detectors:
            result = detector.detect(text)
            predictions.append(result["is_ood"])
            scores.append(result.get("confidence", 0))
        
        # 加权投票
        weighted_score = sum(p * w for p, w in zip(scores, self.weights))
        is_ood = weighted_score > voting_threshold
        
        return {
            "text": text,
            "individual_results": [
                {"detector": type(d).__name__, "is_ood": pred, "score": score}
                for d, pred, score in zip(self.detectors, predictions, scores)
            ],
            "ensemble_score": weighted_score,
            "is_ood": is_ood,
            "confidence": weighted_score
        }

2. 自适应阈值

class AdaptiveThresholdOODDetector:
    """自适应阈值OOD检测"""
    
    def __init__(self, base_detector, window_size: int = 100):
        self.base_detector = base_detector
        self.window_size = window_size
        self.score_history = []
        self.current_threshold = None
    
    def update_threshold(self):
        """更新阈值"""
        if len(self.score_history) >= self.window_size:
            recent_scores = self.score_history[-self.window_size:]
            self.current_threshold = np.percentile(recent_scores, 95)
    
    def detect(self, text: str) -> Dict:
        """检测并更新阈值"""
        result = self.base_detector.detect(text)
        
        # 更新历史
        self.score_history.append(result.get("distance", 0))
        if len(self.score_history) > self.window_size * 2:
            self.score_history = self.score_history[-self.window_size:]
        
        # 更新阈值
        self.update_threshold()
        
        # 使用自适应阈值判断
        if self.current_threshold:
            result["is_ood"] = result.get("distance", 0) > self.current_threshold
            result["adaptive_threshold"] = self.current_threshold
        
        return result

评估工具

class OODEvaluator:
    """OOD检测评估"""
    
    @staticmethod
    def evaluate(detector, in_distribution_texts: List[str], ood_texts: List[str]) -> Dict:
        """评估OOD检测性能"""
        # 检测ID样本
        id_results = [detector.detect(text) for text in in_distribution_texts]
        id_scores = [r.get("distance", 0) for r in id_results]
        
        # 检测OOD样本
        ood_results = [detector.detect(text) for text in ood_texts]
        ood_scores = [r.get("distance", 0) for r in ood_results]
        
        # 计算指标
        from sklearn.metrics import roc_auc_score, precision_recall_curve
        
        labels = [0] * len(id_scores) + [1] * len(ood_scores)
        scores = id_scores + ood_scores
        
        try:
            auroc = roc_auc_score(labels, scores)
        except:
            auroc = 0.5
        
        return {
            "auroc": auroc,
            "id_mean_score": np.mean(id_scores),
            "ood_mean_score": np.mean(ood_scores),
            "separation": np.mean(ood_scores) - np.mean(id_scores)
        }

最佳实践

  1. 多方法集成:结合多种OOD检测方法
  2. 阈值调优:根据应用需求调整检测阈值
  3. 持续监控:在生产环境中持续监控OOD比例
  4. 优雅处理:对检测到的OOD样本进行优雅处理

总结

OOD检测是确保LLM安全可靠的重要环节。通过识别异常输入,可以防止模型处理未知数据,提高系统鲁棒性。