← 返回首页
🧠

Airflow与LLM:构建可靠的AI工作流调度

📂 llm ⏱ 3 min 439 words

--- title: "Airflow与LLM:构建可靠的AI工作流调度" description: "使用Apache Airflow调度和管理LLM任务,实现可靠的AI工作流" tags: ["Airflow", "LLM", "工作流调度", "DAG", "MLOps"] category: "llm" icon: "🌬️"

Airflow与LLM:构建可靠的AI工作流调度

Airflow概述

Apache Airflow是一个工作流编排平台,通过DAG(有向无环图)定义和调度任务。结合LLM,可以构建可靠的AI工作流。

LLM任务定义

1. LLMOperator

from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from openai import OpenAI

class LLMOperator(BaseOperator):
    """自定义LLM操作符"""
    
    @apply_defaults
    def __init__(
        self,
        prompt: str,
        model: str = "gpt-4",
        temperature: float = 0.7,
        max_tokens: int = 1000,
        output_key: str = "llm_output",
        *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.prompt = prompt
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.output_key = output_key
    
    def execute(self, context):
        client = OpenAI()
        
        # 渲染提示模板
        rendered_prompt = self.prompt.format(**context)
        
        self.log.info(f"调用LLM: {self.model}")
        response = client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": rendered_prompt}],
            temperature=self.temperature,
            max_tokens=self.max_tokens
        )
        
        result = response.choices[0].message.content
        context[self.output_key] = result
        
        self.log.info(f"LLM输出: {result[:100]}...")
        return result

2. 文本处理Operator

class TextProcessingOperator(BaseOperator):
    """文本处理操作符"""
    
    @apply_defaults
    def __init__(
        self,
        operation: str,  # summarize, translate, analyze
        input_key: str = "text",
        output_key: str = "processed_text",
        *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.operation = operation
        self.input_key = input_key
        self.output_key = output_key
    
    def execute(self, context):
        text = context[self.input_key]
        
        prompts = {
            "summarize": f"请用100字以内总结以下内容:\n{text}",
            "translate": f"将以下内容翻译成英文:\n{text}",
            "analyze": f"分析以下内容的关键要点:\n{text}"
        }
        
        prompt = prompts.get(self.operation, f"处理文本:{text}")
        
        client = OpenAI()
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}]
        )
        
        result = response.choices[0].message.content
        context[self.output_key] = result
        
        self.log.info(f"{self.operation} 完成")
        return result

DAG定义

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.empty import EmptyOperator
from datetime import datetime, timedelta

default_args = {
    "owner": "ai_team",
    "depends_on_past": False,
    "start_date": datetime(2024, 1, 1),
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
}

def extract_data(**context):
    """数据提取任务"""
    # 模拟数据提取
    data = {
        "articles": [
            {"title": "AI趋势", "content": "人工智能正在快速发展..."},
            {"title": "LLM应用", "content": "大语言模型改变了开发方式..."}
        ]
    }
    context["extracted_data"] = data
    return data

def validate_data(**context):
    """数据验证任务"""
    data = context["extracted_data"]
    articles = data.get("articles", [])
    
    if len(articles) == 0:
        raise ValueError("没有提取到文章数据")
    
    context["valid_data"] = data
    self.log.info(f"验证通过: {len(articles)} 篇文章")
    return True

def aggregate_results(**context):
    """结果聚合任务"""
    summaries = context.get("summaries", [])
    aggregated = "\n\n".join(summaries)
    context["final_report"] = aggregated
    return aggregated

# 定义DAG
with DAG(
    "llm_content_pipeline",
    default_args=default_args,
    description="LLM内容处理工作流",
    schedule_interval="@daily",
    catchup=False,
) as dag:
    
    start = EmptyOperator(task_id="start")
    
    extract = PythonOperator(
        task_id="extract_data",
        python_callable=extract_data,
    )
    
    validate = PythonOperator(
        task_id="validate_data",
        python_callable=validate_data,
    )
    
    summarize = LLMOperator(
        task_id="summarize_articles",
        prompt="总结以下文章: {extracted_data}",
        output_key="summaries",
    )
    
    analyze = LLMOperator(
        task_id="analyze_trends",
        prompt="分析以下内容的趋势: {final_report}",
        output_key="analysis",
    )
    
    aggregate = PythonOperator(
        task_id="aggregate_results",
        python_callable=aggregate_results,
    )
    
    end = EmptyOperator(task_id="end")
    
    # 定义任务依赖
    start >> extract >> validate >> summarize >> aggregate >> analyze >> end

错误处理

from airflow.exceptions import AirflowException
from airflow.operators.python import BranchPythonOperator

def check_quality(**context):
    """检查LLM输出质量"""
    output = context.get("llm_output", "")
    
    if len(output) < 50:
        return "retry_task"
    elif "错误" in output or "抱歉" in output:
        return "fallback_task"
    return "continue_task"

def retry_with_fallback(**context):
    """带降级的重试"""
    try:
        # 尝试使用GPT-4
        client = OpenAI()
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[{"role": "user", "content": "处理任务"}]
        )
        return response.choices[0].message.content
    except Exception:
        # 降级到GPT-3.5
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "处理任务"}]
        )
        return response.choices[0].message.content

# 在DAG中使用
quality_check = BranchPythonOperator(
    task_id="quality_check",
    python_callable=check_quality,
)

retry_task = LLMOperator(
    task_id="retry_task",
    prompt="重试处理: {input}",
)

fallback_task = PythonOperator(
    task_id="fallback_task",
    python_callable=retry_with_fallback,
)

continue_task = EmptyOperator(task_id="continue_task")

quality_check >> [retry_task, fallback_task, continue_task]

最佳实践

  1. 任务原子化:每个LLM调用封装为独立任务
  2. 参数化提示:使用模板变量使提示可配置
  3. 监控告警:设置任务失败告警
  4. 资源限制:控制LLM API调用频率

总结

Airflow为LLM工作流提供了可靠的调度和监控能力。通过自定义Operator和DAG设计,可以构建复杂、可靠的AI处理管道。