模型校准:对齐信心与准确度
--- title: "模型校准:对齐信心与准确度" description: "校准LLM的预测置信度,使模型信心与实际准确度对齐" tags: ["校准", "置信度校准", "可靠性", "LLM", "概率校准"] category: "llm" icon: "⚖️"
模型校准:对齐信心与准确度
校准概述
模型校准是使预测概率与实际准确度对齐的过程,确保模型说"90%确信"时真的有90%的准确率。
校准方法
1. 校准度量
import numpy as np
from typing import List, Tuple
class CalibrationMetrics:
"""校准度量"""
@staticmethod
def expected_calibration_error(
confidences: np.ndarray,
accuracies: np.ndarray,
n_bins: int = 10
) -> float:
"""期望校准误差(ECE)"""
bin_boundaries = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for i in range(n_bins):
mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
if mask.sum() > 0:
bin_confidence = confidences[mask].mean()
bin_accuracy = accuracies[mask].mean()
bin_weight = mask.sum() / len(confidences)
ece += bin_weight * abs(bin_accuracy - bin_confidence)
return ece
@staticmethod
def maximum_calibration_error(
confidences: np.ndarray,
accuracies: np.ndarray,
n_bins: int = 10
) -> float:
"""最大校准误差(MCE)"""
bin_boundaries = np.linspace(0, 1, n_bins + 1)
mce = 0.0
for i in range(n_bins):
mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
if mask.sum() > 0:
bin_confidence = confidences[mask].mean()
bin_accuracy = accuracies[mask].mean()
mce = max(mce, abs(bin_accuracy - bin_confidence))
return mce
@staticmethod
def brier_score(probabilities: np.ndarray, labels: np.ndarray) -> float:
"""Brier分数"""
return np.mean((probabilities - labels) ** 2)
@staticmethod
def reliability_diagram_data(
confidences: np.ndarray,
accuracies: np.ndarray,
n_bins: int = 10
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""可靠性图数据"""
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_confidences = []
bin_accuracies = []
bin_counts = []
for i in range(n_bins):
mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
if mask.sum() > 0:
bin_confidences.append(confidences[mask].mean())
bin_accuracies.append(accuracies[mask].mean())
bin_counts.append(mask.sum())
else:
bin_confidences.append((bin_boundaries[i] + bin_boundaries[i + 1]) / 2)
bin_accuracies.append(0)
bin_counts.append(0)
return np.array(bin_confidences), np.array(bin_accuracies), np.array(bin_counts)
2. 温度缩放校准
class TemperatureScaling:
"""温度缩放校准"""
def __init__(self):
self.temperature = 1.0
def fit(self, logits: np.ndarray, labels: np.ndarray, lr: float = 0.01, max_iter: int = 100):
"""拟合温度参数"""
from scipy.optimize import minimize
def nll_loss(temperature):
scaled_logits = logits / temperature
probs = np.exp(scaled_logits) / np.exp(scaled_logits).sum(axis=-1, keepdims=True)
nll = -np.mean(np.log(probs[np.arange(len(labels)), labels] + 1e-10))
return nll
result = minimize(nll_loss, x0=1.0, method='L-BFGS-B',
options={'maxiter': max_iter})
self.temperature = result.x[0]
def calibrate(self, logits: np.ndarray) -> np.ndarray:
"""校准概率"""
scaled_logits = logits / self.temperature
probs = np.exp(scaled_logits) / np.exp(scaled_logits).sum(axis=-1, keepdims=True)
return probs
3. Platt缩放校准
class PlattScaling:
"""Platt缩放校准"""
def __init__(self):
self.a = 1.0
self.b = 0.0
def fit(self, logits: np.ndarray, labels: np.ndarray):
"""拟合Platt缩放参数"""
from sklearn.linear_model import LogisticRegression
# 将logits转换为特征
features = logits[:, 1] - logits[:, 0] if logits.shape[1] == 2 else logits.max(axis=1)
# 拟合逻辑回归
lr = LogisticRegression()
lr.fit(features.reshape(-1, 1), labels)
self.a = lr.coef_[0][0]
self.b = lr.intercept_[0]
def calibrate(self, logits: np.ndarray) -> np.ndarray:
"""校准概率"""
features = logits[:, 1] - logits[:, 0] if logits.shape[1] == 2 else logits.max(axis=1)
# 应用Platt缩放
scaled = 1 / (1 + np.exp(-(self.a * features + self.b)))
# 转换为概率分布
if logits.shape[1] == 2:
probs = np.column_stack([1 - scaled, scaled])
else:
probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
probs = probs * scaled[:, np.newaxis]
probs = probs / probs.sum(axis=-1, keepdims=True)
return probs
可视化工具
import matplotlib.pyplot as plt
def plot_reliability_diagram(
confidences: np.ndarray,
accuracies: np.ndarray,
n_bins: int = 10,
title: str = "Reliability Diagram"
):
"""绘制可靠性图"""
bin_confidences, bin_accuracies, bin_counts = CalibrationMetrics.reliability_diagram_data(
confidences, accuracies, n_bins
)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 可靠性图
axes[0].plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated')
axes[0].bar(bin_confidences, bin_accuracies, width=1.0/n_bins, alpha=0.7, label='Model')
axes[0].set_xlabel('Confidence')
axes[0].set_ylabel('Accuracy')
axes[0].set_title(title)
axes[0].legend()
# ECE计算
ece = CalibrationMetrics.expected_calibration_error(confidences, accuracies, n_bins)
axes[0].text(0.1, 0.9, f'ECE: {ece:.4f}', transform=axes[0].transAxes)
# 样本分布
axes[1].bar(bin_confidences, bin_counts, width=1.0/n_bins, alpha=0.7)
axes[1].set_xlabel('Confidence')
axes[1].set_ylabel('Sample Count')
axes[1].set_title('Sample Distribution')
plt.tight_layout()
plt.show()
完整校准流程
class CalibrationPipeline:
"""校准管道"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.calibrator = None
def collect_predictions(self, texts: List[str], labels: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""收集预测结果"""
all_logits = []
all_labels = []
for text, label in zip(texts, labels):
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[:, -1, :].numpy()
all_logits.append(logits.squeeze())
all_labels.append(label)
return np.array(all_logits), np.array(all_labels)
def calibrate(self, logits: np.ndarray, labels: np.ndarray, method: str = "temperature"):
"""校准模型"""
if method == "temperature":
self.calibrator = TemperatureScaling()
elif method == "platt":
self.calibrator = PlattScaling()
self.calibrator.fit(logits, labels)
# 计算校准前后指标
original_probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
calibrated_probs = self.calibrator.calibrate(logits)
original_confidences = original_probs.max(axis=-1)
calibrated_confidences = calibrated_probs.max(axis=-1)
original_predictions = original_probs.argmax(axis=-1)
calibrated_predictions = calibrated_probs.argmax(axis=-1)
return {
"original_ece": CalibrationMetrics.expected_calibration_error(
original_confidences, (original_predictions == labels).astype(float)
),
"calibrated_ece": CalibrationMetrics.expected_calibration_error(
calibrated_confidences, (calibrated_predictions == labels).astype(float)
),
"temperature": getattr(self.calibrator, 'temperature', None)
}
def predict_calibrated(self, text: str) -> Dict:
"""校准后预测"""
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[:, -1, :].numpy()
if self.calibrator:
calibrated_probs = self.calibrator.calibrate(logits)
else:
calibrated_probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
return {
"probabilities": calibrated_probs.tolist(),
"prediction": calibrated_probs.argmax(axis=-1).item(),
"confidence": calibrated_probs.max().item()
}
最佳实践
- 验证集校准:在独立验证集上进行校准
- 定期重新校准:随着数据分布变化重新校准
- 多方法比较:比较不同校准方法的效果
- 监控校准指标:持续监控ECE等校准指标
总结
模型校准是确保LLM输出可靠的重要环节。通过校准预测概率,可以使模型的信心与实际准确度对齐,提高决策质量。