分布式追踪与LLM
--- title: "分布式追踪与LLM" description: "探讨分布式追踪技术在LLM应用中的应用,包括跨服务追踪和多模型协作场景。" tags: ["分布式追踪", "LLM", "微服务"] category: "llm" icon: "🧠"
分布式追踪与LLM
分布式追踪的必要性
现代LLM应用通常采用微服务架构,一个用户请求可能涉及多个服务的协作:网关服务、提示处理服务、模型推理服务、后处理服务等。分布式追踪技术可以帮助我们理解这些服务间的交互关系。
在LLM场景中,分布式追踪尤其重要,因为:
- 模型推理通常耗时较长
- 多个模型可能需要协作完成任务
- 外部工具调用增加了系统复杂性
- 错误可能发生在任何环节
核心原理
Trace Context传播
在分布式系统中,追踪上下文需要在服务间传递。W3C Trace Context是目前最标准的传播协议:
import uuid
from dataclasses import dataclass
from typing import Dict, Optional
@dataclass
class TraceContext:
trace_id: str
parent_id: str
trace_flags: int = 1 # 采样标志
trace_state: Dict[str, str] = None
def to_headers(self) -> Dict[str, str]:
headers = {
"traceparent": f"00-{self.trace_id}-{self.parent_id:032x}-{self.trace_flags:02x}"
}
if self.trace_state:
headers["tracestate"] = ",".join(f"{k}={v}" for k, v in self.trace_state.items())
return headers
@classmethod
def from_headers(cls, headers: Dict[str, str]) -> Optional['TraceContext']:
traceparent = headers.get("traceparent")
if not traceparent:
return None
parts = traceparent.split("-")
if len(parts) != 4:
return None
return cls(
trace_id=parts[1],
parent_id=int(parts[2], 16),
trace_flags=int(parts[3], 16)
)
服务间追踪
实现跨服务的追踪上下文传播:
import httpx
from contextlib import asynccontextmanager
class DistributedTracer:
def __init__(self, service_name: str):
self.service_name = service_name
@asynccontextmanager
async def trace_request(self, url: str, method: str = "POST"):
trace_id = str(uuid.uuid4())
span_id = uuid.uuid4().int >> 96
context = TraceContext(trace_id=trace_id, parent_id=span_id)
headers = context.to_headers()
headers["X-Service-Name"] = self.service_name
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method=method,
url=url,
headers=headers
)
yield response
except Exception as e:
# 记录错误信息
raise
LLM应用中的分布式追踪场景
多模型协作追踪
当多个LLM模型协作完成任务时,追踪可以帮助理解调用链路:
class MultiModelTracer:
def __init__(self, tracer: DistributedTracer):
self.tracer = tracer
async def orchestrate_models(self, query: str):
trace_id = str(uuid.uuid4())
# 第一阶段:意图识别
intent = await self.call_model(
model="gpt-3.5-turbo",
prompt=f"识别用户意图:{query}",
trace_id=trace_id,
span_name="intent_detection"
)
# 第二阶段:根据意图调用专门模型
if intent == "code_generation":
result = await self.call_model(
model="gpt-4",
prompt=query,
trace_id=trace_id,
span_name="code_generation"
)
else:
result = await self.call_model(
model="gpt-3.5-turbo",
prompt=query,
trace_id=trace_id,
span_name="general_response"
)
return result
async def call_model(self, model: str, prompt: str,
trace_id: str, span_name: str):
async with self.tracer.trace_request(
url=f"/v1/chat/completions",
method="POST"
) as response:
# 记录模型调用详情
span_attributes = {
"llm.model": model,
"llm.prompt_tokens": len(prompt.split()),
"span.name": span_name
}
return await response.json()
RAG应用追踪
检索增强生成(RAG)应用涉及多个步骤,需要完整追踪:
class RAGTracer:
async def rag_query(self, query: str):
trace_id = str(uuid.uuid4())
# 1. 检索阶段
documents = await self.retrieve_documents(query, trace_id)
# 2. 重排序阶段
reranked_docs = await self.rerank_documents(documents, query, trace_id)
# 3. 生成阶段
response = await self.generate_response(query, reranked_docs, trace_id)
return response
async def retrieve_documents(self, query: str, trace_id: str):
# 追踪检索过程
span_attributes = {
"rag.retrieval.query": query,
"rag.retrieval.method": "vector_search"
}
# 执行检索...
return documents
async def generate_response(self, query: str, context: list, trace_id: str):
# 追踪生成过程
span_attributes = {
"rag.generation.context_count": len(context),
"rag.generation.total_context_length": sum(len(d) for d in context)
}
# 执行生成...
return response
数据收集与存储
使用OpenTelemetry收集数据
OpenTelemetry是分布式追踪的事实标准:
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
def setup_tracing(service_name: str):
provider = TracerProvider()
processor = BatchSpanProcessor(OTLPSpanExporter())
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
return trace.get_tracer(service_name)
追踪数据存储
选择合适的存储后端:
- Jaeger:轻量级,适合中小规模部署
- Zipkin:成熟稳定,社区活跃
- Tempo:Grafana生态,与Prometheus集成良好
- 云服务:AWS X-Ray、Google Cloud Trace
可视化分析
追踪瀑布图
将追踪数据可视化为瀑布图:
def generate_waterfall_chart(trace_data: dict):
spans = trace_data["spans"]
root_span = next(s for s in spans if not s.get("parent_span_id"))
chart = []
for span in sorted(spans, key=lambda x: x["start"]):
indent = calculate_depth(span, spans)
bar = create_bar(span["duration_ms"], max_duration=1000)
chart.append(f"{' ' * indent}{span['name']}: {bar} {span['duration_ms']:.1f}ms")
return "\n".join(chart)
最佳实践
- 标准化命名:使用一致的Span命名规范
- 采样策略:根据系统负载动态调整采样率
- 上下文传播:确保所有服务都正确处理追踪头
- 错误关联:将错误与追踪ID关联,便于排查
- 性能监控:监控追踪系统本身的性能开销
分布式追踪是理解复杂LLM应用行为的必备工具,通过正确实施,你可以获得系统的完整运行视图。