模型服务架构:Triton、TensorFlow Serving与Seldon
模型服务架构:Triton、TensorFlow Serving与Seldon
模型服务核心挑战
模型服务需要在低延迟、高吞吐和资源利用率之间取得平衡。核心挑战包括:模型加载与切换、请求批处理、GPU资源调度、多模型并发服务、A/B测试与灰度发布。选择合适的推理引擎和部署架构是成功的关键。
# 模型服务抽象层
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
import time
@dataclass
class PredictionRequest:
model_name: str
inputs: Dict[str, Any]
parameters: Dict[str, Any] = field(default_factory=dict)
request_id: str = ""
@dataclass
class PredictionResponse:
outputs: Dict[str, Any]
model_name: str
model_version: str
latency_ms: float
request_id: str = ""
class ModelBackend(ABC):
@abstractmethod
def load_model(self, model_name: str, model_path: str,
version: str = "1") -> bool:
pass
@abstractmethod
def predict(self, request: PredictionRequest) -> PredictionResponse:
pass
@abstractmethod
def unload_model(self, model_name: str) -> bool:
pass
@abstractmethod
def get_model_status(self, model_name: str) -> Dict:
pass
class TritonBackend(ModelBackend):
def __init__(self, server_url: str):
self.server_url = server_url
self.loaded_models = {}
def load_model(self, model_name: str, model_path: str,
version: str = "1") -> bool:
print(f"Triton: Loading model {model_name} v{version}")
self.loaded_models[model_name] = {
"version": version,
"status": "ready",
"load_time": time.time()
}
return True
def predict(self, request: PredictionRequest) -> PredictionResponse:
start_time = time.time()
# 模拟Triton推理调用
outputs = {"prediction": [0.95, 0.87]}
latency = (time.time() - start_time) * 1000
return PredictionResponse(
outputs=outputs,
model_name=request.model_name,
model_version=self.loaded_models.get(request.model_name, {}).get("version", "1"),
latency_ms=latency,
request_id=request.request_id
)
def unload_model(self, model_name: str) -> bool:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
return True
def get_model_status(self, model_name: str) -> Dict:
return self.loaded_models.get(model_name, {"status": "not_found"})
Triton Inference Server架构
NVIDIA Triton是高性能推理服务器,支持多种深度学习框架(TensorRT、PyTorch、TensorFlow),提供动态批处理、模型并发执行和内存池优化。其架构采用前端代理+后端执行的模式,支持GPU和CPU混合部署。
# Triton配置与部署
triton_config = {
"server": {
"http_port": 8000,
"grpc_port": 8001,
"metrics_port": 8002,
"model_repository": "/models",
"model_control_mode": "explicit",
"rate_limit": "execution_count"
},
"models": {
"resnet50": {
"platform": "tensorrt_plan",
"max_batch_size": 64,
"dynamic_batching": {
"preferred_batch_size": [16, 32],
"max_queue_delay_microseconds": 100
},
"instance_group": [
{"count": 2, "kind": "KIND_GPU", "gpus": [0, 1]}
]
},
"transformer": {
"platform": "pytorch_libtorch",
"max_batch_size": 32,
"sequence_batching": {
"max_sequence_length": 512
},
"instance_group": [
{"count": 1, "kind": "KIND_GPU", "gpus": [0]}
]
}
}
}
# Triton客户端调用
class TritonClient:
def __init__(self, url: str):
self.url = url
def infer(self, model_name: str, inputs: dict,
model_version: str = "") -> dict:
# Triton推理请求
return {
"model_name": model_name,
"outputs": {"prediction": [0.95]},
"parameters": {"batch_size": 1}
}
def get_model_repository_index(self) -> list:
return [{"name": "resnet50", "version": "1", "state": "READY"}]
def load_model(self, model_name: str, version: str = "1"):
print(f"Loading model: {model_name} v{version}")
TensorFlow Serving架构
TF Serving专为TensorFlow模型优化,支持gRPC和REST API,提供模型版本管理、热加载和批量推理。其架构包括模型服务器、可插入存储和可扩展预处理。
# TF Serving配置与调用
import tensorflow as tf
import numpy as np
class TFServingClient:
def __init__(self, serving_url: str):
self.serving_url = serving_url
def predict(self, model_name: str, input_data: dict) -> dict:
# 使用TensorFlow Serving REST API
import requests
payload = {
"instances": [
{"input": input_data}
]
}
response = requests.post(
f"{self.serving_url}/v1/models/{model_name}:predict",
json=payload
)
return response.json()
def predict_with_signature(self, model_name: str,
signature_name: str,
input_data: dict) -> dict:
"""使用特定签名进行预测"""
import requests
payload = {
"instances": [input_data]
}
response = requests.post(
f"{self.serving_url}/v1/models/{model_name}/versions/1/"
f"signatures/{signature_name}:predict",
json=payload
)
return response.json()
# 批量预测示例
class BatchPredictor:
def __init__(self, client: TFServingClient, batch_size: int = 32):
self.client = client
self.batch_size = batch_size
self.buffer = []
def add_request(self, model_name: str, input_data: dict):
self.buffer.append({"model": model_name, "input": input_data})
if len(self.buffer) >= self.batch_size:
return self.flush()
return None
def flush(self) -> list:
if not self.buffer:
return []
results = []
for item in self.buffer:
result = self.client.predict(item["model"], item["input"])
results.append(result)
self.buffer = []
return results
Seldon Core架构
Seldon Core提供Kubernetes原生的模型部署方案,支持复杂的推理图(Inference Graph),包括预处理、模型推理、后处理等环节的组合。支持多种ML框架和自定义容器。
# Seldon推理图定义
seldon_deployment = {
"apiVersion": "machinelearning.seldon.io/v1",
"kind": "SeldonDeployment",
"metadata": {
"name": "my-model",
"namespace": "seldon"
},
"spec": {
"predictors": [
{
"name": "default",
"replicas": 3,
"graph": {
"name": "classifier",
"type": "MODEL",
"implementation": "SKLEARN_SERVER",
"modelUri": "gs://models/classifier",
"children": [
{
"name": "preprocessor",
"type": "TRANSFORMER",
"implementation": "MY_PREPROCESSOR",
"children": []
}
]
},
"componentSpecs": [{
"spec": {
"containers": [{
"name": "classifier",
"image": "my-model:latest",
"resources": {
"requests": {
"cpu": "1",
"memory": "2Gi"
},
"limits": {
"nvidia.com/gpu": 1
}
}
}]
}
}]
}
]
}
}
# Seldon Python包装器
class ModelWrapper:
def __init__(self):
self.model = None
def load(self):
"""加载模型"""
self.model = "loaded_model"
def predict(self, X, features_names=None):
"""预测接口"""
# X是输入数据
return {"prediction": [0.95] * len(X)}
def feedback(self, features, reward, truth=None):
"""反馈接口,用于在线学习"""
print(f"Received feedback: reward={reward}")
# 多模型路由
class ModelRouter:
def __init__(self, models: dict):
self.models = models
def route(self, request: dict) -> str:
"""根据请求内容路由到不同模型"""
model_type = request.get("model_type", "default")
return self.models.get(model_type, "default_model")
def predict(self, model_name: str, input_data: dict) -> dict:
model = self.models.get(model_name)
if model:
return model.predict(input_data)
return {"error": "Model not found"}
性能优化策略
模型服务性能优化包括:动态批处理提升GPU利用率、模型量化降低计算开销、模型缓存减少加载时间、请求队列平滑流量峰值。监控指标包括P99延迟、吞吐量、GPU利用率和错误率。
# 动态批处理
class DynamicBatcher:
def __init__(self, max_batch_size: int = 64,
max_wait_ms: float = 10.0):
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.pending_requests = []
def add_request(self, request: dict) -> list:
self.pending_requests.append(request)
if (len(self.pending_requests) >= self.max_batch_size):
return self._process_batch()
return None
def _process_batch(self) -> list:
batch = self.pending_requests[:self.max_batch_size]
self.pending_requests = self.pending_requests[self.max_batch_size:]
return batch
# 模型量化
class ModelQuantizer:
def quantize_int8(self, model_path: str, output_path: str):
"""INT8量化,减少模型大小和推理时间"""
print(f"Quantizing {model_path} to INT8")
return output_path
def quantize_fp16(self, model_path: str, output_path: str):
"""FP16量化,平衡精度和性能"""
print(f"Quantizing {model_path} to FP16")
return output_path
def get_quantization_stats(self, original_size: int,
quantized_size: int) -> dict:
return {
"original_mb": original_size / 1024 / 1024,
"quantized_mb": quantized_size / 1024 / 1024,
"compression_ratio": original_size / quantized_size
}
# 性能监控
class ServingMetrics:
def __init__(self):
self.request_count = 0
self.latencies = []
self.errors = 0
def record_request(self, latency_ms: float, success: bool):
self.request_count += 1
self.latencies.append(latency_ms)
if not success:
self.errors += 1
def get_percentiles(self) -> dict:
sorted_latencies = sorted(self.latencies)
n = len(sorted_latencies)
return {
"p50": sorted_latencies[int(n * 0.5)] if n > 0 else 0,
"p95": sorted_latencies[int(n * 0.95)] if n > 0 else 0,
"p99": sorted_latencies[int(n * 0.99)] if n > 0 else 0
}
def get_summary(self) -> dict:
return {
"total_requests": self.request_count,
"error_rate": self.errors / max(self.request_count, 1),
"percentiles": self.get_percentiles()
}