← 返回首页
🧠

错误分析:系统化的问题诊断

📂 llm ⏱ 5 min 806 words

--- title: "错误分析:系统化的问题诊断" description: "系统化的LLM错误分析方法,从错误中学习并改进模型" tags: ["错误分析", "问题诊断", "LLM", "质量改进", "调试"] category: "llm" icon: "❌"

错误分析:系统化的问题诊断

错误分析概述

错误分析是系统化地收集、分类和分析模型错误的过程,帮助理解失败模式并指导改进方向。

错误收集

1. 错误记录器

import json
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional
from datetime import datetime
from pathlib import Path

@dataclass
class ErrorRecord:
    """错误记录"""
    id: str
    timestamp: datetime
    input_text: str
    expected_output: str
    actual_output: str
    error_type: str
    severity: str
    confidence: float
    metadata: Dict[str, Any] = None

class ErrorCollector:
    """错误收集器"""
    
    def __init__(self, storage_path: str = "errors"):
        self.storage_path = Path(storage_path)
        self.storage_path.mkdir(exist_ok=True)
        self.errors = []
    
    def record_error(self, error: ErrorRecord):
        """记录错误"""
        self.errors.append(error)
        
        # 保存到文件
        filepath = self.storage_path / f"error_{error.id}.json"
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(asdict(error), f, ensure_ascii=False, indent=2, default=str)
    
    def get_errors(self, error_type: str = None, severity: str = None) -> List[ErrorRecord]:
        """获取错误"""
        filtered = self.errors
        
        if error_type:
            filtered = [e for e in filtered if e.error_type == error_type]
        if severity:
            filtered = [e for e in filtered if e.severity == severity]
        
        return filtered
    
    def get_statistics(self) -> Dict:
        """获取统计信息"""
        if not self.errors:
            return {"total": 0}
        
        stats = {
            "total": len(self.errors),
            "by_type": {},
            "by_severity": {},
            "avg_confidence": sum(e.confidence for e in self.errors) / len(self.errors)
        }
        
        for error in self.errors:
            stats["by_type"][error.error_type] = stats["by_type"].get(error.error_type, 0) + 1
            stats["by_severity"][error.severity] = stats["by_severity"].get(error.severity, 0) + 1
        
        return stats

2. 自动错误检测

class ErrorDetector:
    """自动错误检测"""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def detect_errors(self, test_cases: List[Dict]) -> List[ErrorRecord]:
        """检测错误"""
        errors = []
        
        for i, test_case in enumerate(test_cases):
            input_text = test_case["input"]
            expected = test_case["expected"]
            
            # 生成输出
            inputs = self.tokenizer(input_text, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model.generate(**inputs, max_new_tokens=100)
                actual = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], 
                                               skip_special_tokens=True)
            
            # 检测错误类型
            error_type = self._classify_error(input_text, expected, actual)
            
            if error_type:
                error = ErrorRecord(
                    id=f"auto_{i}",
                    timestamp=datetime.now(),
                    input_text=input_text,
                    expected_output=expected,
                    actual_output=actual,
                    error_type=error_type,
                    severity=self._assess_severity(error_type, expected, actual),
                    confidence=self._compute_confidence(expected, actual)
                )
                errors.append(error)
        
        return errors
    
    def _classify_error(self, input_text: str, expected: str, actual: str) -> Optional[str]:
        """分类错误类型"""
        if not actual:
            return "empty_output"
        if len(actual) < len(expected) * 0.5:
            return "incomplete_output"
        if expected.lower() not in actual.lower() and len(expected) > 10:
            return "content_mismatch"
        if actual.count(actual[:10]) > 3:
            return "repetition"
        return None
    
    def _assess_severity(self, error_type: str, expected: str, actual: str) -> str:
        """评估严重程度"""
        severity_map = {
            "empty_output": "critical",
            "content_mismatch": "high",
            "incomplete_output": "medium",
            "repetition": "medium"
        }
        return severity_map.get(error_type, "low")
    
    def _compute_confidence(self, expected: str, actual: str) -> float:
        """计算置信度"""
        if not actual:
            return 1.0
        # 简单的相似度计算
        intersection = set(expected.lower().split()) & set(actual.lower().split())
        union = set(expected.lower().split()) | set(actual.lower().split())
        return len(intersection) / len(union) if union else 0

错误分析

1. 错误模式分析

class ErrorPatternAnalyzer:
    """错误模式分析"""
    
    def __init__(self):
        self.patterns = {}
    
    def analyze_patterns(self, errors: List[ErrorRecord]) -> Dict:
        """分析错误模式"""
        patterns = {
            "by_input_length": {},
            "by_error_position": {},
            "common_tokens": {}
        }
        
        for error in errors:
            # 按输入长度分析
            length_bucket = len(error.input_text) // 100
            if length_bucket not in patterns["by_input_length"]:
                patterns["by_input_length"][length_bucket] = []
            patterns["by_input_length"][length_bucket].append(error.error_type)
            
            # 分析错误位置
            if error.expected_output and error.actual_output:
                # 找到第一个不同点
                for i, (e, a) in enumerate(zip(error.expected_output, error.actual_output)):
                    if e != a:
                        position = i / len(error.expected_output)
                        bucket = int(position * 10)
                        if bucket not in patterns["by_error_position"]:
                            patterns["by_error_position"][bucket] = 0
                        patterns["by_error_position"][bucket] += 1
                        break
        
        return patterns
    
    def find_root_causes(self, errors: List[ErrorRecord]) -> List[Dict]:
        """查找根本原因"""
        causes = []
        
        # 按错误类型分组
        by_type = {}
        for error in errors:
            if error.error_type not in by_type:
                by_type[error.error_type] = []
            by_type[error.error_type].append(error)
        
        for error_type, type_errors in by_type.items():
            # 分析共同特征
            cause = {
                "error_type": error_type,
                "count": len(type_errors),
                "severity": max(e.severity for e in type_errors),
                "avg_confidence": sum(e.confidence for e in type_errors) / len(type_errors),
                "possible_causes": self._suggest_causes(error_type, type_errors)
            }
            causes.append(cause)
        
        return causes
    
    def _suggest_causes(self, error_type: str, errors: List[ErrorRecord]) -> List[str]:
        """建议可能的原因"""
        causes = []
        
        if error_type == "empty_output":
            causes.extend(["模型未学习到有效模式", "提示不清晰", "生成参数设置不当"])
        elif error_type == "content_mismatch":
            causes.extend(["训练数据分布偏移", "模型理解能力不足", "任务定义不明确"])
        elif error_type == "repetition":
            causes.extend(["温度设置过低", "训练数据重复", "模型退化"])
        
        return causes

2. 改进建议生成

class ImprovementRecommender:
    """改进建议生成器"""
    
    def __init__(self):
        self.recommendations = []
    
    def generate_recommendations(self, analysis: Dict) -> List[Dict]:
        """生成改进建议"""
        recommendations = []
        
        # 基于错误类型
        if "by_type" in analysis:
            for error_type, count in analysis["by_type"].items():
                if count > 10:
                    rec = self._get_recommendation(error_type, count)
                    if rec:
                        recommendations.append(rec)
        
        # 基于错误模式
        if "patterns" in analysis:
            patterns = analysis["patterns"]
            
            if "by_input_length" in patterns:
                for length_bucket, errors in patterns["by_input_length"].items():
                    if len(errors) > 5:
                        recommendations.append({
                            "type": "data_augmentation",
                            "description": f"增强长度为{length_bucket*100}-{(length_bucket+1)*100}的样本",
                            "priority": "medium"
                        })
        
        return recommendations
    
    def _get_recommendation(self, error_type: str, count: int) -> Optional[Dict]:
        """获取错误类型对应的建议"""
        recommendations_map = {
            "empty_output": {
                "type": "model_adjustment",
                "description": "调整生成参数(增加温度或top_p)",
                "priority": "high"
            },
            "content_mismatch": {
                "type": "data_improvement",
                "description": "检查并改进训练数据质量",
                "priority": "high"
            },
            "repetition": {
                "type": "training_adjustment",
                "description": "增加训练数据多样性",
                "priority": "medium"
            },
            "incomplete_output": {
                "type": "generation_config",
                "description": "增加max_new_tokens",
                "priority": "low"
            }
        }
        
        return recommendations_map.get(error_type)

可视化

import matplotlib.pyplot as plt

def plot_error_analysis(errors: List[ErrorRecord]):
    """绘制错误分析图"""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # 错误类型分布
    error_types = [e.error_type for e in errors]
    type_counts = {}
    for t in error_types:
        type_counts[t] = type_counts.get(t, 0) + 1
    
    axes[0, 0].bar(type_counts.keys(), type_counts.values())
    axes[0, 0].set_title("Error Type Distribution")
    axes[0, 0].set_xlabel("Error Type")
    axes[0, 0].set_ylabel("Count")
    axes[0, 0].tick_params(axis="x", rotation=45)
    
    # 严重程度分布
    severities = [e.severity for e in errors]
    severity_counts = {}
    for s in severities:
        severity_counts[s] = severity_counts.get(s, 0) + 1
    
    colors = {"critical": "red", "high": "orange", "medium": "yellow", "low": "green"}
    axes[0, 1].pie(severity_counts.values(), labels=severity_counts.keys(),
                   colors=[colors.get(k, "gray") for k in severity_counts.keys()])
    axes[0, 1].set_title("Severity Distribution")
    
    # 置信度分布
    confidences = [e.confidence for e in errors]
    axes[1, 0].hist(confidences, bins=20, edgecolor="black")
    axes[1, 0].set_title("Confidence Distribution")
    axes[1, 0].set_xlabel("Confidence")
    
    # 时间趋势
    timestamps = [e.timestamp for e in errors]
    axes[1, 1].plot(timestamps, range(len(timestamps)), "o-")
    axes[1, 1].set_title("Error Timeline")
    axes[1, 1].set_xlabel("Time")
    axes[1, 1].set_ylabel("Cumulative Errors")
    
    plt.tight_layout()
    plt.show()

最佳实践

  1. 持续收集:在生产环境中持续收集错误
  2. 分类详细:建立详细的错误分类体系
  3. 优先级排序:根据严重程度和频率排序问题
  4. 闭环改进:将分析结果转化为具体改进措施

总结

错误分析是持续改进LLM性能的关键环节。通过系统化的错误收集、分析和改进,可以不断提升模型质量。