错误分析:系统化的问题诊断
--- title: "错误分析:系统化的问题诊断" description: "系统化的LLM错误分析方法,从错误中学习并改进模型" tags: ["错误分析", "问题诊断", "LLM", "质量改进", "调试"] category: "llm" icon: "❌"
错误分析:系统化的问题诊断
错误分析概述
错误分析是系统化地收集、分类和分析模型错误的过程,帮助理解失败模式并指导改进方向。
错误收集
1. 错误记录器
import json
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional
from datetime import datetime
from pathlib import Path
@dataclass
class ErrorRecord:
"""错误记录"""
id: str
timestamp: datetime
input_text: str
expected_output: str
actual_output: str
error_type: str
severity: str
confidence: float
metadata: Dict[str, Any] = None
class ErrorCollector:
"""错误收集器"""
def __init__(self, storage_path: str = "errors"):
self.storage_path = Path(storage_path)
self.storage_path.mkdir(exist_ok=True)
self.errors = []
def record_error(self, error: ErrorRecord):
"""记录错误"""
self.errors.append(error)
# 保存到文件
filepath = self.storage_path / f"error_{error.id}.json"
with open(filepath, "w", encoding="utf-8") as f:
json.dump(asdict(error), f, ensure_ascii=False, indent=2, default=str)
def get_errors(self, error_type: str = None, severity: str = None) -> List[ErrorRecord]:
"""获取错误"""
filtered = self.errors
if error_type:
filtered = [e for e in filtered if e.error_type == error_type]
if severity:
filtered = [e for e in filtered if e.severity == severity]
return filtered
def get_statistics(self) -> Dict:
"""获取统计信息"""
if not self.errors:
return {"total": 0}
stats = {
"total": len(self.errors),
"by_type": {},
"by_severity": {},
"avg_confidence": sum(e.confidence for e in self.errors) / len(self.errors)
}
for error in self.errors:
stats["by_type"][error.error_type] = stats["by_type"].get(error.error_type, 0) + 1
stats["by_severity"][error.severity] = stats["by_severity"].get(error.severity, 0) + 1
return stats
2. 自动错误检测
class ErrorDetector:
"""自动错误检测"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def detect_errors(self, test_cases: List[Dict]) -> List[ErrorRecord]:
"""检测错误"""
errors = []
for i, test_case in enumerate(test_cases):
input_text = test_case["input"]
expected = test_case["expected"]
# 生成输出
inputs = self.tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=100)
actual = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True)
# 检测错误类型
error_type = self._classify_error(input_text, expected, actual)
if error_type:
error = ErrorRecord(
id=f"auto_{i}",
timestamp=datetime.now(),
input_text=input_text,
expected_output=expected,
actual_output=actual,
error_type=error_type,
severity=self._assess_severity(error_type, expected, actual),
confidence=self._compute_confidence(expected, actual)
)
errors.append(error)
return errors
def _classify_error(self, input_text: str, expected: str, actual: str) -> Optional[str]:
"""分类错误类型"""
if not actual:
return "empty_output"
if len(actual) < len(expected) * 0.5:
return "incomplete_output"
if expected.lower() not in actual.lower() and len(expected) > 10:
return "content_mismatch"
if actual.count(actual[:10]) > 3:
return "repetition"
return None
def _assess_severity(self, error_type: str, expected: str, actual: str) -> str:
"""评估严重程度"""
severity_map = {
"empty_output": "critical",
"content_mismatch": "high",
"incomplete_output": "medium",
"repetition": "medium"
}
return severity_map.get(error_type, "low")
def _compute_confidence(self, expected: str, actual: str) -> float:
"""计算置信度"""
if not actual:
return 1.0
# 简单的相似度计算
intersection = set(expected.lower().split()) & set(actual.lower().split())
union = set(expected.lower().split()) | set(actual.lower().split())
return len(intersection) / len(union) if union else 0
错误分析
1. 错误模式分析
class ErrorPatternAnalyzer:
"""错误模式分析"""
def __init__(self):
self.patterns = {}
def analyze_patterns(self, errors: List[ErrorRecord]) -> Dict:
"""分析错误模式"""
patterns = {
"by_input_length": {},
"by_error_position": {},
"common_tokens": {}
}
for error in errors:
# 按输入长度分析
length_bucket = len(error.input_text) // 100
if length_bucket not in patterns["by_input_length"]:
patterns["by_input_length"][length_bucket] = []
patterns["by_input_length"][length_bucket].append(error.error_type)
# 分析错误位置
if error.expected_output and error.actual_output:
# 找到第一个不同点
for i, (e, a) in enumerate(zip(error.expected_output, error.actual_output)):
if e != a:
position = i / len(error.expected_output)
bucket = int(position * 10)
if bucket not in patterns["by_error_position"]:
patterns["by_error_position"][bucket] = 0
patterns["by_error_position"][bucket] += 1
break
return patterns
def find_root_causes(self, errors: List[ErrorRecord]) -> List[Dict]:
"""查找根本原因"""
causes = []
# 按错误类型分组
by_type = {}
for error in errors:
if error.error_type not in by_type:
by_type[error.error_type] = []
by_type[error.error_type].append(error)
for error_type, type_errors in by_type.items():
# 分析共同特征
cause = {
"error_type": error_type,
"count": len(type_errors),
"severity": max(e.severity for e in type_errors),
"avg_confidence": sum(e.confidence for e in type_errors) / len(type_errors),
"possible_causes": self._suggest_causes(error_type, type_errors)
}
causes.append(cause)
return causes
def _suggest_causes(self, error_type: str, errors: List[ErrorRecord]) -> List[str]:
"""建议可能的原因"""
causes = []
if error_type == "empty_output":
causes.extend(["模型未学习到有效模式", "提示不清晰", "生成参数设置不当"])
elif error_type == "content_mismatch":
causes.extend(["训练数据分布偏移", "模型理解能力不足", "任务定义不明确"])
elif error_type == "repetition":
causes.extend(["温度设置过低", "训练数据重复", "模型退化"])
return causes
2. 改进建议生成
class ImprovementRecommender:
"""改进建议生成器"""
def __init__(self):
self.recommendations = []
def generate_recommendations(self, analysis: Dict) -> List[Dict]:
"""生成改进建议"""
recommendations = []
# 基于错误类型
if "by_type" in analysis:
for error_type, count in analysis["by_type"].items():
if count > 10:
rec = self._get_recommendation(error_type, count)
if rec:
recommendations.append(rec)
# 基于错误模式
if "patterns" in analysis:
patterns = analysis["patterns"]
if "by_input_length" in patterns:
for length_bucket, errors in patterns["by_input_length"].items():
if len(errors) > 5:
recommendations.append({
"type": "data_augmentation",
"description": f"增强长度为{length_bucket*100}-{(length_bucket+1)*100}的样本",
"priority": "medium"
})
return recommendations
def _get_recommendation(self, error_type: str, count: int) -> Optional[Dict]:
"""获取错误类型对应的建议"""
recommendations_map = {
"empty_output": {
"type": "model_adjustment",
"description": "调整生成参数(增加温度或top_p)",
"priority": "high"
},
"content_mismatch": {
"type": "data_improvement",
"description": "检查并改进训练数据质量",
"priority": "high"
},
"repetition": {
"type": "training_adjustment",
"description": "增加训练数据多样性",
"priority": "medium"
},
"incomplete_output": {
"type": "generation_config",
"description": "增加max_new_tokens",
"priority": "low"
}
}
return recommendations_map.get(error_type)
可视化
import matplotlib.pyplot as plt
def plot_error_analysis(errors: List[ErrorRecord]):
"""绘制错误分析图"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# 错误类型分布
error_types = [e.error_type for e in errors]
type_counts = {}
for t in error_types:
type_counts[t] = type_counts.get(t, 0) + 1
axes[0, 0].bar(type_counts.keys(), type_counts.values())
axes[0, 0].set_title("Error Type Distribution")
axes[0, 0].set_xlabel("Error Type")
axes[0, 0].set_ylabel("Count")
axes[0, 0].tick_params(axis="x", rotation=45)
# 严重程度分布
severities = [e.severity for e in errors]
severity_counts = {}
for s in severities:
severity_counts[s] = severity_counts.get(s, 0) + 1
colors = {"critical": "red", "high": "orange", "medium": "yellow", "low": "green"}
axes[0, 1].pie(severity_counts.values(), labels=severity_counts.keys(),
colors=[colors.get(k, "gray") for k in severity_counts.keys()])
axes[0, 1].set_title("Severity Distribution")
# 置信度分布
confidences = [e.confidence for e in errors]
axes[1, 0].hist(confidences, bins=20, edgecolor="black")
axes[1, 0].set_title("Confidence Distribution")
axes[1, 0].set_xlabel("Confidence")
# 时间趋势
timestamps = [e.timestamp for e in errors]
axes[1, 1].plot(timestamps, range(len(timestamps)), "o-")
axes[1, 1].set_title("Error Timeline")
axes[1, 1].set_xlabel("Time")
axes[1, 1].set_ylabel("Cumulative Errors")
plt.tight_layout()
plt.show()
最佳实践
- 持续收集:在生产环境中持续收集错误
- 分类详细:建立详细的错误分类体系
- 优先级排序:根据严重程度和频率排序问题
- 闭环改进:将分析结果转化为具体改进措施
总结
错误分析是持续改进LLM性能的关键环节。通过系统化的错误收集、分析和改进,可以不断提升模型质量。