Airflow与LLM:构建可靠的AI工作流调度
--- title: "Airflow与LLM:构建可靠的AI工作流调度" description: "使用Apache Airflow调度和管理LLM任务,实现可靠的AI工作流" tags: ["Airflow", "LLM", "工作流调度", "DAG", "MLOps"] category: "llm" icon: "🌬️"
Airflow与LLM:构建可靠的AI工作流调度
Airflow概述
Apache Airflow是一个工作流编排平台,通过DAG(有向无环图)定义和调度任务。结合LLM,可以构建可靠的AI工作流。
LLM任务定义
1. LLMOperator
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from openai import OpenAI
class LLMOperator(BaseOperator):
"""自定义LLM操作符"""
@apply_defaults
def __init__(
self,
prompt: str,
model: str = "gpt-4",
temperature: float = 0.7,
max_tokens: int = 1000,
output_key: str = "llm_output",
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.prompt = prompt
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.output_key = output_key
def execute(self, context):
client = OpenAI()
# 渲染提示模板
rendered_prompt = self.prompt.format(**context)
self.log.info(f"调用LLM: {self.model}")
response = client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": rendered_prompt}],
temperature=self.temperature,
max_tokens=self.max_tokens
)
result = response.choices[0].message.content
context[self.output_key] = result
self.log.info(f"LLM输出: {result[:100]}...")
return result
2. 文本处理Operator
class TextProcessingOperator(BaseOperator):
"""文本处理操作符"""
@apply_defaults
def __init__(
self,
operation: str, # summarize, translate, analyze
input_key: str = "text",
output_key: str = "processed_text",
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.operation = operation
self.input_key = input_key
self.output_key = output_key
def execute(self, context):
text = context[self.input_key]
prompts = {
"summarize": f"请用100字以内总结以下内容:\n{text}",
"translate": f"将以下内容翻译成英文:\n{text}",
"analyze": f"分析以下内容的关键要点:\n{text}"
}
prompt = prompts.get(self.operation, f"处理文本:{text}")
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
context[self.output_key] = result
self.log.info(f"{self.operation} 完成")
return result
DAG定义
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.empty import EmptyOperator
from datetime import datetime, timedelta
default_args = {
"owner": "ai_team",
"depends_on_past": False,
"start_date": datetime(2024, 1, 1),
"retries": 2,
"retry_delay": timedelta(minutes=5),
}
def extract_data(**context):
"""数据提取任务"""
# 模拟数据提取
data = {
"articles": [
{"title": "AI趋势", "content": "人工智能正在快速发展..."},
{"title": "LLM应用", "content": "大语言模型改变了开发方式..."}
]
}
context["extracted_data"] = data
return data
def validate_data(**context):
"""数据验证任务"""
data = context["extracted_data"]
articles = data.get("articles", [])
if len(articles) == 0:
raise ValueError("没有提取到文章数据")
context["valid_data"] = data
self.log.info(f"验证通过: {len(articles)} 篇文章")
return True
def aggregate_results(**context):
"""结果聚合任务"""
summaries = context.get("summaries", [])
aggregated = "\n\n".join(summaries)
context["final_report"] = aggregated
return aggregated
# 定义DAG
with DAG(
"llm_content_pipeline",
default_args=default_args,
description="LLM内容处理工作流",
schedule_interval="@daily",
catchup=False,
) as dag:
start = EmptyOperator(task_id="start")
extract = PythonOperator(
task_id="extract_data",
python_callable=extract_data,
)
validate = PythonOperator(
task_id="validate_data",
python_callable=validate_data,
)
summarize = LLMOperator(
task_id="summarize_articles",
prompt="总结以下文章: {extracted_data}",
output_key="summaries",
)
analyze = LLMOperator(
task_id="analyze_trends",
prompt="分析以下内容的趋势: {final_report}",
output_key="analysis",
)
aggregate = PythonOperator(
task_id="aggregate_results",
python_callable=aggregate_results,
)
end = EmptyOperator(task_id="end")
# 定义任务依赖
start >> extract >> validate >> summarize >> aggregate >> analyze >> end
错误处理
from airflow.exceptions import AirflowException
from airflow.operators.python import BranchPythonOperator
def check_quality(**context):
"""检查LLM输出质量"""
output = context.get("llm_output", "")
if len(output) < 50:
return "retry_task"
elif "错误" in output or "抱歉" in output:
return "fallback_task"
return "continue_task"
def retry_with_fallback(**context):
"""带降级的重试"""
try:
# 尝试使用GPT-4
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "处理任务"}]
)
return response.choices[0].message.content
except Exception:
# 降级到GPT-3.5
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "处理任务"}]
)
return response.choices[0].message.content
# 在DAG中使用
quality_check = BranchPythonOperator(
task_id="quality_check",
python_callable=check_quality,
)
retry_task = LLMOperator(
task_id="retry_task",
prompt="重试处理: {input}",
)
fallback_task = PythonOperator(
task_id="fallback_task",
python_callable=retry_with_fallback,
)
continue_task = EmptyOperator(task_id="continue_task")
quality_check >> [retry_task, fallback_task, continue_task]
最佳实践
- 任务原子化:每个LLM调用封装为独立任务
- 参数化提示:使用模板变量使提示可配置
- 监控告警:设置任务失败告警
- 资源限制:控制LLM API调用频率
总结
Airflow为LLM工作流提供了可靠的调度和监控能力。通过自定义Operator和DAG设计,可以构建复杂、可靠的AI处理管道。