← 返回首页
🧠

自指令

📂 llm ⏱ 5 min 951 words

--- title: "自指令" description: "Self-Instruct技术详解,包括指令生成、数据增强和模型微调方法" tags: ["自指令", "Self-Instruct", "指令生成", "数据增强"] category: "llm" icon: "🧠"

自指令

什么是自指令

自指令(Self-Instruct)是一种通过LLM自身生成训练数据来提升模型能力的技术。它使用少量种子任务生成大量新指令,显著降低了人工标注成本。

核心原理

1. 种子任务定义

from typing import List, Dict
import json

class SeedTaskManager:
    def __init__(self):
        self.seed_tasks = [
            {
                "id": 1,
                "task_type": "classification",
                "instruction": "将以下文本分类为正面、负面或中性情感",
                "input": "这家餐厅的服务非常好,食物也很美味。",
                "output": "正面情感",
                "category": "sentiment_analysis"
            },
            {
                "id": 2,
                "task_type": "generation",
                "instruction": "根据以下描述写一段产品描述",
                "input": "一款轻便的笔记本电脑,适合商务人士",
                "output": "这款轻薄笔记本电脑专为商务精英设计,重量仅1.2kg,搭载最新处理器,续航长达12小时,让您随时随地高效工作。",
                "category": "text_generation"
            },
            {
                "id": 3,
                "task_type": "extraction",
                "instruction": "从以下文本中提取人名和地点",
                "input": "张三上周去了北京出差",
                "output": "人名:张三,地点:北京",
                "category": "information_extraction"
            }
        ]
    
    def get_seed_tasks(self, category: str = None) -> List[Dict]:
        """获取种子任务"""
        if category:
            return [t for t in self.seed_tasks if t["category"] == category]
        return self.seed_tasks
    
    def add_seed_task(self, task: Dict) -> None:
        """添加新种子任务"""
        task["id"] = len(self.seed_tasks) + 1
        self.seed_tasks.append(task)
    
    def format_for_generation(self, tasks: List[Dict]) -> str:
        """格式化种子任务用于生成"""
        formatted = "以下是几个示例任务:\n\n"
        for task in tasks:
            formatted += f"任务:{task['instruction']}\n"
            formatted += f"输入:{task['input']}\n"
            formatted += f"输出:{task['output']}\n\n"
        return formatted

2. 指令生成器

class InstructionGenerator:
    def __init__(self, model, seed_manager: SeedTaskManager):
        self.model = model
        self.seed_manager = seed_manager
        self.generation_prompt_template = """
        根据以下示例任务,生成{num_tasks}个新的类似任务。
        
        示例任务:
        {examples}
        
        请生成新的任务,格式如下:
        任务:[指令]
        输入:[可选输入]
        输出:[期望输出]
        """
    
    def generate_instructions(self, num_tasks: int = 5, category: str = None) -> List[Dict]:
        """生成新指令"""
        # 获取种子任务作为示例
        seed_tasks = self.seed_manager.get_seed_tasks(category)
        examples = self.seed_manager.format_for_generation(seed_tasks[:3])
        
        # 构建生成提示
        prompt = self.generation_prompt_template.format(
            num_tasks=num_tasks,
            examples=examples
        )
        
        # 生成新任务
        generated_text = self.model.generate(prompt)
        
        # 解析生成结果
        new_tasks = self.parse_generated_tasks(generated_text)
        
        return new_tasks
    
    def parse_generated_tasks(self, text: str) -> List[Dict]:
        """解析生成的任务"""
        tasks = []
        lines = text.strip().split('\n')
        
        current_task = {}
        for line in lines:
            line = line.strip()
            if line.startswith('任务:'):
                if current_task:
                    tasks.append(current_task)
                current_task = {"instruction": line[3:]}
            elif line.startswith('输入:'):
                current_task["input"] = line[3:]
            elif line.startswith('输出:'):
                current_task["output"] = line[3:]
        
        if current_task:
            tasks.append(current_task)
        
        return tasks
    
    def augment_with_diversity(self, seed_tasks: List[Dict], target_count: int) -> List[Dict]:
        """增加多样性的增强"""
        augmented_tasks = []
        
        # 确保任务类型多样性
        task_types = set(t.get("task_type", "unknown") for t in seed_tasks)
        
        for task_type in task_types:
            type_tasks = [t for t in seed_tasks if t.get("task_type") == task_type]
            count = min(len(type_tasks) * 3, target_count // len(task_types))
            
            for _ in range(count):
                new_tasks = self.generate_instructions(1, task_type)
                augmented_tasks.extend(new_tasks)
        
        return augmented_tasks[:target_count]

3. 质量过滤

class QualityFilter:
    def __init__(self):
        self.quality_criteria = {
            "length": {"min": 10, "max": 500},
            "diversity": 0.3,
            "relevance": 0.5,
            "completeness": 0.7
        }
    
    def filter_tasks(self, tasks: List[Dict]) -> List[Dict]:
        """过滤低质量任务"""
        filtered_tasks = []
        
        for task in tasks:
            if self.meets_quality_standards(task):
                filtered_tasks.append(task)
        
        return filtered_tasks
    
    def meets_quality_standards(self, task: Dict) -> bool:
        """检查是否满足质量标准"""
        # 长度检查
        instruction_length = len(task.get("instruction", ""))
        if not (self.quality_criteria["length"]["min"] <= 
                instruction_length <= 
                self.quality_criteria["length"]["max"]):
            return False
        
        # 完整性检查
        required_fields = ["instruction", "output"]
        for field in required_fields:
            if field not in task or not task[field]:
                return False
        
        # 多样性检查(简化)
        if not self.check_diversity(task):
            return False
        
        return True
    
    def check_diversity(self, task: Dict) -> bool:
        """检查任务多样性"""
        # 简化实现:检查关键词多样性
        instruction = task.get("instruction", "")
        words = set(instruction.split())
        
        # 确保有足够的独特词汇
        return len(words) / max(len(instruction.split()), 1) > 0.3
    
    def calculate_quality_score(self, task: Dict) -> float:
        """计算质量分数"""
        scores = []
        
        # 长度分数
        instruction_length = len(task.get("instruction", ""))
        length_score = min(instruction_length / 100, 1.0)
        scores.append(length_score)
        
        # 完整性分数
        completeness_score = sum(
            1 for field in ["instruction", "input", "output"] 
            if field in task and task[field]
        ) / 3
        scores.append(completeness_score)
        
        # 多样性分数
        diversity_score = 0.8 if self.check_diversity(task) else 0.3
        scores.append(diversity_score)
        
        return sum(scores) / len(scores)

完整流程

class SelfInstructPipeline:
    def __init__(self, model):
        self.seed_manager = SeedTaskManager()
        self.generator = InstructionGenerator(model, self.seed_manager)
        self.quality_filter = QualityFilter()
    
    def run_pipeline(self, target_count: int = 100) -> List[Dict]:
        """运行完整的自指令流程"""
        print(f"开始生成 {target_count} 个指令...")
        
        # 阶段1:生成候选任务
        print("阶段1:生成候选任务...")
        candidate_tasks = self.generator.augment_with_diversity(
            self.seed_manager.get_seed_tasks(),
            target_count * 2  # 生成双倍数量以供筛选
        )
        
        # 阶段2:质量过滤
        print("阶段2:质量过滤...")
        filtered_tasks = self.quality_filter.filter_tasks(candidate_tasks)
        
        # 阶段3:去重和精选
        print("阶段3:去重和精选...")
        final_tasks = self.deduplicate_and_select(
            filtered_tasks, target_count
        )
        
        # 阶段4:验证和标注
        print("阶段4:验证和标注...")
        validated_tasks = self.validate_and_annotate(final_tasks)
        
        print(f"生成完成!共 {len(validated_tasks)} 个高质量指令")
        return validated_tasks
    
    def deduplicate_and_select(self, tasks: List[Dict], count: int) -> List[Dict]:
        """去重和选择"""
        unique_tasks = []
        seen_instructions = set()
        
        # 按质量分数排序
        scored_tasks = [
            (task, self.quality_filter.calculate_quality_score(task))
            for task in tasks
        ]
        scored_tasks.sort(key=lambda x: x[1], reverse=True)
        
        for task, score in scored_tasks:
            instruction = task.get("instruction", "")
            instruction_hash = hash(instruction)
            
            if instruction_hash not in seen_instructions:
                seen_instructions.add(instruction_hash)
                unique_tasks.append(task)
                
                if len(unique_tasks) >= count:
                    break
        
        return unique_tasks
    
    def validate_and_annotate(self, tasks: List[Dict]) -> List[Dict]:
        """验证和标注"""
        validated_tasks = []
        
        for task in tasks:
            # 添加元数据
            task["source"] = "self_instruct"
            task["generation_timestamp"] = self.get_timestamp()
            task["quality_score"] = self.quality_filter.calculate_quality_score(task)
            
            # 验证任务格式
            if self.validate_task_format(task):
                validated_tasks.append(task)
        
        return validated_tasks
    
    def validate_task_format(self, task: Dict) -> bool:
        """验证任务格式"""
        required_fields = ["instruction", "output"]
        return all(field in task and task[field] for field in required_fields)
    
    def get_timestamp(self) -> str:
        """获取时间戳"""
        from datetime import datetime
        return datetime.now().isoformat()

# 使用示例
pipeline = SelfInstructPipeline(model)
generated_tasks = pipeline.run_pipeline(target_count=50)

# 保存生成的任务
with open("generated_instructions.json", "w", encoding="utf-8") as f:
    json.dump(generated_tasks, f, ensure_ascii=False, indent=2)

高级技巧

1. 多样性增强

class DiversityEnhancer:
    def __init__(self):
        self.diversity_strategies = [
            "paraphrase",
            "template_variation",
            "domain_transfer",
            "complexity_adjustment"
        ]
    
    def enhance_diversity(self, tasks: List[Dict]) -> List[Dict]:
        """增强任务多样性"""
        enhanced_tasks = []
        
        for task in tasks:
            # 应用多种增强策略
            for strategy in self.diversity_strategies:
                enhanced_task = self.apply_strategy(task, strategy)
                if enhanced_task:
                    enhanced_tasks.append(enhanced_task)
        
        return enhanced_tasks
    
    def apply_strategy(self, task: Dict, strategy: str) -> Dict:
        """应用特定增强策略"""
        strategies = {
            "paraphrase": self.paraphrase_task,
            "template_variation": self.vary_template,
            "domain_transfer": self.transfer_domain,
            "complexity_adjustment": self.adjust_complexity
        }
        
        enhancer = strategies.get(strategy)
        return enhancer(task) if enhancer else None
    
    def paraphrase_task(self, task: Dict) -> Dict:
        """改述任务"""
        # 使用模型进行改述
        paraphrased = task.copy()
        paraphrased["instruction"] = f"请{task['instruction']}"
        paraphrased["enhancement"] = "paraphrase"
        return paraphrased
    
    def vary_template(self, task: Dict) -> Dict:
        """变化模板"""
        templates = [
            "请{}",
            "帮我{}",
            "如何{}",
            "完成以下任务:{}"
        ]
        
        import random
        template = random.choice(templates)
        
        varied = task.copy()
        varied["instruction"] = template.format(task["instruction"])
        varied["enhancement"] = "template_variation"
        return varied
    
    def transfer_domain(self, task: Dict) -> Dict:
        """领域迁移"""
        domains = ["科技", "教育", "医疗", "金融", "娱乐"]
        import random
        domain = random.choice(domains)
        
        transferred = task.copy()
        transferred["instruction"] = f"在{domain}领域,{task['instruction']}"
        transferred["enhancement"] = "domain_transfer"
        return transferred
    
    def adjust_complexity(self, task: Dict) -> Dict:
        """调整复杂度"""
        import random
        complexity = random.choice(["简单", "详细", "专业"])
        
        adjusted = task.copy()
        adjusted["instruction"] = f"请以{complexity}的方式{task['instruction']}"
        adjusted["enhancement"] = "complexity_adjustment"
        return adjusted

实际应用

class SelfInstructApplication:
    def __init__(self, model):
        self.pipeline = SelfInstructPipeline(model)
        self.enhancer = DiversityEnhancer()
    
    def create_training_dataset(self, domain: str, count: int) -> List[Dict]:
        """创建特定领域的训练数据集"""
        # 调整种子任务以适应领域
        domain_specific_seeds = self.get_domain_seeds(domain)
        self.pipeline.seed_manager.seed_tasks.extend(domain_specific_seeds)
        
        # 生成指令
        generated = self.pipeline.run_pipeline(count)
        
        # 增强多样性
        enhanced = self.enhancer.enhance_diversity(generated)
        
        return enhanced[:count]
    
    def get_domain_seeds(self, domain: str) -> List[Dict]:
        """获取领域特定种子任务"""
        domain_seeds = {
            "medical": [
                {
                    "instruction": "解释以下医学术语",
                    "input": "高血压",
                    "output": "高血压是指动脉血压持续升高的一种慢性疾病"
                }
            ],
            "legal": [
                {
                    "instruction": "分析以下合同条款",
                    "input": "违约责任条款",
                    "output": "该条款规定了合同双方违反约定时应承担的法律责任"
                }
            ]
        }
        
        return domain_seeds.get(domain, [])
    
    def evaluate_generated_data(self, tasks: List[Dict]) -> Dict:
        """评估生成的数据质量"""
        evaluation = {
            "total_tasks": len(tasks),
            "avg_quality_score": 0,
            "category_distribution": {},
            "task_type_distribution": {}
        }
        
        quality_scores = []
        for task in tasks:
            quality_scores.append(task.get("quality_score", 0))
            
            category = task.get("category", "unknown")
            evaluation["category_distribution"][category] = \
                evaluation["category_distribution"].get(category, 0) + 1
            
            task_type = task.get("task_type", "unknown")
            evaluation["task_type_distribution"][task_type] = \
                evaluation["task_type_distribution"].get(task_type, 0) + 1
        
        evaluation["avg_quality_score"] = sum(quality_scores) / len(quality_scores) if quality_scores else 0
        
        return evaluation

# 使用示例
app = SelfInstructApplication(model)
medical_data = app.create_training_dataset("medical", 100)
evaluation = app.evaluate_generated_data(medical_data)
print(f"生成了 {evaluation['total_tasks']} 个医疗领域指令")
print(f"平均质量分数: {evaluation['avg_quality_score']:.2f}")

总结

自指令技术通过LLM自身生成训练数据,大幅降低了数据标注成本。结合质量过滤和多样性增强,可以创建高质量的训练数据集,提升模型在特定任务上的表现。