← 返回首页
🧠

SHAP值在LLM解释中的应用

📂 llm ⏱ 9 min 1606 words

--- title: "SHAP值在LLM解释中的应用" description: "介绍SHAP值在大型语言模型解释和特征重要性分析中的应用。" tags: ["SHAP", "shapley值", "llm", "模型解释", "特征重要性"] category: "llm" icon: "🧠"

SHAP值在LLM解释中的应用

什么是SHAP值?

SHAP(SHapley Additive exPlanations)是一种基于博弈论的特征归因方法,通过计算每个特征对预测的边际贡献来解释模型预测。

SHAP值原理

1. 基本SHAP计算

import numpy as np
import torch
from itertools import combinations

class SHAPCalculator:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def compute_kernel_shap(self, input_ids, background_data, n_samples=100):
        """计算核SHAP值"""
        # 获取预测
        with torch.no_grad():
            target_output = self.model(input_ids)
            target_pred = self._get_prediction(target_output)
        
        # 获取背景数据的预测
        background_preds = []
        for bg in background_data:
            with torch.no_grad():
                bg_output = self.model(bg.unsqueeze(0))
                background_preds.append(self._get_prediction(bg_output))
        
        background_pred = np.mean(background_preds)
        
        # 计算SHAP值
        n_features = input_ids.shape[1]
        shap_values = np.zeros(n_features)
        
        # 使用采样近似
        for _ in range(n_samples):
            # 随机掩码
            mask = np.random.binomial(1, 0.5, n_features)
            
            # 创建掩码输入
            masked_input = input_ids.clone()
            masked_input[0][mask == 0] = self.tokenizer.pad_token_id
            
            # 计算掩码输入的预测
            with torch.no_grad():
                masked_output = self.model(masked_input)
                masked_pred = self._get_prediction(masked_output)
            
            # 计算边际贡献
            for i in range(n_features):
                if mask[i] == 1:
                    # 添加特征i的贡献
                    shap_values[i] += (masked_pred - background_pred) / n_samples
        
        return shap_values
    
    def compute_exact_shap(self, input_ids):
        """计算精确SHAP值(适用于小特征集)"""
        n_features = input_ids.shape[1]
        
        # 获取完整预测
        with torch.no_grad():
            full_output = self.model(input_ids)
            full_pred = self._get_prediction(full_output)
        
        shap_values = np.zeros(n_features)
        
        # 遍历所有可能的特征子集
        for i in range(n_features):
            # 计算包含特征i和不包含特征i的联盟
            for subset_size in range(n_features):
                for subset in combinations(range(n_features), subset_size):
                    # 创建掩码输入
                    masked_input = input_ids.clone()
                    for j in range(n_features):
                        if j not in subset and j != i:
                            masked_input[0][j] = self.tokenizer.pad_token_id
                    
                    # 计算预测
                    with torch.no_grad():
                        masked_output = self.model(masked_input)
                        masked_pred = self._get_prediction(masked_output)
                    
                    # 添加特征i到子集
                    subset_with_i = tuple(sorted(subset + (i,)))
                    masked_input_with_i = input_ids.clone()
                    for j in range(n_features):
                        if j not in subset_with_i:
                            masked_input_with_i[0][j] = self.tokenizer.pad_token_id
                    
                    with torch.no_grad():
                        masked_output_with_i = self.model(masked_input_with_i)
                        pred_with_i = self._get_prediction(masked_output_with_i)
                    
                    # 计算边际贡献
                    weight = (np.math.factorial(subset_size) * 
                             np.math.factorial(n_features - subset_size - 1) / 
                             np.math.factorial(n_features))
                    
                    shap_values[i] += weight * (pred_with_i - masked_pred)
        
        return shap_values
    
    def _get_prediction(self, output):
        """获取模型预测"""
        if hasattr(output, 'logits'):
            logits = output.logits
        else:
            logits = output
        
        # 返回最大概率
        probs = torch.softmax(logits, dim=-1)
        return probs.max().item()

2. SHAP值可视化

class SHAPVisualizer:
    def __init__(self):
        self.figures = {}
    
    def plot_waterfall(self, tokens, shap_values, 
                       title="SHAP Waterfall Plot",
                       figsize=(12, 6)):
        """绘制瀑布图"""
        fig, ax = plt.subplots(figsize=figsize)
        
        # 按SHAP值排序
        sorted_indices = np.argsort(shap_values)
        sorted_tokens = [tokens[i] for i in sorted_indices]
        sorted_values = shap_values[sorted_indices]
        
        # 绘制瀑布图
        colors = ['red' if v < 0 else 'blue' for v in sorted_values]
        
        # 绘制条形图
        y_pos = range(len(sorted_tokens))
        ax.barh(y_pos, sorted_values, color=colors, alpha=0.7)
        
        # 添加标签
        ax.set_yticks(y_pos)
        ax.set_yticklabels(sorted_tokens)
        ax.set_xlabel('SHAP Value')
        ax.set_title(title)
        ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
        
        plt.tight_layout()
        return fig
    
    def plot_summary(self, tokens, shap_values, 
                     title="SHAP Summary Plot",
                     figsize=(12, 8)):
        """绘制摘要图"""
        fig, ax = plt.subplots(figsize=figsize)
        
        # 按SHAP值排序
        sorted_indices = np.argsort(np.abs(shap_values))[::-1]
        top_k = min(10, len(sorted_indices))
        top_indices = sorted_indices[:top_k]
        
        top_tokens = [tokens[i] for i in top_indices]
        top_values = shap_values[top_indices]
        
        # 绘制条形图
        colors = ['red' if v < 0 else 'blue' for v in top_values]
        ax.barh(range(len(top_tokens)), top_values, color=colors, alpha=0.7)
        
        ax.set_yticks(range(len(top_tokens)))
        ax.set_yticklabels(top_tokens)
        ax.set_xlabel('Mean |SHAP Value|')
        ax.set_title(title)
        ax.invert_yaxis()
        
        plt.tight_layout()
        return fig
    
    def plot_dependence(self, tokens, shap_values, feature_idx,
                        title="SHAP Dependence Plot",
                        figsize=(10, 6)):
        """绘制依赖图"""
        fig, ax = plt.subplots(figsize=figsize)
        
        # 绘制散点图
        x = np.arange(len(tokens))
        y = shap_values
        
        colors = ['red' if v < 0 else 'blue' for v in y]
        ax.scatter(x, y, c=colors, alpha=0.6)
        
        # 高亮指定特征
        ax.scatter([feature_idx], [shap_values[feature_idx]], 
                  color='green', s=100, zorder=5, label='Selected Feature')
        
        ax.set_xticks(x)
        ax.set_xticklabels(tokens, rotation=45, ha='right')
        ax.set_ylabel('SHAP Value')
        ax.set_title(title)
        ax.legend()
        ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        
        plt.tight_layout()
        return fig
    
    def plot_heatmap(self, all_tokens, all_shap_values,
                     title="SHAP Heatmap",
                     figsize=(12, 8)):
        """绘制热力图"""
        fig, ax = plt.subplots(figsize=figsize)
        
        # 转换为numpy数组
        shap_matrix = np.array(all_shap_values)
        
        # 绘制热力图
        im = ax.imshow(shap_matrix, cmap='RdBu_r', aspect='auto')
        
        # 添加标签
        ax.set_yticks(range(len(all_tokens)))
        ax.set_yticklabels(all_tokens)
        ax.set_xlabel('Sample')
        ax.set_ylabel('Token')
        ax.set_title(title)
        
        plt.colorbar(im)
        plt.tight_layout()
        return fig

3. SHAP分析器

class SHAPAnalyzer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.calculator = SHAPCalculator(model, tokenizer)
        self.visualizer = SHAPVisualizer()
    
    def analyze_text(self, text, background_texts=None, method='kernel'):
        """分析文本的SHAP值"""
        # 编码输入
        inputs = self.tokenizer(text, return_tensors="pt")
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # 准备背景数据
        if background_texts is None:
            background_texts = [text] * 10  # 简化处理
        
        background_data = []
        for bg_text in background_texts:
            bg_inputs = self.tokenizer(bg_text, return_tensors="pt")
            background_data.append(bg_inputs['input_ids'][0])
        
        # 计算SHAP值
        if method == 'kernel':
            shap_values = self.calculator.compute_kernel_shap(
                inputs['input_ids'], background_data
            )
        elif method == 'exact':
            shap_values = self.calculator.compute_exact_shap(inputs['input_ids'])
        else:
            raise ValueError(f"Unknown method: {method}")
        
        return {
            'tokens': tokens,
            'shap_values': shap_values,
            'method': method
        }
    
    def analyze_batch(self, texts, background_texts=None):
        """批量分析文本"""
        results = []
        
        for text in texts:
            result = self.analyze_text(text, background_texts)
            results.append(result)
        
        return results
    
    def compare_methods(self, text, methods=['kernel', 'exact']):
        """比较不同SHAP方法"""
        results = {}
        
        for method in methods:
            try:
                result = self.analyze_text(text, method=method)
                results[method] = result
            except Exception as e:
                print(f"Method {method} failed: {e}")
        
        return results
    
    def get_top_tokens(self, analysis_result, top_k=10):
        """获取最重要的token"""
        tokens = analysis_result['tokens']
        shap_values = analysis_result['shap_values']
        
        # 按绝对值排序
        sorted_indices = np.argsort(np.abs(shap_values))[::-1]
        top_indices = sorted_indices[:top_k]
        
        top_tokens = []
        for idx in top_indices:
            top_tokens.append({
                'token': tokens[idx],
                'shap_value': shap_values[idx],
                'index': idx
            })
        
        return top_tokens
    
    def visualize_analysis(self, analysis_result):
        """可视化分析结果"""
        tokens = analysis_result['tokens']
        shap_values = analysis_result['shap_values']
        
        visualizations = {}
        
        # 瀑布图
        visualizations['waterfall'] = self.visualizer.plot_waterfall(
            tokens, shap_values, "SHAP Waterfall"
        )
        
        # 摘要图
        visualizations['summary'] = self.visualizer.plot_summary(
            tokens, shap_values, "SHAP Summary"
        )
        
        return visualizations

LLM SHAP分析实践

1. 文本分类SHAP分析

class TextClassificationSHAP:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.shap_analyzer = SHAPAnalyzer(model, tokenizer)
    
    def explain_classification(self, text, target_class=None):
        """解释分类预测"""
        # 编码文本
        inputs = self.tokenizer(text, return_tensors="pt")
        
        # 获取预测
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
        
        if target_class is None:
            target_class = logits.argmax(dim=-1).item()
        
        # 计算SHAP值
        analysis = self.shap_analyzer.analyze_text(text)
        
        # 计算每个类别的SHAP值
        class_shap_values = {}
        
        # 简化:假设是二分类或多分类
        n_classes = logits.shape[-1]
        
        for class_idx in range(n_classes):
            # 计算该类别的SHAP值
            # 这里简化处理,实际需要针对每个类别计算
            class_shap_values[class_idx] = analysis['shap_values']
        
        return {
            'text': text,
            'tokens': analysis['tokens'],
            'prediction': target_class,
            'confidence': torch.softmax(logits, dim=-1)[0][target_class].item(),
            'shap_values': class_shap_values,
            'top_tokens': self.shap_analyzer.get_top_tokens(analysis)
        }
    
    def visualize_classification_explanation(self, explanation):
        """可视化分类解释"""
        visualizer = SHAPVisualizer()
        
        figs = {}
        
        # 为每个类别绘制SHAP值
        for class_idx, shap_values in explanation['shap_values'].items():
            fig = visualizer.plot_waterfall(
                explanation['tokens'],
                shap_values,
                title=f"SHAP Values for Class {class_idx}"
            )
            figs[f'class_{class_idx}'] = fig
        
        # 绘制摘要图
        # 平均所有类别的SHAP值
        avg_shap = np.mean(list(explanation['shap_values'].values()), axis=0)
        summary_fig = visualizer.plot_summary(
            explanation['tokens'],
            avg_shap,
            title="SHAP Summary (Average Across Classes)"
        )
        figs['summary'] = summary_fig
        
        return figs
    
    def batch_explain(self, texts, batch_size=32):
        """批量解释文本"""
        explanations = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            
            for text in batch_texts:
                explanation = self.explain_classification(text)
                explanations.append(explanation)
        
        return explanations

2. 问答系统SHAP分析

class QASystemSHAP:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.shap_analyzer = SHAPAnalyzer(model, tokenizer)
    
    def explain_qa_prediction(self, question, context):
        """解释问答预测"""
        # 编码输入
        inputs = self.tokenizer(
            question, context,
            return_tensors="pt",
            max_length=512,
            truncation=True
        )
        
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # 获取预测
        with torch.no_grad():
            outputs = self.model(**inputs)
            start_logits = outputs.start_logits
            end_logits = outputs.end_logits
        
        # 获取答案位置
        start_idx = start_logits.argmax(dim=-1).item()
        end_idx = end_logits.argmax(dim=-1).item()
        
        # 计算SHAP值
        # 简化:对整个输入计算SHAP
        full_text = question + " " + context
        analysis = self.shap_analyzer.analyze_text(full_text)
        
        return {
            'question': question,
            'context': context,
            'tokens': analysis['tokens'],
            'answer': self.tokenizer.decode(
                inputs['input_ids'][0][start_idx:end_idx+1]
            ),
            'start_idx': start_idx,
            'end_idx': end_idx,
            'shap_values': analysis['shap_values'],
            'top_tokens': self.shap_analyzer.get_top_tokens(analysis)
        }
    
    def visualize_qa_explanation(self, explanation):
        """可视化问答解释"""
        visualizer = SHAPVisualizer()
        
        figs = {}
        
        # 绘制SHAP值
        figs['waterfall'] = visualizer.plot_waterfall(
            explanation['tokens'],
            explanation['shap_values'],
            title="QA SHAP Values"
        )
        
        # 高亮答案token
        answer_highlight = self._highlight_answer_tokens(explanation)
        figs['answer_highlight'] = answer_highlight
        
        return figs
    
    def _highlight_answer_tokens(self, explanation):
        """高亮答案token"""
        tokens = explanation['tokens']
        start_idx = explanation['start_idx']
        end_idx = explanation['end_idx']
        
        # 创建高亮文本
        highlighted = []
        for i, token in enumerate(tokens):
            if start_idx <= i <= end_idx:
                highlighted.append(f"**[{token}]**")
            else:
                highlighted.append(token)
        
        return ' '.join(highlighted)
    
    def analyze_answer_quality(self, explanation):
        """分析答案质量"""
        # 计算答案token的SHAP值
        answer_shap = explanation['shap_values'][
            explanation['start_idx']:explanation['end_idx']+1
        ]
        
        # 计算问题token的SHAP值
        question_len = len(explanation['question'].split())
        question_shap = explanation['shap_values'][:question_len]
        
        quality_metrics = {
            'answer_shap_mean': np.mean(answer_shap),
            'answer_shap_max': np.max(answer_shap),
            'question_shap_mean': np.mean(question_shap),
            'answer_importance': np.sum(np.abs(answer_shap)) / np.sum(np.abs(explanation['shap_values']))
        }
        
        return quality_metrics

3. 生成模型SHAP分析

class GenerationModelSHAP:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.shap_analyzer = SHAPAnalyzer(model, tokenizer)
    
    def explain_generation(self, prompt, generated_text):
        """解释生成过程"""
        # 组合文本
        full_text = prompt + generated_text
        
        # 计算SHAP值
        analysis = self.shap_analyzer.analyze_text(full_text)
        
        # 分割prompt和generated的token
        prompt_tokens = self.tokenizer.tokenize(prompt)
        generated_tokens = self.tokenizer.tokenize(generated_text)
        
        prompt_len = len(prompt_tokens)
        
        return {
            'prompt': prompt,
            'generated_text': generated_text,
            'all_tokens': analysis['tokens'],
            'prompt_tokens': analysis['tokens'][:prompt_len],
            'generated_tokens': analysis['tokens'][prompt_len:],
            'shap_values': analysis['shap_values'],
            'prompt_shap': analysis['shap_values'][:prompt_len],
            'generated_shap': analysis['shap_values'][prompt_len:]
        }
    
    def visualize_generation_explanation(self, explanation):
        """可视化生成解释"""
        visualizer = SHAPVisualizer()
        
        figs = {}
        
        # 绘制整体SHAP值
        figs['overall'] = visualizer.plot_waterfall(
            explanation['all_tokens'],
            explanation['shap_values'],
            title="Overall SHAP Values"
        )
        
        # 绘制prompt SHAP值
        if len(explanation['prompt_shap']) > 0:
            figs['prompt'] = visualizer.plot_waterfall(
                explanation['prompt_tokens'],
                explanation['prompt_shap'],
                title="Prompt SHAP Values"
            )
        
        # 绘制generated SHAP值
        if len(explanation['generated_shap']) > 0:
            figs['generated'] = visualizer.plot_waterfall(
                explanation['generated_tokens'],
                explanation['generated_shap'],
                title="Generated Text SHAP Values"
            )
        
        return figs
    
    def analyze_prompt_influence(self, explanation):
        """分析prompt对生成的影响"""
        # 计算prompt token的重要性
        prompt_importance = np.sum(np.abs(explanation['prompt_shap']))
        
        # 计算generated token的重要性
        generated_importance = np.sum(np.abs(explanation['generated_shap']))
        
        # 计算影响比率
        total_importance = prompt_importance + generated_importance
        if total_importance > 0:
            prompt_ratio = prompt_importance / total_importance
            generated_ratio = generated_importance / total_importance
        else:
            prompt_ratio = generated_ratio = 0.5
        
        return {
            'prompt_importance': prompt_importance,
            'generated_importance': generated_importance,
            'prompt_ratio': prompt_ratio,
            'generated_ratio': generated_ratio,
            'top_prompt_tokens': self._get_top_tokens(
                explanation['prompt_tokens'], explanation['prompt_shap']
            ),
            'top_generated_tokens': self._get_top_tokens(
                explanation['generated_tokens'], explanation['generated_shap']
            )
        }
    
    def _get_top_tokens(self, tokens, shap_values, top_k=5):
        """获取最重要的token"""
        if len(tokens) == 0 or len(shap_values) == 0:
            return []
        
        # 按绝对值排序
        sorted_indices = np.argsort(np.abs(shap_values))[::-1]
        top_indices = sorted_indices[:top_k]
        
        top_tokens = []
        for idx in top_indices:
            if idx < len(tokens):
                top_tokens.append({
                    'token': tokens[idx],
                    'shap_value': shap_values[idx]
                })
        
        return top_tokens

实际应用案例

案例:LLM SHAP分析系统

# LLM SHAP分析系统
class LLM_SHAP_Analysis_System:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.text_shap = TextClassificationSHAP(model, tokenizer)
        self.qa_shap = QASystemSHAP(model, tokenizer)
        self.generation_shap = GenerationModelSHAP(model, tokenizer)
        self.visualizer = SHAPVisualizer()
    
    def comprehensive_analysis(self, text, task_type='classification'):
        """综合分析"""
        if task_type == 'classification':
            return self.text_shap.explain_classification(text)
        elif task_type == 'qa':
            # 需要question和context
            return {"error": "QA任务需要question和context参数"}
        elif task_type == 'generation':
            return {"error": "生成任务需要prompt和generated_text参数"}
        else:
            return {"error": f"Unknown task type: {task_type}"}
    
    def generate_analysis_report(self, analysis_result, task_type='classification'):
        """生成分析报告"""
        report = {
            'task_type': task_type,
            'summary': {},
            'top_features': [],
            'insights': [],
            'recommendations': []
        }
        
        if task_type == 'classification':
            # 分类任务报告
            report['summary'] = {
                'prediction': analysis_result.get('prediction'),
                'confidence': analysis_result.get('confidence'),
                'num_tokens': len(analysis_result.get('tokens', []))
            }
            
            # 获取top tokens
            if 'top_tokens' in analysis_result:
                report['top_features'] = analysis_result['top_tokens']
            
            # 生成洞察
            report['insights'] = self._generate_classification_insights(analysis_result)
        
        elif task_type == 'qa':
            # QA任务报告
            report['summary'] = {
                'question': analysis_result.get('question'),
                'answer': analysis_result.get('answer'),
                'answer_length': len(analysis_result.get('answer', '').split())
            }
        
        return report
    
    def _generate_classification_insights(self, analysis_result):
        """生成分类洞察"""
        insights = []
        
        # 分析top tokens
        top_tokens = analysis_result.get('top_tokens', [])
        
        if top_tokens:
            # 检查是否有强烈的正面/负面词汇
            positive_tokens = [t for t in top_tokens if t['shap_value'] > 0]
            negative_tokens = [t for t in top_tokens if t['shap_value'] < 0]
            
            if positive_tokens:
                insights.append(f"正面影响词汇: {[t['token'] for t in positive_tokens[:3]]}")
            
            if negative_tokens:
                insights.append(f"负面影响词汇: {[t['token'] for t in negative_tokens[:3]]}")
            
            # 检查重要性分布
            shap_values = [t['shap_value'] for t in top_tokens]
            if np.std(shap_values) > 0.5:
                insights.append("词汇重要性差异较大,模型可能过度依赖某些词汇")
        
        return insights
    
    def visualize_comprehensive(self, analysis_result):
        """综合可视化"""
        visualizations = {}
        
        if 'tokens' in analysis_result and 'shap_values' in analysis_result:
            # 瀑布图
            visualizations['waterfall'] = self.visualizer.plot_waterfall(
                analysis_result['tokens'],
                analysis_result['shap_values'],
                "SHAP Waterfall Plot"
            )
            
            # 摘要图
            visualizations['summary'] = self.visualizer.plot_summary(
                analysis_result['tokens'],
                analysis_result['shap_values'],
                "SHAP Summary Plot"
            )
        
        return visualizations
    
    def compare_texts(self, texts, task_type='classification'):
        """比较多个文本的SHAP分析"""
        results = []
        
        for text in texts:
            analysis = self.comprehensive_analysis(text, task_type)
            results.append({
                'text': text,
                'analysis': analysis
            })
        
        # 比较分析
        comparison = self._compare_analyses(results)
        
        return {
            'individual_results': results,
            'comparison': comparison
        }
    
    def _compare_analyses(self, results):
        """比较分析结果"""
        comparison = {
            'common_features': [],
            'distinguishing_features': [],
            'pattern_analysis': {}
        }
        
        # 收集所有top features
        all_features = []
        for result in results:
            if 'top_tokens' in result['analysis']:
                features = [t['token'] for t in result['analysis']['top_tokens']]
                all_features.append(set(features))
        
        if all_features:
            # 找出共同特征
            common = all_features[0]
            for features in all_features[1:]:
                common = common.intersection(features)
            comparison['common_features'] = list(common)
        
        return comparison

# 使用示例
# system = LLM_SHAP_Analysis_System(model, tokenizer)
# 
# # 分析文本分类
# analysis = system.comprehensive_analysis("This movie was fantastic!")
# report = system.generate_analysis_report(analysis)
# visualizations = system.visualize_comprehensive(analysis)
# 
# # 比较多个文本
# texts = ["Great product!", "Terrible service.", "Average experience."]
# comparison = system.compare_texts(texts)

总结

SHAP值是解释LLM的强大工具:

  1. 理论基础 - 基于博弈论,具有坚实的数学基础
  2. 公平归因 - 满足对称性、效率性等公理
  3. 全局解释 - 可以分析模型整体行为
  4. 局部解释 - 可以解释单个预测
  5. 特征交互 - 可以分析特征间的交互效应

通过SHAP分析,我们可以更好地理解LLM如何做出决策,提高模型的透明度和可信度,为模型优化和调试提供有力支持。