← 返回首页
🧠

模型校准:对齐信心与准确度

📂 llm ⏱ 4 min 664 words

--- title: "模型校准:对齐信心与准确度" description: "校准LLM的预测置信度,使模型信心与实际准确度对齐" tags: ["校准", "置信度校准", "可靠性", "LLM", "概率校准"] category: "llm" icon: "⚖️"

模型校准:对齐信心与准确度

校准概述

模型校准是使预测概率与实际准确度对齐的过程,确保模型说"90%确信"时真的有90%的准确率。

校准方法

1. 校准度量

import numpy as np
from typing import List, Tuple

class CalibrationMetrics:
    """校准度量"""
    
    @staticmethod
    def expected_calibration_error(
        confidences: np.ndarray, 
        accuracies: np.ndarray, 
        n_bins: int = 10
    ) -> float:
        """期望校准误差(ECE)"""
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        ece = 0.0
        
        for i in range(n_bins):
            mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            if mask.sum() > 0:
                bin_confidence = confidences[mask].mean()
                bin_accuracy = accuracies[mask].mean()
                bin_weight = mask.sum() / len(confidences)
                ece += bin_weight * abs(bin_accuracy - bin_confidence)
        
        return ece
    
    @staticmethod
    def maximum_calibration_error(
        confidences: np.ndarray, 
        accuracies: np.ndarray, 
        n_bins: int = 10
    ) -> float:
        """最大校准误差(MCE)"""
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        mce = 0.0
        
        for i in range(n_bins):
            mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            if mask.sum() > 0:
                bin_confidence = confidences[mask].mean()
                bin_accuracy = accuracies[mask].mean()
                mce = max(mce, abs(bin_accuracy - bin_confidence))
        
        return mce
    
    @staticmethod
    def brier_score(probabilities: np.ndarray, labels: np.ndarray) -> float:
        """Brier分数"""
        return np.mean((probabilities - labels) ** 2)
    
    @staticmethod
    def reliability_diagram_data(
        confidences: np.ndarray, 
        accuracies: np.ndarray, 
        n_bins: int = 10
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """可靠性图数据"""
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_confidences = []
        bin_accuracies = []
        bin_counts = []
        
        for i in range(n_bins):
            mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            if mask.sum() > 0:
                bin_confidences.append(confidences[mask].mean())
                bin_accuracies.append(accuracies[mask].mean())
                bin_counts.append(mask.sum())
            else:
                bin_confidences.append((bin_boundaries[i] + bin_boundaries[i + 1]) / 2)
                bin_accuracies.append(0)
                bin_counts.append(0)
        
        return np.array(bin_confidences), np.array(bin_accuracies), np.array(bin_counts)

2. 温度缩放校准

class TemperatureScaling:
    """温度缩放校准"""
    
    def __init__(self):
        self.temperature = 1.0
    
    def fit(self, logits: np.ndarray, labels: np.ndarray, lr: float = 0.01, max_iter: int = 100):
        """拟合温度参数"""
        from scipy.optimize import minimize
        
        def nll_loss(temperature):
            scaled_logits = logits / temperature
            probs = np.exp(scaled_logits) / np.exp(scaled_logits).sum(axis=-1, keepdims=True)
            nll = -np.mean(np.log(probs[np.arange(len(labels)), labels] + 1e-10))
            return nll
        
        result = minimize(nll_loss, x0=1.0, method='L-BFGS-B', 
                         options={'maxiter': max_iter})
        self.temperature = result.x[0]
    
    def calibrate(self, logits: np.ndarray) -> np.ndarray:
        """校准概率"""
        scaled_logits = logits / self.temperature
        probs = np.exp(scaled_logits) / np.exp(scaled_logits).sum(axis=-1, keepdims=True)
        return probs

3. Platt缩放校准

class PlattScaling:
    """Platt缩放校准"""
    
    def __init__(self):
        self.a = 1.0
        self.b = 0.0
    
    def fit(self, logits: np.ndarray, labels: np.ndarray):
        """拟合Platt缩放参数"""
        from sklearn.linear_model import LogisticRegression
        
        # 将logits转换为特征
        features = logits[:, 1] - logits[:, 0] if logits.shape[1] == 2 else logits.max(axis=1)
        
        # 拟合逻辑回归
        lr = LogisticRegression()
        lr.fit(features.reshape(-1, 1), labels)
        
        self.a = lr.coef_[0][0]
        self.b = lr.intercept_[0]
    
    def calibrate(self, logits: np.ndarray) -> np.ndarray:
        """校准概率"""
        features = logits[:, 1] - logits[:, 0] if logits.shape[1] == 2 else logits.max(axis=1)
        
        # 应用Platt缩放
        scaled = 1 / (1 + np.exp(-(self.a * features + self.b)))
        
        # 转换为概率分布
        if logits.shape[1] == 2:
            probs = np.column_stack([1 - scaled, scaled])
        else:
            probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
            probs = probs * scaled[:, np.newaxis]
            probs = probs / probs.sum(axis=-1, keepdims=True)
        
        return probs

可视化工具

import matplotlib.pyplot as plt

def plot_reliability_diagram(
    confidences: np.ndarray, 
    accuracies: np.ndarray, 
    n_bins: int = 10,
    title: str = "Reliability Diagram"
):
    """绘制可靠性图"""
    bin_confidences, bin_accuracies, bin_counts = CalibrationMetrics.reliability_diagram_data(
        confidences, accuracies, n_bins
    )
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # 可靠性图
    axes[0].plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated')
    axes[0].bar(bin_confidences, bin_accuracies, width=1.0/n_bins, alpha=0.7, label='Model')
    axes[0].set_xlabel('Confidence')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_title(title)
    axes[0].legend()
    
    # ECE计算
    ece = CalibrationMetrics.expected_calibration_error(confidences, accuracies, n_bins)
    axes[0].text(0.1, 0.9, f'ECE: {ece:.4f}', transform=axes[0].transAxes)
    
    # 样本分布
    axes[1].bar(bin_confidences, bin_counts, width=1.0/n_bins, alpha=0.7)
    axes[1].set_xlabel('Confidence')
    axes[1].set_ylabel('Sample Count')
    axes[1].set_title('Sample Distribution')
    
    plt.tight_layout()
    plt.show()

完整校准流程

class CalibrationPipeline:
    """校准管道"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.calibrator = None
    
    def collect_predictions(self, texts: List[str], labels: List[int]) -> Tuple[np.ndarray, np.ndarray]:
        """收集预测结果"""
        all_logits = []
        all_labels = []
        
        for text, label in zip(texts, labels):
            inputs = self.tokenizer(text, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits[:, -1, :].numpy()
            
            all_logits.append(logits.squeeze())
            all_labels.append(label)
        
        return np.array(all_logits), np.array(all_labels)
    
    def calibrate(self, logits: np.ndarray, labels: np.ndarray, method: str = "temperature"):
        """校准模型"""
        if method == "temperature":
            self.calibrator = TemperatureScaling()
        elif method == "platt":
            self.calibrator = PlattScaling()
        
        self.calibrator.fit(logits, labels)
        
        # 计算校准前后指标
        original_probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
        calibrated_probs = self.calibrator.calibrate(logits)
        
        original_confidences = original_probs.max(axis=-1)
        calibrated_confidences = calibrated_probs.max(axis=-1)
        
        original_predictions = original_probs.argmax(axis=-1)
        calibrated_predictions = calibrated_probs.argmax(axis=-1)
        
        return {
            "original_ece": CalibrationMetrics.expected_calibration_error(
                original_confidences, (original_predictions == labels).astype(float)
            ),
            "calibrated_ece": CalibrationMetrics.expected_calibration_error(
                calibrated_confidences, (calibrated_predictions == labels).astype(float)
            ),
            "temperature": getattr(self.calibrator, 'temperature', None)
        }
    
    def predict_calibrated(self, text: str) -> Dict:
        """校准后预测"""
        inputs = self.tokenizer(text, return_tensors="pt")
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :].numpy()
        
        if self.calibrator:
            calibrated_probs = self.calibrator.calibrate(logits)
        else:
            calibrated_probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
        
        return {
            "probabilities": calibrated_probs.tolist(),
            "prediction": calibrated_probs.argmax(axis=-1).item(),
            "confidence": calibrated_probs.max().item()
        }

最佳实践

  1. 验证集校准:在独立验证集上进行校准
  2. 定期重新校准:随着数据分布变化重新校准
  3. 多方法比较:比较不同校准方法的效果
  4. 监控校准指标:持续监控ECE等校准指标

总结

模型校准是确保LLM输出可靠的重要环节。通过校准预测概率,可以使模型的信心与实际准确度对齐,提高决策质量。