特征存储架构:离线在线一致性与Feast实践
特征存储架构:离线在线一致性与Feast实践
特征存储的核心价值
特征存储解决机器学习中最常见但最容易被忽视的问题:训练和推理使用不一致的特征计算逻辑,导致模型性能下降(训练-服务偏差)。特征存储通过统一的特征定义、计算逻辑和服务接口,确保特征在离线训练和在线推理中保持完全一致。
# 特征定义DSL
from dataclasses import dataclass
from typing import Callable, Any
from datetime import datetime, timedelta
@dataclass
class FeatureDefinition:
name: str
entity: str
dtype: str
description: str
owner: str
ttl: timedelta = timedelta(days=1)
tags: list = None
class Feature:
def __init__(self, definition: FeatureDefinition, transform: Callable):
self.definition = definition
self.transform = transform
def compute(self, data: Any) -> Any:
return self.transform(data)
# 特征注册表
class FeatureRegistry:
def __init__(self):
self.features = {}
def register(self, feature: Feature):
self.features[feature.definition.name] = feature
print(f"Registered feature: {feature.definition.name}")
def get(self, name: str) -> Feature:
return self.features.get(name)
def list_features(self, entity: str = None) -> list:
if entity:
return [f for f in self.features.values()
if f.definition.entity == entity]
return list(self.features.values())
# 使用示例
registry = FeatureRegistry()
user_avg_purchase = Feature(
definition=FeatureDefinition(
name="user_avg_purchase_30d",
entity="user_id",
dtype="float",
description="用户过去30天平均消费金额",
owner="data_team",
tags=["monetary", "user"]
),
transform=lambda df: df.groupby("user_id")["amount"]
.rolling("30D").mean()
.reset_index(level=0)
)
registry.register(user_avg_purchase)
离线特征计算
离线特征计算通常基于Spark或Flink处理海量历史数据,生成训练数据集和特征快照。计算结果存储在数据湖(HDFS/S3)中,支持批量查询和时间点回溯。
# 离线特征计算管道
import pandas as pd
from datetime import datetime, timedelta
class OfflineFeatureCompute:
def __init__(self, feature_registry: FeatureRegistry):
self.registry = feature_registry
def compute_features(self, entity_df: pd.DataFrame,
feature_names: list,
as_of_date: datetime) -> pd.DataFrame:
"""为给定实体和时间点计算特征值"""
result = entity_df["entity_id"](/notes/entity_id).copy()
for feature_name in feature_names:
feature = self.registry.get(feature_name)
if feature:
result[feature_name] = feature.compute(entity_df)
return result
def generate_training_dataset(self,
entities: pd.DataFrame,
labels: pd.DataFrame,
feature_names: list,
snapshot_date: datetime) -> pd.DataFrame:
"""生成训练数据集:特征 + 标签"""
features_df = self.compute_features(
entities, feature_names, snapshot_date
)
# 时间对齐:确保特征时间点早于标签时间点
training_data = features_df.merge(
labels, on="entity_id", how="inner"
)
print(f"Generated training dataset: {len(training_data)} samples, "
f"{len(feature_names)} features")
return training_data
def backfill_features(self, feature_names: list,
start_date: datetime,
end_date: datetime,
interval_days: int = 1):
"""批量回填历史特征"""
current_date = start_date
while current_date <= end_date:
print(f"Computing features for {current_date.date()}")
# 执行特征计算
current_date += timedelta(days=interval_days)
# 特征计算示例
compute = OfflineFeatureCompute(registry)
training_data = compute.generate_training_dataset(
entities=user_df,
labels=purchase_labels,
feature_names=["user_avg_purchase_30d", "user_purchase_count_7d"],
snapshot_date=datetime(2024, 1, 15)
)
在线特征服务
在线特征服务需要毫秒级延迟,通常使用Redis作为在线存储,配合预计算和缓存策略。Feast的Online Store支持Redis、DynamoDB等后端。
# 在线特征服务
import redis
import json
from typing import Dict, List
class OnlineFeatureStore:
def __init__(self, redis_client: redis.Redis, feature_registry: FeatureRegistry):
self.redis = redis_client
self.registry = feature_registry
def get_online_features(self,
entity_ids: List[str],
feature_names: List[str]) -> Dict[str, Dict]:
"""批量获取在线特征"""
results = {}
pipeline = self.redis.pipeline()
for entity_id in entity_ids:
for feature_name in feature_names:
key = f"feature:{entity_id}:{feature_name}"
pipeline.hgetall(key)
responses = pipeline.execute()
idx = 0
for entity_id in entity_ids:
results[entity_id] = {}
for feature_name in feature_names:
data = responses[idx]
if data:
results[entity_id][feature_name] = {
"value": float(data.get(b"value", 0)),
"timestamp": data.get(b"timestamp", b"").decode()
}
idx += 1
return results
def write_online_features(self,
entity_id: str,
features: Dict[str, any]):
"""写入在线特征"""
pipeline = self.redis.pipeline()
for feature_name, value in features.items():
key = f"feature:{entity_id}:{feature_name}"
feature = self.registry.get(feature_name)
ttl = feature.definition.ttl if feature else timedelta(hours=24)
pipeline.hset(key, mapping={
"value": str(value),
"timestamp": datetime.now().isoformat()
})
pipeline.expire(key, int(ttl.total_seconds()))
pipeline.execute()
# 使用示例
store = OnlineFeatureStore(redis.Redis(), registry)
features = store.get_online_features(
entity_ids=["user_123", "user_456"],
feature_names=["user_avg_purchase_30d", "user_purchase_count_7d"]
)
Feast集成实践
Feast是开源特征存储的标准实现,支持离线存储(Parquet/BigQuery)和在线存储(Redis/SQLite)。通过Feature Store API统一管理特征定义、加载和查询。
# Feast特征存储配置
from feast import FeatureStore, Entity, Feature, ValueType
from feast import FileSource, RedisOnlineStore
# 定义实体
user_entity = Entity(
name="user_id",
value_type=ValueType.INT64,
description="用户唯一标识"
)
# 定义特征视图
from feast import FeatureView, Field
user_features_view = FeatureView(
name="user_features",
entities=["user_id"],
ttl=timedelta(days=1),
schema=[
Field(name="avg_purchase_30d", dtype=ValueType.FLOAT),
Field(name="purchase_count_7d", dtype=ValueType.INT64),
Field(name="last_purchase_days", dtype=ValueType.INT64)
],
source=FileSource(
path="s3://features/user_features.parquet",
timestamp_field="event_timestamp"
)
)
# 初始化Feature Store
store = FeatureStore(repo_path=".")
# 离线查询(训练)
training_df = store.get_historical_features(
entity_df=entity_df,
features=[
"user_features:avg_purchase_30d",
"user_features:purchase_count_7d"
]
).to_df()
# 在线查询(推理)
online_features = store.get_online_features(
features=[
"user_features:avg_purchase_30d",
"user_features:purchase_count_7d"
],
entity_rows=[
{"user_id": 123},
{"user_id": 456}
]
).to_dict()
特征监控与质量保障
特征存储需要监控特征质量:数据新鲜度、缺失率、分布漂移。通过定期检查和告警机制,及时发现特征异常,避免影响模型效果。
# 特征质量监控
class FeatureMonitor:
def __init__(self, feature_store: OnlineFeatureStore):
self.store = feature_store
def check_freshness(self, feature_names: list,
max_staleness_hours: int = 24) -> dict:
"""检查特征数据新鲜度"""
results = {}
for feature_name in feature_names:
# 查询最近更新时间
last_update = datetime.now() - timedelta(hours=12)
staleness_hours = (datetime.now() - last_update).total_seconds() / 3600
results[feature_name] = {
"last_update": last_update.isoformat(),
"staleness_hours": staleness_hours,
"is_fresh": staleness_hours < max_staleness_hours
}
return results
def check_distribution(self, feature_name: str,
reference_stats: dict,
current_stats: dict) -> dict:
"""检查特征分布是否发生漂移"""
psi = self._calculate_psi(reference_stats, current_stats)
return {
"feature": feature_name,
"psi": psi,
"drift_detected": psi > 0.2,
"recommendation": "investigate" if psi > 0.2 else "normal"
}
def _calculate_psi(self, expected: dict, actual: dict) -> float:
"""计算PSI(稳定性指标)"""
psi = 0.0
for key in expected:
if key in actual:
e = expected[key] + 1e-6
a = actual[key] + 1e-6
psi += (a - e) * np.log(a / e)
return abs(psi)