← 返回首页
🧠

Alpaca数据集

📂 llm ⏱ 6 min 1071 words

--- title: "Alpaca数据集" description: "Stanford Alpaca数据集详解,包括指令微调数据格式、生成方法和应用场景" tags: ["Alpaca", "指令微调", "数据集", "Stanford"] category: "llm" icon: "🧠"

Alpaca数据集

什么是Alpaca数据集

Alpaca数据集是Stanford大学开发的指令微调数据集,包含52K条由GPT-3.5生成的指令-输出对。它为开源LLM的指令微调提供了重要资源。

数据格式

1. 标准数据结构

from typing import List, Dict
import json

class AlpacaDataFormat:
    def __init__(self):
        self.sample_structure = {
            "instruction": "将以下英文翻译成中文",
            "input": "Hello, how are you?",
            "output": "你好,你好吗?"
        }
    
    def validate_format(self, data: Dict) -> bool:
        """验证数据格式"""
        required_fields = ["instruction", "output"]
        return all(field in data for field in required_fields)
    
    def format_for_training(self, data: List[Dict]) -> List[Dict]:
        """格式化为训练格式"""
        formatted_data = []
        
        for item in data:
            if not self.validate_format(item):
                continue
            
            formatted_item = {
                "instruction": item["instruction"],
                "input": item.get("input", ""),
                "output": item["output"]
            }
            formatted_data.append(formatted_item)
        
        return formatted_data
    
    def create_prompt_template(self) -> str:
        """创建提示模板"""
        return """### Instruction:
{instruction}

### Input:
{input}

### Response:
{output}"""

# 数据示例
alpaca_examples = [
    {
        "instruction": "写一首关于春天的诗",
        "input": "",
        "output": "春风拂面花满枝,\n燕子归来筑新巢。\n万物复苏生机盎,\n人间四月芳菲时。"
    },
    {
        "instruction": "解释什么是机器学习",
        "input": "",
        "output": "机器学习是人工智能的一个分支,它使计算机能够从数据中学习,而无需显式编程。通过算法分析数据模式,机器学习模型可以做出预测或决策。"
    },
    {
        "instruction": "将以下JSON数据转换为CSV格式",
        "input": '[{"name": "张三", "age": 25}, {"name": "李四", "age": 30}]',
        "output": "name,age\n张三,25\n李四,30"
    }
]

2. 数据加载与处理

import pandas as pd
from datasets import Dataset, DatasetDict

class AlpacaDataLoader:
    def __init__(self):
        self.data_path = None
        self.dataset = None
    
    def load_from_json(self, file_path: str) -> List[Dict]:
        """从JSON文件加载数据"""
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        self.data_path = file_path
        return data
    
    def load_from_csv(self, file_path: str) -> List[Dict]:
        """从CSV文件加载数据"""
        df = pd.read_csv(file_path)
        return df.to_dict('records')
    
    def load_dataset(self, split: str = "train") -> Dataset:
        """加载为HuggingFace数据集"""
        if self.data_path:
            self.dataset = Dataset.from_json(self.data_path)
            return self.dataset
        return None
    
    def preprocess_data(self, data: List[Dict]) -> List[Dict]:
        """预处理数据"""
        processed_data = []
        
        for item in data:
            processed_item = {
                "instruction": self.clean_text(item.get("instruction", "")),
                "input": self.clean_text(item.get("input", "")),
                "output": self.clean_text(item.get("output", ""))
            }
            
            # 过滤空数据
            if processed_item["instruction"] and processed_item["output"]:
                processed_data.append(processed_item)
        
        return processed_data
    
    def clean_text(self, text: str) -> str:
        """清理文本"""
        # 移除多余空白
        text = text.strip()
        # 规范化换行符
        text = text.replace('\r\n', '\n').replace('\r', '\n')
        return text
    
    def create_splits(self, data: List[Dict], train_ratio: float = 0.9) -> DatasetDict:
        """创建训练/验证分割"""
        import random
        
        # 随机打乱数据
        shuffled_data = data.copy()
        random.shuffle(shuffled_data)
        
        # 分割数据
        split_idx = int(len(shuffled_data) * train_ratio)
        train_data = shuffled_data[:split_idx]
        val_data = shuffled_data[split_idx:]
        
        # 创建数据集
        train_dataset = Dataset.from_list(train_data)
        val_dataset = Dataset.from_list(val_data)
        
        return DatasetDict({
            "train": train_dataset,
            "validation": val_dataset
        })

数据生成方法

1. 使用GPT-3.5生成

import openai
from typing import List, Dict

class AlpacaDataGenerator:
    def __init__(self, api_key: str):
        self.client = openai.OpenAI(api_key=api_key)
        self.generation_prompt = """
        请生成一个指令-输出对,用于训练AI助手。
        
        要求:
        1. 指令应该清晰明确
        2. 输出应该详细准确
        3. 涵盖不同领域和任务类型
        
        请按照以下格式生成:
        Instruction: [指令]
        Input: [可选输入]
        Output: [输出]
        """
    
    def generate_instruction_pair(self) -> Dict:
        """生成单个指令对"""
        response = self.client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "你是一个数据生成助手"},
                {"role": "user", "content": self.generation_prompt}
            ],
            temperature=0.8,
            max_tokens=200
        )
        
        content = response.choices[0].message.content
        return self.parse_response(content)
    
    def parse_response(self, response: str) -> Dict:
        """解析生成的响应"""
        lines = response.strip().split('\n')
        result = {"instruction": "", "input": "", "output": ""}
        
        current_field = None
        for line in lines:
            line = line.strip()
            if line.startswith("Instruction:"):
                current_field = "instruction"
                result["instruction"] = line[len("Instruction:"):].strip()
            elif line.startswith("Input:"):
                current_field = "input"
                result["input"] = line[len("Input:"):].strip()
            elif line.startswith("Output:"):
                current_field = "output"
                result["output"] = line[len("Output:"):].strip()
            elif current_field:
                result[current_field] += " " + line
        
        return result
    
    def generate_dataset(self, num_samples: int = 100) -> List[Dict]:
        """生成数据集"""
        dataset = []
        
        for i in range(num_samples):
            print(f"生成第 {i+1}/{num_samples} 个样本...")
            instruction_pair = self.generate_instruction_pair()
            
            if instruction_pair["instruction"] and instruction_pair["output"]:
                dataset.append(instruction_pair)
        
        return dataset
    
    def save_dataset(self, data: List[Dict], file_path: str):
        """保存数据集"""
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

# 使用示例
generator = AlpacaDataGenerator(api_key="your-api-key")
dataset = generator.generate_dataset(50)
generator.save_dataset(dataset, "alpaca_generated.json")

2. 模板化生成

class TemplateBasedGenerator:
    def __init__(self):
        self.templates = {
            "translation": {
                "instruction": "将以下{source_lang}文本翻译成{target_lang}",
                "input_samples": [
                    "Hello, how are you?",
                    "Good morning",
                    "Thank you very much"
                ],
                "output_samples": [
                    "你好,你好吗?",
                    "早上好",
                    "非常感谢"
                ]
            },
            "summarization": {
                "instruction": "请总结以下文本的主要内容",
                "input_samples": [
                    "人工智能是计算机科学的一个分支,致力于创建能够执行通常需要人类智能才能完成的任务的系统。",
                    "机器学习是人工智能的核心技术之一,它使计算机能够从数据中学习。"
                ],
                "output_samples": [
                    "人工智能是创建智能系统的计算机科学分支。",
                    "机器学习让计算机从数据中学习,是AI的核心技术。"
                ]
            },
            "question_answering": {
                "instruction": "根据以下信息回答问题",
                "input_samples": [
                    "北京是中国的首都,人口约2100万。",
                    "Python是一种高级编程语言,由Guido van Rossum于1991年创建。"
                ],
                "output_samples": [
                    "中国的首都是北京。",
                    "Python由Guido van Rossum于1991年创建。"
                ]
            }
        }
    
    def generate_from_template(self, template_name: str, variations: int = 5) -> List[Dict]:
        """从模板生成数据"""
        template = self.templates.get(template_name)
        if not template:
            return []
        
        generated_data = []
        
        for i in range(variations):
            # 轮换使用不同的样本
            input_idx = i % len(template["input_samples"])
            output_idx = i % len(template["output_samples"])
            
            data_item = {
                "instruction": template["instruction"],
                "input": template["input_samples"][input_idx],
                "output": template["output_samples"][output_idx],
                "template": template_name
            }
            generated_data.append(data_item)
        
        return generated_data
    
    def generate_all_templates(self, variations_per_template: int = 3) -> List[Dict]:
        """从所有模板生成数据"""
        all_data = []
        
        for template_name in self.templates:
            template_data = self.generate_from_template(
                template_name, variations_per_template
            )
            all_data.extend(template_data)
        
        return all_data

数据质量控制

class AlpacaDataQualityControl:
    def __init__(self):
        self.quality_metrics = {
            "completeness": 0.8,
            "relevance": 0.7,
            "diversity": 0.6,
            "accuracy": 0.9
        }
    
    def assess_quality(self, data: List[Dict]) -> Dict:
        """评估数据质量"""
        quality_report = {
            "total_samples": len(data),
            "quality_scores": {},
            "issues": [],
            "recommendations": []
        }
        
        # 评估各个维度
        completeness_score = self.assess_completeness(data)
        relevance_score = self.assess_relevance(data)
        diversity_score = self.assess_diversity(data)
        accuracy_score = self.assess_accuracy(data)
        
        quality_report["quality_scores"] = {
            "completeness": completeness_score,
            "relevance": relevance_score,
            "diversity": diversity_score,
            "accuracy": accuracy_score
        }
        
        # 识别问题
        quality_report["issues"] = self.identify_issues(data)
        
        # 生成建议
        quality_report["recommendations"] = self.generate_recommendations(
            quality_report["quality_scores"]
        )
        
        return quality_report
    
    def assess_completeness(self, data: List[Dict]) -> float:
        """评估完整性"""
        complete_samples = 0
        for sample in data:
            if (sample.get("instruction") and 
                sample.get("output") and 
                len(sample["instruction"]) > 5 and 
                len(sample["output"]) > 10):
                complete_samples += 1
        
        return complete_samples / len(data) if data else 0
    
    def assess_relevance(self, data: List[Dict]) -> float:
        """评估相关性"""
        # 简化实现:检查指令和输出的关键词重叠
        relevant_samples = 0
        for sample in data:
            instruction_words = set(sample.get("instruction", "").split())
            output_words = set(sample.get("output", "").split())
            
            # 计算词汇重叠度
            overlap = len(instruction_words.intersection(output_words))
            total = len(instruction_words.union(output_words))
            
            if total > 0 and overlap / total > 0.1:
                relevant_samples += 1
        
        return relevant_samples / len(data) if data else 0
    
    def assess_diversity(self, data: List[Dict]) -> float:
        """评估多样性"""
        instructions = [sample.get("instruction", "") for sample in data]
        
        # 计算指令的独特性
        unique_instructions = set(instructions)
        diversity_score = len(unique_instructions) / len(instructions) if instructions else 0
        
        return diversity_score
    
    def assess_accuracy(self, data: List[Dict]) -> float:
        """评估准确性"""
        # 简化实现:检查输出是否包含常见错误模式
        accurate_samples = 0
        error_patterns = ["我不知道", "无法回答", "错误"]
        
        for sample in data:
            output = sample.get("output", "")
            if not any(pattern in output for pattern in error_patterns):
                accurate_samples += 1
        
        return accurate_samples / len(data) if data else 0
    
    def identify_issues(self, data: List[Dict]) -> List[str]:
        """识别数据问题"""
        issues = []
        
        # 检查空值
        empty_instructions = sum(1 for s in data if not s.get("instruction"))
        if empty_instructions > 0:
            issues.append(f"发现 {empty_instructions} 个空指令")
        
        # 检查重复
        instructions = [s.get("instruction", "") for s in data]
        unique_instructions = set(instructions)
        if len(unique_instructions) < len(instructions):
            duplicates = len(instructions) - len(unique_instructions)
            issues.append(f"发现 {duplicates} 个重复指令")
        
        # 检查长度异常
        long_instructions = sum(1 for s in data if len(s.get("instruction", "")) > 200)
        if long_instructions > 0:
            issues.append(f"发现 {long_instructions} 个过长指令")
        
        return issues
    
    def generate_recommendations(self, scores: Dict) -> List[str]:
        """生成改进建议"""
        recommendations = []
        
        if scores["completeness"] < 0.8:
            recommendations.append("增加更多完整的指令-输出对")
        
        if scores["relevance"] < 0.7:
            recommendations.append("确保指令和输出内容相关")
        
        if scores["diversity"] < 0.6:
            recommendations.append("增加指令的多样性,涵盖更多领域")
        
        if scores["accuracy"] < 0.9:
            recommendations.append("验证输出内容的准确性")
        
        return recommendations

应用场景

class AlpacaDataApplication:
    def __init__(self):
        self.data_loader = AlpacaDataLoader()
        self.quality_control = AlpacaDataQualityControl()
    
    def prepare_training_data(self, data_path: str) -> Dict:
        """准备训练数据"""
        # 加载数据
        raw_data = self.data_loader.load_from_json(data_path)
        
        # 预处理
        processed_data = self.data_loader.preprocess_data(raw_data)
        
        # 质量评估
        quality_report = self.quality_control.assess_quality(processed_data)
        
        # 创建数据分割
        splits = self.data_loader.create_splits(processed_data)
        
        return {
            "data": splits,
            "quality_report": quality_report,
            "statistics": self.calculate_statistics(processed_data)
        }
    
    def calculate_statistics(self, data: List[Dict]) -> Dict:
        """计算数据统计信息"""
        instructions = [d.get("instruction", "") for d in data]
        outputs = [d.get("output", "") for d in data]
        
        return {
            "total_samples": len(data),
            "avg_instruction_length": sum(len(i) for i in instructions) / len(instructions),
            "avg_output_length": sum(len(o) for o in outputs) / len(outputs),
            "min_instruction_length": min(len(i) for i in instructions),
            "max_instruction_length": max(len(i) for i in instructions)
        }
    
    def create_custom_dataset(self, domain: str, count: int) -> List[Dict]:
        """创建自定义数据集"""
        # 基于Alpaca格式创建特定领域数据
        custom_data = []
        
        for i in range(count):
            data_item = {
                "instruction": f"在{domain}领域,解释第{i+1}个概念",
                "input": "",
                "output": f"这是关于{domain}领域第{i+1}个概念的详细解释..."
            }
            custom_data.append(data_item)
        
        return custom_data

# 使用示例
app = AlpacaDataApplication()
training_data = app.prepare_training_data("alpaca_data.json")
print(f"数据质量评分: {training_data['quality_report']['quality_scores']}")

总结

Alpaca数据集为LLM指令微调提供了重要资源。通过标准化的数据格式、灵活的生成方法和严格的质量控制,可以创建高质量的训练数据,提升模型的指令遵循能力。