分布外检测:识别异常输入
--- title: "分布外检测:识别异常输入" description: "检测LLM的分布外输入,防止模型处理未知数据" tags: ["分布外检测", "OOD", "异常检测", "LLM", "输入验证"] category: "llm" icon: "🔍"
分布外检测:识别异常输入
OOD检测概述
分布外(Out-of-Distribution, OOD)检测是识别与训练数据分布不同的输入,防止模型处理未知或异常数据。
检测方法
1. 基于距离的检测
import numpy as np
import torch
from typing import List, Dict, Tuple
from sklearn.metrics.pairwise import cosine_distances
class DistanceBasedOODDetector:
"""基于距离的OOD检测"""
def __init__(self, model, tokenizer, reference_embeddings: np.ndarray):
self.model = model
self.tokenizer = tokenizer
self.reference_embeddings = reference_embeddings
self.threshold = None
def set_threshold(self, percentile: float = 95):
"""设置阈值"""
distances = self._compute_distances(self.reference_embeddings)
self.threshold = np.percentile(distances, percentile)
def _compute_distances(self, embeddings: np.ndarray) -> np.ndarray:
"""计算到参考数据的距离"""
distances = cosine_distances(embeddings, self.reference_embeddings)
return distances.min(axis=1)
def _get_embedding(self, text: str) -> np.ndarray:
"""获取文本嵌入"""
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze().numpy()
return embedding
def detect(self, text: str) -> Dict:
"""检测单个样本"""
embedding = self._get_embedding(text)
distance = self._compute_distances(embedding.reshape(1, -1))[0]
is_ood = distance > self.threshold if self.threshold else False
return {
"text": text,
"distance": float(distance),
"threshold": self.threshold,
"is_ood": is_ood,
"confidence": min(distance / self.threshold, 2.0) if self.threshold else 0
}
def detect_batch(self, texts: List[str]) -> List[Dict]:
"""批量检测"""
embeddings = np.array([self._get_embedding(text) for text in texts])
distances = self._compute_distances(embeddings)
results = []
for text, distance in zip(texts, distances):
is_ood = distance > self.threshold if self.threshold else False
results.append({
"text": text,
"distance": float(distance),
"is_ood": is_ood
})
return results
2. 基于不确定性的检测
class UncertaintyOODDetector:
"""基于不确定性的OOD检测"""
def __init__(self, model, tokenizer, n_samples: int = 10):
self.model = model
self.tokenizer = tokenizer
self.n_samples = n_samples
def compute_uncertainty(self, text: str) -> Dict:
"""计算不确定性"""
inputs = self.tokenizer(text, return_tensors="pt")
# MC Dropout采样
predictions = []
self.model.train() # 启用dropout
for _ in range(self.n_samples):
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
predictions.append(probs.numpy())
self.model.eval()
predictions = np.array(predictions)
# 计算不确定性指标
mean_prediction = predictions.mean(axis=0)
# 熵
entropy = -np.sum(mean_prediction * np.log(mean_prediction + 1e-10), axis=-1)
# 互信息
expected_entropy = -np.mean(np.sum(predictions * np.log(predictions + 1e-10), axis=-1), axis=0)
mutual_info = entropy - expected_entropy
# 预测方差
prediction_variance = np.var(predictions, axis=0).mean()
return {
"entropy": float(entropy.mean()),
"mutual_information": float(mutual_info.mean()),
"prediction_variance": float(prediction_variance),
"mean_confidence": float(mean_prediction.max(axis=-1).mean())
}
def detect(self, text: str, uncertainty_threshold: float = 1.0) -> Dict:
"""检测OOD样本"""
uncertainty = self.compute_uncertainty(text)
is_ood = uncertainty["entropy"] > uncertainty_threshold
return {
"text": text,
"uncertainty": uncertainty,
"is_ood": is_ood,
"confidence": uncertainty["entropy"] / uncertainty_threshold if uncertainty_threshold else 0
}
3. 基于重构的检测
class ReconstructionOODDetector:
"""基于重构的OOD检测"""
def __init__(self, autoencoder, tokenizer):
self.autoencoder = autoencoder
self.tokenizer = tokenizer
self.reconstruction_threshold = None
def set_threshold(self, reference_texts: List[str], percentile: float = 95):
"""设置阈值"""
recon_errors = []
for text in reference_texts:
error = self._compute_reconstruction_error(text)
recon_errors.append(error)
self.reconstruction_threshold = np.percentile(recon_errors, percentile)
def _compute_reconstruction_error(self, text: str) -> float:
"""计算重构误差"""
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
reconstructed = self.autoencoder(inputs["input_ids"])
error = torch.nn.functional.mse_loss(
reconstructed["reconstructions"],
inputs["input_ids"].float()
).item()
return error
def detect(self, text: str) -> Dict:
"""检测OOD样本"""
recon_error = self._compute_reconstruction_error(text)
is_ood = recon_error > self.reconstruction_threshold if self.reconstruction_threshold else False
return {
"text": text,
"reconstruction_error": recon_error,
"threshold": self.reconstruction_threshold,
"is_ood": is_ood,
"confidence": recon_error / self.reconstruction_threshold if self.reconstruction_threshold else 0
}
高级检测器
1. 集成OOD检测
class EnsembleOODDetector:
"""集成OOD检测"""
def __init__(self, detectors: List):
self.detectors = detectors
self.weights = [1.0 / len(detectors)] * len(detectors)
def detect(self, text: str, voting_threshold: float = 0.5) -> Dict:
"""集成检测"""
predictions = []
scores = []
for detector in self.detectors:
result = detector.detect(text)
predictions.append(result["is_ood"])
scores.append(result.get("confidence", 0))
# 加权投票
weighted_score = sum(p * w for p, w in zip(scores, self.weights))
is_ood = weighted_score > voting_threshold
return {
"text": text,
"individual_results": [
{"detector": type(d).__name__, "is_ood": pred, "score": score}
for d, pred, score in zip(self.detectors, predictions, scores)
],
"ensemble_score": weighted_score,
"is_ood": is_ood,
"confidence": weighted_score
}
2. 自适应阈值
class AdaptiveThresholdOODDetector:
"""自适应阈值OOD检测"""
def __init__(self, base_detector, window_size: int = 100):
self.base_detector = base_detector
self.window_size = window_size
self.score_history = []
self.current_threshold = None
def update_threshold(self):
"""更新阈值"""
if len(self.score_history) >= self.window_size:
recent_scores = self.score_history[-self.window_size:]
self.current_threshold = np.percentile(recent_scores, 95)
def detect(self, text: str) -> Dict:
"""检测并更新阈值"""
result = self.base_detector.detect(text)
# 更新历史
self.score_history.append(result.get("distance", 0))
if len(self.score_history) > self.window_size * 2:
self.score_history = self.score_history[-self.window_size:]
# 更新阈值
self.update_threshold()
# 使用自适应阈值判断
if self.current_threshold:
result["is_ood"] = result.get("distance", 0) > self.current_threshold
result["adaptive_threshold"] = self.current_threshold
return result
评估工具
class OODEvaluator:
"""OOD检测评估"""
@staticmethod
def evaluate(detector, in_distribution_texts: List[str], ood_texts: List[str]) -> Dict:
"""评估OOD检测性能"""
# 检测ID样本
id_results = [detector.detect(text) for text in in_distribution_texts]
id_scores = [r.get("distance", 0) for r in id_results]
# 检测OOD样本
ood_results = [detector.detect(text) for text in ood_texts]
ood_scores = [r.get("distance", 0) for r in ood_results]
# 计算指标
from sklearn.metrics import roc_auc_score, precision_recall_curve
labels = [0] * len(id_scores) + [1] * len(ood_scores)
scores = id_scores + ood_scores
try:
auroc = roc_auc_score(labels, scores)
except:
auroc = 0.5
return {
"auroc": auroc,
"id_mean_score": np.mean(id_scores),
"ood_mean_score": np.mean(ood_scores),
"separation": np.mean(ood_scores) - np.mean(id_scores)
}
最佳实践
- 多方法集成:结合多种OOD检测方法
- 阈值调优:根据应用需求调整检测阈值
- 持续监控:在生产环境中持续监控OOD比例
- 优雅处理:对检测到的OOD样本进行优雅处理
总结
OOD检测是确保LLM安全可靠的重要环节。通过识别异常输入,可以防止模型处理未知数据,提高系统鲁棒性。