← 返回首页
🧠

特性标志与LLM

📂 llm ⏱ 2 min 373 words

--- title: "特性标志与LLM" description: "使用特性标志控制LLM功能的渐进式发布,实现灵活的模型版本管理和用户级别的功能开关" tags: ["特性标志", "渐进发布", "功能开关"] category: "llm" icon: "🧠"

特性标志与LLM

概述

特性标志(Feature Flags)是一种在运行时动态控制功能开关的技术。在LLM应用中,特性标志可以用来控制不同模型版本的流量分配、A/B测试、金丝雀发布以及用户级别的功能定制。

核心概念

特性标志的基本思想是将代码部署与功能发布解耦。代码可以提前部署到生产环境,但功能只在标志开启时才对用户可见。这对于LLM项目特别有价值,因为模型更新往往需要灰度验证。

实现方案

基础特性标志系统

# feature_flags/core.py
import json
import hashlib
from dataclasses import dataclass
from typing import Optional
from datetime import datetime

@dataclass
class FeatureFlag:
    name: str
    enabled: bool
    description: str
    created_at: datetime
    
    # 目标用户群体
    user_segments: list = None  # ["beta_users", "internal"]
    percentage: float = 100.0   # 流量百分比
    
    # 模型特定配置
    model_override: Optional[str] = None  # 覆盖默认模型
    model_params: dict = None             # 模型参数覆盖

class FeatureFlagManager:
    def __init__(self):
        self.flags: dict[str, FeatureFlag] = {}
    
    def register(self, flag: FeatureFlag):
        self.flags[flag.name] = flag
    
    def is_enabled(self, flag_name: str, user_id: Optional[str] = None) -> bool:
        flag = self.flags.get(flag_name)
        if not flag:
            return False
        
        if not flag.enabled:
            return False
        
        # 用户群体过滤
        if flag.user_segments and user_id:
            # 简化实现:检查用户是否在指定群体
            if not self._check_user_segment(user_id, flag.user_segments):
                return False
        
        # 百分比灰度
        if flag.percentage < 100.0 and user_id:
            hash_val = int(hashlib.md5(
                f"{flag_name}:{user_id}".encode()
            ).hexdigest()[:8], 16)
            if (hash_val % 100) >= flag.percentage:
                return False
        
        return True
    
    def get_model_for_user(self, flag_name: str, user_id: str) -> Optional[str]:
        flag = self.flags.get(flag_name)
        if flag and self.is_enabled(flag_name, user_id):
            return flag.model_override
        return None
    
    def _check_user_segment(self, user_id: str, segments: list) -> bool:
        # 简化的群体检查逻辑
        beta_users = {"user_001", "user_002", "user_003"}
        if "beta_users" in segments and user_id in beta_users:
            return True
        if "internal" in segments and user_id.startswith("emp_"):
            return True
        return False

与LLM推理集成

# inference_with_flags.py
from feature_flags.core import FeatureFlagManager

class LLMInferenceService:
    def __init__(self, flag_manager: FeatureFlagManager, default_model: str):
        self.flag_manager = flag_manager
        self.default_model = default_model
        self.models = {}
    
    def load_model(self, name: str, model):
        self.models[name] = model
    
    def predict(self, prompt: str, user_id: str, **kwargs) -> str:
        # 检查是否有模型覆盖标志
        model_name = self.flag_manager.get_model_for_user(
            "new_model_rollout", user_id
        )
        
        if not model_name:
            model_name = self.default_model
        
        model = self.models.get(model_name)
        if not model:
            raise ValueError(f"Model {model_name} not loaded")
        
        # 检查是否有参数覆盖
        flag = self.flag_manager.flags.get("new_model_rollout")
        if flag and flag.model_params:
            kwargs.update(flag.model_params)
        
        return model.generate(prompt, **kwargs)

# 使用示例
flag_manager = FeatureFlagManager()
flag_manager.register(FeatureFlag(
    name="new_model_rollout",
    enabled=True,
    description="New model gradual rollout",
    created_at=datetime.now(),
    user_segments=["beta_users"],
    percentage=20.0,
    model_override="qwen-7b-v2",
    model_params={"temperature": 0.7, "max_tokens": 512}
))

service = LLMInferenceService(flag_manager, default_model="qwen-7b-v1")
service.load_model("qwen-7b-v1", old_model)
service.load_model("qwen-7b-v2", new_model)

# 不同用户会得到不同模型的响应
response = service.predict("你好", user_id="user_001")  # 使用新模型
response = service.predict("你好", user_id="user_999")  # 使用旧模型

高级特性

配置热更新

import threading
import time

class HotReloadFlags:
    def __init__(self, config_path: str, interval: int = 60):
        self.config_path = config_path
        self.interval = interval
        self.manager = FeatureFlagManager()
        self._start_watcher()
    
    def _start_watcher(self):
        def watch():
            while True:
                self._reload_config()
                time.sleep(self.interval)
        
        thread = threading.Thread(target=watch, daemon=True)
        thread.start()
    
    def _reload_config(self):
        config = json.loads(Path(self.config_path).read_text())
        for flag_data in config.get("flags", []):
            flag = FeatureFlag(**flag_data)
            self.manager.register(flag)

最佳实践

  1. 渐进式发布:从小百分比开始,逐步增加流量
  2. 快速回滚:发现问题时立即关闭标志,无需重新部署
  3. 监控指标:为每个标志关联关键指标,自动判断发布效果
  4. 定期清理:功能稳定后移除标志,避免代码腐化