← 返回首页
🧠

护栏系统:保护LLM安全运行

📂 llm ⏱ 4 min 782 words

--- title: "护栏系统:保护LLM安全运行" description: "设计和实现LLM护栏系统,确保AI在安全边界内运行" tags: ["护栏", "Guardrails", "安全边界", "LLM", "防护"] category: "llm" icon: "🚧"

护栏系统:保护LLM安全运行

护栏概述

护栏(Guardrails)是限制和引导LLM行为的机制,确保AI在预定义的安全边界内运行。

护栏架构

1. 护栏定义

from dataclasses import dataclass, field
from typing import List, Dict, Callable, Any, Optional
from enum import Enum

class GuardrailType(Enum):
    INPUT_VALIDATION = "input_validation"
    OUTPUT_VALIDATION = "output_validation"
    CONTENT_FILTER = "content_filter"
    RATE_LIMIT = "rate_limit"
    TOPIC_RESTRICTION = "topic_restriction"

@dataclass
class Guardrail:
    """护栏定义"""
    name: str
    type: GuardrailType
    description: str
    check_func: Callable
    severity: str = "high"  # "low", "medium", "high", "critical"
    enabled: bool = True
    metadata: Dict[str, Any] = field(default_factory=dict)

class GuardrailRegistry:
    """护栏注册表"""
    
    def __init__(self):
        self.guardrails: Dict[str, Guardrail] = {}
    
    def register(self, guardrail: Guardrail):
        """注册护栏"""
        self.guardrails[guardrail.name] = guardrail
    
    def get_guardrail(self, name: str) -> Optional[Guardrail]:
        """获取护栏"""
        return self.guardrails.get(name)
    
    def list_guardrails(self, type_filter: GuardrailType = None) -> List[Guardrail]:
        """列出护栏"""
        guardrails = list(self.guardrails.values())
        
        if type_filter:
            guardrails = [g for g in guardrails if g.type == type_filter]
        
        return guardrails
    
    def enable_guardrail(self, name: str):
        """启用护栏"""
        if name in self.guardrails:
            self.guardrails[name].enabled = True
    
    def disable_guardrail(self, name: str):
        """禁用护栏"""
        if name in self.guardrails:
            self.guardrails[name].enabled = False

2. 护栏引擎

class GuardrailEngine:
    """护栏引擎"""
    
    def __init__(self, registry: GuardrailRegistry):
        self.registry = registry
        self.execution_log = []
    
    def execute(self, input_text: str, context: Dict = None) -> Dict:
        """执行所有护栏检查"""
        results = []
        all_passed = True
        
        for guardrail in self.registry.list_guardrails():
            if not guardrail.enabled:
                continue
            
            try:
                check_result = guardrail.check_func(input_text, context)
                results.append({
                    "guardrail": guardrail.name,
                    "type": guardrail.type.value,
                    "passed": check_result["passed"],
                    "message": check_result.get("message", ""),
                    "severity": guardrail.severity
                })
                
                if not check_result["passed"]:
                    all_passed = False
            except Exception as e:
                results.append({
                    "guardrail": guardrail.name,
                    "type": guardrail.type.value,
                    "passed": False,
                    "message": f"检查失败: {str(e)}",
                    "severity": guardrail.severity
                })
                all_passed = False
        
        # 记录执行日志
        self._log_execution(input_text, results)
        
        return {
            "passed": all_passed,
            "results": results,
            "failed_guardrails": [r for r in results if not r["passed"]]
        }
    
    def _log_execution(self, input_text: str, results: List[Dict]):
        """记录执行日志"""
        self.execution_log.append({
            "timestamp": datetime.now().isoformat(),
            "input_preview": input_text[:100],
            "results": results
        })

3. 内置护栏

class BuiltInGuardrails:
    """内置护栏"""
    
    @staticmethod
    def input_length_guardrail(max_length: int = 10000) -> Guardrail:
        """输入长度护栏"""
        def check(input_text: str, context: Dict = None) -> Dict:
            if len(input_text) > max_length:
                return {
                    "passed": False,
                    "message": f"输入长度超过限制: {len(input_text)} > {max_length}"
                }
            return {"passed": True}
        
        return Guardrail(
            name="input_length",
            type=GuardrailType.INPUT_VALIDATION,
            description=f"限制输入长度不超过{max_length}字符",
            check_func=check
        )
    
    @staticmethod
    def content_filter_guardrail(prohibited_terms: List[str] = None) -> Guardrail:
        """内容过滤护栏"""
        if prohibited_terms is None:
            prohibited_terms = ["暴力", "仇恨", "歧视", "色情"]
        
        def check(input_text: str, context: Dict = None) -> Dict:
            for term in prohibited_terms:
                if term in input_text:
                    return {
                        "passed": False,
                        "message": f"包含禁止内容: {term}"
                    }
            return {"passed": True}
        
        return Guardrail(
            name="content_filter",
            type=GuardrailType.CONTENT_FILTER,
            description="过滤禁止内容",
            check_func=check
        )
    
    @staticmethod
    def topic_restriction_guardrail(allowed_topics: List[str] = None) -> Guardrail:
        """主题限制护栏"""
        if allowed_topics is None:
            allowed_topics = ["技术", "教育", "娱乐"]
        
        def check(input_text: str, context: Dict = None) -> Dict:
            # 简化实现:检查是否包含允许主题的关键词
            has_allowed_topic = any(topic in input_text for topic in allowed_topics)
            
            if not has_allowed_topic:
                return {
                    "passed": False,
                    "message": "主题不在允许范围内"
                }
            return {"passed": True}
        
        return Guardrail(
            name="topic_restriction",
            type=GuardrailType.TOPIC_RESTRICTION,
            description="限制允许的主题",
            check_func=check
        )
    
    @staticmethod
    def rate_limit_guardrail(max_requests: int = 100, time_window: int = 60) -> Guardrail:
        """速率限制护栏"""
        request_counts = {}
        
        def check(input_text: str, context: Dict = None) -> Dict:
            user_id = context.get("user_id", "default") if context else "default"
            current_time = time.time()
            
            if user_id not in request_counts:
                request_counts[user_id] = []
            
            # 清理过期请求
            request_counts[user_id] = [
                t for t in request_counts[user_id] 
                if current_time - t < time_window
            ]
            
            # 检查限制
            if len(request_counts[user_id]) >= max_requests:
                return {
                    "passed": False,
                    "message": f"请求频率超过限制: {max_requests}次/{time_window}秒"
                }
            
            # 记录当前请求
            request_counts[user_id].append(current_time)
            return {"passed": True}
        
        return Guardrail(
            name="rate_limit",
            type=GuardrailType.RATE_LIMIT,
            description=f"限制请求频率: {max_requests}次/{time_window}秒",
            check_func=check
        )

高级护栏

1. 输出验证护栏

class OutputValidationGuardrail:
    """输出验证护栏"""
    
    def __init__(self):
        self.validation_rules = []
    
    def add_rule(self, name: str, check_func: Callable):
        """添加验证规则"""
        self.validation_rules.append({"name": name, "check": check_func})
    
    def create_guardrail(self) -> Guardrail:
        """创建护栏"""
        def check_output(output_text: str, context: Dict = None) -> Dict:
            for rule in self.validation_rules:
                try:
                    result = rule["check"](output_text)
                    if not result["passed"]:
                        return {
                            "passed": False,
                            "message": f"输出验证失败: {rule['name']} - {result.get('message', '')}"
                        }
                except Exception as e:
                    return {
                        "passed": False,
                        "message": f"验证规则执行失败: {rule['name']}"
                    }
            
            return {"passed": True}
        
        return Guardrail(
            name="output_validation",
            type=GuardrailType.OUTPUT_VALIDATION,
            description="验证输出质量",
            check_func=check_output
        )

# 使用示例
output_validator = OutputValidationGuardrail()
output_validator.add_rule("length", lambda t: {"passed": len(t) > 10})
output_validator.add_rule("no_empty", lambda t: {"passed": bool(t.strip())})

output_guardrail = output_validator.create_guardrail()

2. 自定义护栏

class CustomGuardrailFactory:
    """自定义护栏工厂"""
    
    @staticmethod
    def create_keyword_guardrail(keywords: List[str], mode: str = "block") -> Guardrail:
        """创建关键词护栏"""
        def check(input_text: str, context: Dict = None) -> Dict:
            found_keywords = [kw for kw in keywords if kw in input_text]
            
            if mode == "block" and found_keywords:
                return {
                    "passed": False,
                    "message": f"包含阻止关键词: {found_keywords}"
                }
            elif mode == "allow" and not found_keywords:
                return {
                    "passed": False,
                    "message": "未包含允许的关键词"
                }
            
            return {"passed": True}
        
        return Guardrail(
            name=f"keyword_{mode}",
            type=GuardrailType.CONTENT_FILTER,
            description=f"关键词{mode}护栏",
            check_func=check
        )
    
    @staticmethod
    def create_regex_guardrail(pattern: str, description: str = "") -> Guardrail:
        """创建正则表达式护栏"""
        import re
        
        def check(input_text: str, context: Dict = None) -> Dict:
            if re.search(pattern, input_text):
                return {
                    "passed": False,
                    "message": f"匹配禁止模式: {pattern}"
                }
            return {"passed": True}
        
        return Guardrail(
            name="regex_filter",
            type=GuardrailType.CONTENT_FILTER,
            description=description or f"正则表达式过滤: {pattern}",
            check_func=check
        )

护栏配置

class GuardrailConfiguration:
    """护栏配置"""
    
    def __init__(self):
        self.config = {
            "enabled": True,
            "strict_mode": False,
            "logging_level": "INFO",
            "alert_on_violation": True,
            "max_violations_before_block": 3
        }
    
    def configure_strict_mode(self, enabled: bool):
        """配置严格模式"""
        self.config["strict_mode"] = enabled
    
    def get_config(self) -> Dict:
        """获取配置"""
        return self.config.copy()

# 使用示例
config = GuardrailConfiguration()
config.configure_strict_mode(True)

# 创建护栏引擎
registry = GuardrailRegistry()
registry.register(BuiltInGuardrails.input_length_guardrail())
registry.register(BuiltInGuardrails.content_filter_guardrail())

engine = GuardrailEngine(registry)

# 执行检查
result = engine.execute("这是一段测试文本")
print(f"护栏检查结果: {'通过' if result['passed'] else '失败'}")

最佳实践

  1. 分层防护:实施多层护栏防护
  2. 可配置性:提供灵活的配置选项
  3. 性能考虑:优化护栏检查性能
  4. 监控告警:监控护栏违规情况

总结

护栏系统是保护LLM安全运行的重要机制。通过设计和实现合适的护栏,可以确保AI在安全边界内运行。