← 返回首页
🧠

LLM A/B测试

📂 llm ⏱ 3 min 535 words

--- title: "LLM A/B测试" description: "LLM模型的A/B测试方法论和实践,包括实验设计、指标选择、统计分析和自动化测试框架" tags: ["A/B测试", "实验分析", "统计验证"] category: "llm" icon: "🧠"

LLM A/B测试

概述

A/B测试是比较两个或多个模型版本在真实用户场景下表现的科学方法。对于LLM项目,A/B测试帮助团队验证新模型是否真正优于现有模型,避免仅依赖离线评估指标做出发布决策。

实验设计

实验框架

# ab_testing/experiment.py
import hashlib
import random
from dataclasses import dataclass
from datetime import datetime
from typing import Optional

@dataclass
class ExperimentVariant:
    name: str
    model_name: str
    weight: float  # 流量权重
    params: dict = None

@dataclass
class Experiment:
    name: str
    variants: list[ExperimentVariant]
    start_date: datetime
    end_date: Optional[datetime] = None
    status: str = "running"  # running, paused, completed
    target_metric: str = "quality_score"
    min_sample_size: int = 1000

class ExperimentManager:
    def __init__(self):
        self.experiments: dict[str, Experiment] = {}
    
    def create_experiment(self, name: str, variants: list, 
                         target_metric: str = "quality_score",
                         min_samples: int = 1000) -> Experiment:
        exp = Experiment(
            name=name,
            variants=variants,
            start_date=datetime.now(),
            target_metric=target_metric,
            min_sample_size=min_samples
        )
        self.experiments[name] = exp
        return exp
    
    def assign_variant(self, experiment_name: str, user_id: str) -> ExperimentVariant:
        exp = self.experiments[experiment_name]
        
        if exp.status != "running":
            return exp.variants[0]
        
        # 确定性分配:同一用户始终分配到同一变体
        hash_input = f"{experiment_name}:{user_id}"
        hash_val = int(hashlib.md5(hash_input.encode()).hexdigest()[:8], 16)
        
        cumulative = 0.0
        total_weight = sum(v.weight for v in exp.variants)
        
        for variant in exp.variants:
            cumulative += variant.weight / total_weight
            if (hash_val % 1000) / 1000.0 < cumulative:
                return variant
        
        return exp.variants[-1]

指标收集

# ab_testing/metrics.py
from dataclasses import dataclass
from datetime import datetime
import json

@dataclass
class InteractionRecord:
    experiment: str
    variant: str
    user_id: str
    timestamp: datetime
    
    # 交互指标
    prompt_tokens: int
    completion_tokens: int
    latency_ms: float
    
    # 质量指标
    user_rating: Optional[int]  # 1-5评分
    response_relevance: Optional[float]  # 0-1相关度
    hallucination_score: Optional[float]  # 幻觉分数
    
    # 行为指标
    is_requery: bool  # 用户是否重新提问
    is_regenerate: bool  # 用户是否要求重新生成

class MetricsCollector:
    def __init__(self):
        self.records: list[InteractionRecord] = []
    
    def record(self, interaction: InteractionRecord):
        self.records.append(interaction)
    
    def compute_metrics(self, experiment_name: str) -> dict:
        exp_records = [r for r in self.records if r.experiment == experiment_name]
        
        variants = {}
        for record in exp_records:
            if record.variant not in variants:
                variants[record.variant] = []
            variants[record.variant].append(record)
        
        results = {}
        for variant_name, records in variants.items():
            results[variant_name] = {
                "sample_size": len(records),
                "avg_latency_ms": sum(r.latency_ms for r in records) / len(records),
                "avg_prompt_tokens": sum(r.prompt_tokens for r in records) / len(records),
                "avg_completion_tokens": sum(r.completion_tokens for r in records) / len(records),
                "requery_rate": sum(1 for r in records if r.is_requery) / len(records),
                "regenerate_rate": sum(1 for r in records if r.is_regenerate) / len(records),
            }
            
            ratings = [r.user_rating for r in records if r.user_rating is not None]
            if ratings:
                results[variant_name]["avg_rating"] = sum(ratings) / len(ratings)
            
            relevance = [r.response_relevance for r in records 
                        if r.response_relevance is not None]
            if relevance:
                results[variant_name]["avg_relevance"] = sum(relevance) / len(relevance)
        
        return results

统计分析

# ab_testing/analysis.py
import math
from scipy import stats

class StatisticalAnalyzer:
    def __init__(self, confidence_level: float = 0.95):
        self.confidence_level = confidence_level
    
    def compare_proportions(self, variant_a_successes: int, variant_a_total: int,
                           variant_b_successes: int, variant_b_total: int) -> dict:
        p_a = variant_a_successes / variant_a_total
        p_b = variant_b_successes / variant_b_total
        
        # 合并比例
        p_pool = (variant_a_successes + variant_b_successes) / (
            variant_a_total + variant_b_total
        )
        
        # 标准误差
        se = math.sqrt(p_pool * (1 - p_pool) * (
            1/variant_a_total + 1/variant_b_total
        ))
        
        if se == 0:
            return {"significant": False, "p_value": 1.0, "lift": 0}
        
        z_score = (p_b - p_a) / se
        p_value = 2 * (1 - stats.norm.cdf(abs(z_score)))
        
        lift = (p_b - p_a) / p_a if p_a > 0 else 0
        
        return {
            "significant": p_value < (1 - self.confidence_level),
            "p_value": p_value,
            "z_score": z_score,
            "lift": lift,
            "variant_a_rate": p_a,
            "variant_b_rate": p_b,
        }
    
    def analyze_experiment(self, metrics_a: dict, metrics_b: dict) -> dict:
        results = {}
        
        # 分析转化率(以requery_rate为示例)
        if "requery_rate" in metrics_a and "requery_rate" in metrics_b:
            results["requery_rate"] = self.compare_proportions(
                int(metrics_a["requery_rate"] * metrics_a["sample_size"]),
                metrics_a["sample_size"],
                int(metrics_b["requery_rate"] * metrics_b["sample_size"]),
                metrics_b["sample_size"]
            )
        
        # 分析评分
        if "avg_rating" in metrics_a and "avg_rating" in metrics_b:
            # 简化的t检验
            diff = metrics_b["avg_rating"] - metrics_a["avg_rating"]
            results["rating_diff"] = {
                "difference": diff,
                "improved": diff > 0
            }
        
        return results

最佳实践

  1. 充分样本量:确保每个变体有足够样本以达到统计显著性
  2. 随机化:使用确定性哈希确保用户一致性分配
  3. 多维度评估:不要只看单一指标,综合考虑质量、延迟、成本
  4. 实验时长:至少运行一周以覆盖不同时间段的使用模式
  5. 自动决策:设置自动提升阈值,达到条件后自动推广胜出变体