SHAP值在LLM解释中的应用
--- title: "SHAP值在LLM解释中的应用" description: "介绍SHAP值在大型语言模型解释和特征重要性分析中的应用。" tags: ["SHAP", "shapley值", "llm", "模型解释", "特征重要性"] category: "llm" icon: "🧠"
SHAP值在LLM解释中的应用
什么是SHAP值?
SHAP(SHapley Additive exPlanations)是一种基于博弈论的特征归因方法,通过计算每个特征对预测的边际贡献来解释模型预测。
SHAP值原理
1. 基本SHAP计算
import numpy as np
import torch
from itertools import combinations
class SHAPCalculator:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def compute_kernel_shap(self, input_ids, background_data, n_samples=100):
"""计算核SHAP值"""
# 获取预测
with torch.no_grad():
target_output = self.model(input_ids)
target_pred = self._get_prediction(target_output)
# 获取背景数据的预测
background_preds = []
for bg in background_data:
with torch.no_grad():
bg_output = self.model(bg.unsqueeze(0))
background_preds.append(self._get_prediction(bg_output))
background_pred = np.mean(background_preds)
# 计算SHAP值
n_features = input_ids.shape[1]
shap_values = np.zeros(n_features)
# 使用采样近似
for _ in range(n_samples):
# 随机掩码
mask = np.random.binomial(1, 0.5, n_features)
# 创建掩码输入
masked_input = input_ids.clone()
masked_input[0][mask == 0] = self.tokenizer.pad_token_id
# 计算掩码输入的预测
with torch.no_grad():
masked_output = self.model(masked_input)
masked_pred = self._get_prediction(masked_output)
# 计算边际贡献
for i in range(n_features):
if mask[i] == 1:
# 添加特征i的贡献
shap_values[i] += (masked_pred - background_pred) / n_samples
return shap_values
def compute_exact_shap(self, input_ids):
"""计算精确SHAP值(适用于小特征集)"""
n_features = input_ids.shape[1]
# 获取完整预测
with torch.no_grad():
full_output = self.model(input_ids)
full_pred = self._get_prediction(full_output)
shap_values = np.zeros(n_features)
# 遍历所有可能的特征子集
for i in range(n_features):
# 计算包含特征i和不包含特征i的联盟
for subset_size in range(n_features):
for subset in combinations(range(n_features), subset_size):
# 创建掩码输入
masked_input = input_ids.clone()
for j in range(n_features):
if j not in subset and j != i:
masked_input[0][j] = self.tokenizer.pad_token_id
# 计算预测
with torch.no_grad():
masked_output = self.model(masked_input)
masked_pred = self._get_prediction(masked_output)
# 添加特征i到子集
subset_with_i = tuple(sorted(subset + (i,)))
masked_input_with_i = input_ids.clone()
for j in range(n_features):
if j not in subset_with_i:
masked_input_with_i[0][j] = self.tokenizer.pad_token_id
with torch.no_grad():
masked_output_with_i = self.model(masked_input_with_i)
pred_with_i = self._get_prediction(masked_output_with_i)
# 计算边际贡献
weight = (np.math.factorial(subset_size) *
np.math.factorial(n_features - subset_size - 1) /
np.math.factorial(n_features))
shap_values[i] += weight * (pred_with_i - masked_pred)
return shap_values
def _get_prediction(self, output):
"""获取模型预测"""
if hasattr(output, 'logits'):
logits = output.logits
else:
logits = output
# 返回最大概率
probs = torch.softmax(logits, dim=-1)
return probs.max().item()
2. SHAP值可视化
class SHAPVisualizer:
def __init__(self):
self.figures = {}
def plot_waterfall(self, tokens, shap_values,
title="SHAP Waterfall Plot",
figsize=(12, 6)):
"""绘制瀑布图"""
fig, ax = plt.subplots(figsize=figsize)
# 按SHAP值排序
sorted_indices = np.argsort(shap_values)
sorted_tokens = [tokens[i] for i in sorted_indices]
sorted_values = shap_values[sorted_indices]
# 绘制瀑布图
colors = ['red' if v < 0 else 'blue' for v in sorted_values]
# 绘制条形图
y_pos = range(len(sorted_tokens))
ax.barh(y_pos, sorted_values, color=colors, alpha=0.7)
# 添加标签
ax.set_yticks(y_pos)
ax.set_yticklabels(sorted_tokens)
ax.set_xlabel('SHAP Value')
ax.set_title(title)
ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
plt.tight_layout()
return fig
def plot_summary(self, tokens, shap_values,
title="SHAP Summary Plot",
figsize=(12, 8)):
"""绘制摘要图"""
fig, ax = plt.subplots(figsize=figsize)
# 按SHAP值排序
sorted_indices = np.argsort(np.abs(shap_values))[::-1]
top_k = min(10, len(sorted_indices))
top_indices = sorted_indices[:top_k]
top_tokens = [tokens[i] for i in top_indices]
top_values = shap_values[top_indices]
# 绘制条形图
colors = ['red' if v < 0 else 'blue' for v in top_values]
ax.barh(range(len(top_tokens)), top_values, color=colors, alpha=0.7)
ax.set_yticks(range(len(top_tokens)))
ax.set_yticklabels(top_tokens)
ax.set_xlabel('Mean |SHAP Value|')
ax.set_title(title)
ax.invert_yaxis()
plt.tight_layout()
return fig
def plot_dependence(self, tokens, shap_values, feature_idx,
title="SHAP Dependence Plot",
figsize=(10, 6)):
"""绘制依赖图"""
fig, ax = plt.subplots(figsize=figsize)
# 绘制散点图
x = np.arange(len(tokens))
y = shap_values
colors = ['red' if v < 0 else 'blue' for v in y]
ax.scatter(x, y, c=colors, alpha=0.6)
# 高亮指定特征
ax.scatter([feature_idx], [shap_values[feature_idx]],
color='green', s=100, zorder=5, label='Selected Feature')
ax.set_xticks(x)
ax.set_xticklabels(tokens, rotation=45, ha='right')
ax.set_ylabel('SHAP Value')
ax.set_title(title)
ax.legend()
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
plt.tight_layout()
return fig
def plot_heatmap(self, all_tokens, all_shap_values,
title="SHAP Heatmap",
figsize=(12, 8)):
"""绘制热力图"""
fig, ax = plt.subplots(figsize=figsize)
# 转换为numpy数组
shap_matrix = np.array(all_shap_values)
# 绘制热力图
im = ax.imshow(shap_matrix, cmap='RdBu_r', aspect='auto')
# 添加标签
ax.set_yticks(range(len(all_tokens)))
ax.set_yticklabels(all_tokens)
ax.set_xlabel('Sample')
ax.set_ylabel('Token')
ax.set_title(title)
plt.colorbar(im)
plt.tight_layout()
return fig
3. SHAP分析器
class SHAPAnalyzer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.calculator = SHAPCalculator(model, tokenizer)
self.visualizer = SHAPVisualizer()
def analyze_text(self, text, background_texts=None, method='kernel'):
"""分析文本的SHAP值"""
# 编码输入
inputs = self.tokenizer(text, return_tensors="pt")
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# 准备背景数据
if background_texts is None:
background_texts = [text] * 10 # 简化处理
background_data = []
for bg_text in background_texts:
bg_inputs = self.tokenizer(bg_text, return_tensors="pt")
background_data.append(bg_inputs['input_ids'][0])
# 计算SHAP值
if method == 'kernel':
shap_values = self.calculator.compute_kernel_shap(
inputs['input_ids'], background_data
)
elif method == 'exact':
shap_values = self.calculator.compute_exact_shap(inputs['input_ids'])
else:
raise ValueError(f"Unknown method: {method}")
return {
'tokens': tokens,
'shap_values': shap_values,
'method': method
}
def analyze_batch(self, texts, background_texts=None):
"""批量分析文本"""
results = []
for text in texts:
result = self.analyze_text(text, background_texts)
results.append(result)
return results
def compare_methods(self, text, methods=['kernel', 'exact']):
"""比较不同SHAP方法"""
results = {}
for method in methods:
try:
result = self.analyze_text(text, method=method)
results[method] = result
except Exception as e:
print(f"Method {method} failed: {e}")
return results
def get_top_tokens(self, analysis_result, top_k=10):
"""获取最重要的token"""
tokens = analysis_result['tokens']
shap_values = analysis_result['shap_values']
# 按绝对值排序
sorted_indices = np.argsort(np.abs(shap_values))[::-1]
top_indices = sorted_indices[:top_k]
top_tokens = []
for idx in top_indices:
top_tokens.append({
'token': tokens[idx],
'shap_value': shap_values[idx],
'index': idx
})
return top_tokens
def visualize_analysis(self, analysis_result):
"""可视化分析结果"""
tokens = analysis_result['tokens']
shap_values = analysis_result['shap_values']
visualizations = {}
# 瀑布图
visualizations['waterfall'] = self.visualizer.plot_waterfall(
tokens, shap_values, "SHAP Waterfall"
)
# 摘要图
visualizations['summary'] = self.visualizer.plot_summary(
tokens, shap_values, "SHAP Summary"
)
return visualizations
LLM SHAP分析实践
1. 文本分类SHAP分析
class TextClassificationSHAP:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.shap_analyzer = SHAPAnalyzer(model, tokenizer)
def explain_classification(self, text, target_class=None):
"""解释分类预测"""
# 编码文本
inputs = self.tokenizer(text, return_tensors="pt")
# 获取预测
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
if target_class is None:
target_class = logits.argmax(dim=-1).item()
# 计算SHAP值
analysis = self.shap_analyzer.analyze_text(text)
# 计算每个类别的SHAP值
class_shap_values = {}
# 简化:假设是二分类或多分类
n_classes = logits.shape[-1]
for class_idx in range(n_classes):
# 计算该类别的SHAP值
# 这里简化处理,实际需要针对每个类别计算
class_shap_values[class_idx] = analysis['shap_values']
return {
'text': text,
'tokens': analysis['tokens'],
'prediction': target_class,
'confidence': torch.softmax(logits, dim=-1)[0][target_class].item(),
'shap_values': class_shap_values,
'top_tokens': self.shap_analyzer.get_top_tokens(analysis)
}
def visualize_classification_explanation(self, explanation):
"""可视化分类解释"""
visualizer = SHAPVisualizer()
figs = {}
# 为每个类别绘制SHAP值
for class_idx, shap_values in explanation['shap_values'].items():
fig = visualizer.plot_waterfall(
explanation['tokens'],
shap_values,
title=f"SHAP Values for Class {class_idx}"
)
figs[f'class_{class_idx}'] = fig
# 绘制摘要图
# 平均所有类别的SHAP值
avg_shap = np.mean(list(explanation['shap_values'].values()), axis=0)
summary_fig = visualizer.plot_summary(
explanation['tokens'],
avg_shap,
title="SHAP Summary (Average Across Classes)"
)
figs['summary'] = summary_fig
return figs
def batch_explain(self, texts, batch_size=32):
"""批量解释文本"""
explanations = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
for text in batch_texts:
explanation = self.explain_classification(text)
explanations.append(explanation)
return explanations
2. 问答系统SHAP分析
class QASystemSHAP:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.shap_analyzer = SHAPAnalyzer(model, tokenizer)
def explain_qa_prediction(self, question, context):
"""解释问答预测"""
# 编码输入
inputs = self.tokenizer(
question, context,
return_tensors="pt",
max_length=512,
truncation=True
)
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# 获取预测
with torch.no_grad():
outputs = self.model(**inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# 获取答案位置
start_idx = start_logits.argmax(dim=-1).item()
end_idx = end_logits.argmax(dim=-1).item()
# 计算SHAP值
# 简化:对整个输入计算SHAP
full_text = question + " " + context
analysis = self.shap_analyzer.analyze_text(full_text)
return {
'question': question,
'context': context,
'tokens': analysis['tokens'],
'answer': self.tokenizer.decode(
inputs['input_ids'][0][start_idx:end_idx+1]
),
'start_idx': start_idx,
'end_idx': end_idx,
'shap_values': analysis['shap_values'],
'top_tokens': self.shap_analyzer.get_top_tokens(analysis)
}
def visualize_qa_explanation(self, explanation):
"""可视化问答解释"""
visualizer = SHAPVisualizer()
figs = {}
# 绘制SHAP值
figs['waterfall'] = visualizer.plot_waterfall(
explanation['tokens'],
explanation['shap_values'],
title="QA SHAP Values"
)
# 高亮答案token
answer_highlight = self._highlight_answer_tokens(explanation)
figs['answer_highlight'] = answer_highlight
return figs
def _highlight_answer_tokens(self, explanation):
"""高亮答案token"""
tokens = explanation['tokens']
start_idx = explanation['start_idx']
end_idx = explanation['end_idx']
# 创建高亮文本
highlighted = []
for i, token in enumerate(tokens):
if start_idx <= i <= end_idx:
highlighted.append(f"**[{token}]**")
else:
highlighted.append(token)
return ' '.join(highlighted)
def analyze_answer_quality(self, explanation):
"""分析答案质量"""
# 计算答案token的SHAP值
answer_shap = explanation['shap_values'][
explanation['start_idx']:explanation['end_idx']+1
]
# 计算问题token的SHAP值
question_len = len(explanation['question'].split())
question_shap = explanation['shap_values'][:question_len]
quality_metrics = {
'answer_shap_mean': np.mean(answer_shap),
'answer_shap_max': np.max(answer_shap),
'question_shap_mean': np.mean(question_shap),
'answer_importance': np.sum(np.abs(answer_shap)) / np.sum(np.abs(explanation['shap_values']))
}
return quality_metrics
3. 生成模型SHAP分析
class GenerationModelSHAP:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.shap_analyzer = SHAPAnalyzer(model, tokenizer)
def explain_generation(self, prompt, generated_text):
"""解释生成过程"""
# 组合文本
full_text = prompt + generated_text
# 计算SHAP值
analysis = self.shap_analyzer.analyze_text(full_text)
# 分割prompt和generated的token
prompt_tokens = self.tokenizer.tokenize(prompt)
generated_tokens = self.tokenizer.tokenize(generated_text)
prompt_len = len(prompt_tokens)
return {
'prompt': prompt,
'generated_text': generated_text,
'all_tokens': analysis['tokens'],
'prompt_tokens': analysis['tokens'][:prompt_len],
'generated_tokens': analysis['tokens'][prompt_len:],
'shap_values': analysis['shap_values'],
'prompt_shap': analysis['shap_values'][:prompt_len],
'generated_shap': analysis['shap_values'][prompt_len:]
}
def visualize_generation_explanation(self, explanation):
"""可视化生成解释"""
visualizer = SHAPVisualizer()
figs = {}
# 绘制整体SHAP值
figs['overall'] = visualizer.plot_waterfall(
explanation['all_tokens'],
explanation['shap_values'],
title="Overall SHAP Values"
)
# 绘制prompt SHAP值
if len(explanation['prompt_shap']) > 0:
figs['prompt'] = visualizer.plot_waterfall(
explanation['prompt_tokens'],
explanation['prompt_shap'],
title="Prompt SHAP Values"
)
# 绘制generated SHAP值
if len(explanation['generated_shap']) > 0:
figs['generated'] = visualizer.plot_waterfall(
explanation['generated_tokens'],
explanation['generated_shap'],
title="Generated Text SHAP Values"
)
return figs
def analyze_prompt_influence(self, explanation):
"""分析prompt对生成的影响"""
# 计算prompt token的重要性
prompt_importance = np.sum(np.abs(explanation['prompt_shap']))
# 计算generated token的重要性
generated_importance = np.sum(np.abs(explanation['generated_shap']))
# 计算影响比率
total_importance = prompt_importance + generated_importance
if total_importance > 0:
prompt_ratio = prompt_importance / total_importance
generated_ratio = generated_importance / total_importance
else:
prompt_ratio = generated_ratio = 0.5
return {
'prompt_importance': prompt_importance,
'generated_importance': generated_importance,
'prompt_ratio': prompt_ratio,
'generated_ratio': generated_ratio,
'top_prompt_tokens': self._get_top_tokens(
explanation['prompt_tokens'], explanation['prompt_shap']
),
'top_generated_tokens': self._get_top_tokens(
explanation['generated_tokens'], explanation['generated_shap']
)
}
def _get_top_tokens(self, tokens, shap_values, top_k=5):
"""获取最重要的token"""
if len(tokens) == 0 or len(shap_values) == 0:
return []
# 按绝对值排序
sorted_indices = np.argsort(np.abs(shap_values))[::-1]
top_indices = sorted_indices[:top_k]
top_tokens = []
for idx in top_indices:
if idx < len(tokens):
top_tokens.append({
'token': tokens[idx],
'shap_value': shap_values[idx]
})
return top_tokens
实际应用案例
案例:LLM SHAP分析系统
# LLM SHAP分析系统
class LLM_SHAP_Analysis_System:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.text_shap = TextClassificationSHAP(model, tokenizer)
self.qa_shap = QASystemSHAP(model, tokenizer)
self.generation_shap = GenerationModelSHAP(model, tokenizer)
self.visualizer = SHAPVisualizer()
def comprehensive_analysis(self, text, task_type='classification'):
"""综合分析"""
if task_type == 'classification':
return self.text_shap.explain_classification(text)
elif task_type == 'qa':
# 需要question和context
return {"error": "QA任务需要question和context参数"}
elif task_type == 'generation':
return {"error": "生成任务需要prompt和generated_text参数"}
else:
return {"error": f"Unknown task type: {task_type}"}
def generate_analysis_report(self, analysis_result, task_type='classification'):
"""生成分析报告"""
report = {
'task_type': task_type,
'summary': {},
'top_features': [],
'insights': [],
'recommendations': []
}
if task_type == 'classification':
# 分类任务报告
report['summary'] = {
'prediction': analysis_result.get('prediction'),
'confidence': analysis_result.get('confidence'),
'num_tokens': len(analysis_result.get('tokens', []))
}
# 获取top tokens
if 'top_tokens' in analysis_result:
report['top_features'] = analysis_result['top_tokens']
# 生成洞察
report['insights'] = self._generate_classification_insights(analysis_result)
elif task_type == 'qa':
# QA任务报告
report['summary'] = {
'question': analysis_result.get('question'),
'answer': analysis_result.get('answer'),
'answer_length': len(analysis_result.get('answer', '').split())
}
return report
def _generate_classification_insights(self, analysis_result):
"""生成分类洞察"""
insights = []
# 分析top tokens
top_tokens = analysis_result.get('top_tokens', [])
if top_tokens:
# 检查是否有强烈的正面/负面词汇
positive_tokens = [t for t in top_tokens if t['shap_value'] > 0]
negative_tokens = [t for t in top_tokens if t['shap_value'] < 0]
if positive_tokens:
insights.append(f"正面影响词汇: {[t['token'] for t in positive_tokens[:3]]}")
if negative_tokens:
insights.append(f"负面影响词汇: {[t['token'] for t in negative_tokens[:3]]}")
# 检查重要性分布
shap_values = [t['shap_value'] for t in top_tokens]
if np.std(shap_values) > 0.5:
insights.append("词汇重要性差异较大,模型可能过度依赖某些词汇")
return insights
def visualize_comprehensive(self, analysis_result):
"""综合可视化"""
visualizations = {}
if 'tokens' in analysis_result and 'shap_values' in analysis_result:
# 瀑布图
visualizations['waterfall'] = self.visualizer.plot_waterfall(
analysis_result['tokens'],
analysis_result['shap_values'],
"SHAP Waterfall Plot"
)
# 摘要图
visualizations['summary'] = self.visualizer.plot_summary(
analysis_result['tokens'],
analysis_result['shap_values'],
"SHAP Summary Plot"
)
return visualizations
def compare_texts(self, texts, task_type='classification'):
"""比较多个文本的SHAP分析"""
results = []
for text in texts:
analysis = self.comprehensive_analysis(text, task_type)
results.append({
'text': text,
'analysis': analysis
})
# 比较分析
comparison = self._compare_analyses(results)
return {
'individual_results': results,
'comparison': comparison
}
def _compare_analyses(self, results):
"""比较分析结果"""
comparison = {
'common_features': [],
'distinguishing_features': [],
'pattern_analysis': {}
}
# 收集所有top features
all_features = []
for result in results:
if 'top_tokens' in result['analysis']:
features = [t['token'] for t in result['analysis']['top_tokens']]
all_features.append(set(features))
if all_features:
# 找出共同特征
common = all_features[0]
for features in all_features[1:]:
common = common.intersection(features)
comparison['common_features'] = list(common)
return comparison
# 使用示例
# system = LLM_SHAP_Analysis_System(model, tokenizer)
#
# # 分析文本分类
# analysis = system.comprehensive_analysis("This movie was fantastic!")
# report = system.generate_analysis_report(analysis)
# visualizations = system.visualize_comprehensive(analysis)
#
# # 比较多个文本
# texts = ["Great product!", "Terrible service.", "Average experience."]
# comparison = system.compare_texts(texts)
总结
SHAP值是解释LLM的强大工具:
- 理论基础 - 基于博弈论,具有坚实的数学基础
- 公平归因 - 满足对称性、效率性等公理
- 全局解释 - 可以分析模型整体行为
- 局部解释 - 可以解释单个预测
- 特征交互 - 可以分析特征间的交互效应
通过SHAP分析,我们可以更好地理解LLM如何做出决策,提高模型的透明度和可信度,为模型优化和调试提供有力支持。