← 返回首页
🤖

特征存储架构:离线在线一致性与Feast实践

📂 architecture ⏱ 3 min 600 words

特征存储架构:离线在线一致性与Feast实践

特征存储的核心价值

特征存储解决机器学习中最常见但最容易被忽视的问题:训练和推理使用不一致的特征计算逻辑,导致模型性能下降(训练-服务偏差)。特征存储通过统一的特征定义、计算逻辑和服务接口,确保特征在离线训练和在线推理中保持完全一致。

# 特征定义DSL
from dataclasses import dataclass
from typing import Callable, Any
from datetime import datetime, timedelta

@dataclass
class FeatureDefinition:
    name: str
    entity: str
    dtype: str
    description: str
    owner: str
    ttl: timedelta = timedelta(days=1)
    tags: list = None

class Feature:
    def __init__(self, definition: FeatureDefinition, transform: Callable):
        self.definition = definition
        self.transform = transform
    
    def compute(self, data: Any) -> Any:
        return self.transform(data)

# 特征注册表
class FeatureRegistry:
    def __init__(self):
        self.features = {}
    
    def register(self, feature: Feature):
        self.features[feature.definition.name] = feature
        print(f"Registered feature: {feature.definition.name}")
    
    def get(self, name: str) -> Feature:
        return self.features.get(name)
    
    def list_features(self, entity: str = None) -> list:
        if entity:
            return [f for f in self.features.values() 
                    if f.definition.entity == entity]
        return list(self.features.values())

# 使用示例
registry = FeatureRegistry()

user_avg_purchase = Feature(
    definition=FeatureDefinition(
        name="user_avg_purchase_30d",
        entity="user_id",
        dtype="float",
        description="用户过去30天平均消费金额",
        owner="data_team",
        tags=["monetary", "user"]
    ),
    transform=lambda df: df.groupby("user_id")["amount"]
                         .rolling("30D").mean()
                         .reset_index(level=0)
)

registry.register(user_avg_purchase)

离线特征计算

离线特征计算通常基于Spark或Flink处理海量历史数据,生成训练数据集和特征快照。计算结果存储在数据湖(HDFS/S3)中,支持批量查询和时间点回溯。

# 离线特征计算管道
import pandas as pd
from datetime import datetime, timedelta

class OfflineFeatureCompute:
    def __init__(self, feature_registry: FeatureRegistry):
        self.registry = feature_registry
    
    def compute_features(self, entity_df: pd.DataFrame, 
                        feature_names: list,
                        as_of_date: datetime) -> pd.DataFrame:
        """为给定实体和时间点计算特征值"""
        result = entity_df["entity_id"](/notes/entity_id).copy()
        
        for feature_name in feature_names:
            feature = self.registry.get(feature_name)
            if feature:
                result[feature_name] = feature.compute(entity_df)
        
        return result
    
    def generate_training_dataset(self, 
                                  entities: pd.DataFrame,
                                  labels: pd.DataFrame,
                                  feature_names: list,
                                  snapshot_date: datetime) -> pd.DataFrame:
        """生成训练数据集:特征 + 标签"""
        features_df = self.compute_features(
            entities, feature_names, snapshot_date
        )
        
        # 时间对齐:确保特征时间点早于标签时间点
        training_data = features_df.merge(
            labels, on="entity_id", how="inner"
        )
        
        print(f"Generated training dataset: {len(training_data)} samples, "
              f"{len(feature_names)} features")
        return training_data
    
    def backfill_features(self, feature_names: list,
                         start_date: datetime,
                         end_date: datetime,
                         interval_days: int = 1):
        """批量回填历史特征"""
        current_date = start_date
        while current_date <= end_date:
            print(f"Computing features for {current_date.date()}")
            # 执行特征计算
            current_date += timedelta(days=interval_days)

# 特征计算示例
compute = OfflineFeatureCompute(registry)
training_data = compute.generate_training_dataset(
    entities=user_df,
    labels=purchase_labels,
    feature_names=["user_avg_purchase_30d", "user_purchase_count_7d"],
    snapshot_date=datetime(2024, 1, 15)
)

在线特征服务

在线特征服务需要毫秒级延迟,通常使用Redis作为在线存储,配合预计算和缓存策略。Feast的Online Store支持Redis、DynamoDB等后端。

# 在线特征服务
import redis
import json
from typing import Dict, List

class OnlineFeatureStore:
    def __init__(self, redis_client: redis.Redis, feature_registry: FeatureRegistry):
        self.redis = redis_client
        self.registry = feature_registry
    
    def get_online_features(self, 
                           entity_ids: List[str],
                           feature_names: List[str]) -> Dict[str, Dict]:
        """批量获取在线特征"""
        results = {}
        pipeline = self.redis.pipeline()
        
        for entity_id in entity_ids:
            for feature_name in feature_names:
                key = f"feature:{entity_id}:{feature_name}"
                pipeline.hgetall(key)
        
        responses = pipeline.execute()
        idx = 0
        
        for entity_id in entity_ids:
            results[entity_id] = {}
            for feature_name in feature_names:
                data = responses[idx]
                if data:
                    results[entity_id][feature_name] = {
                        "value": float(data.get(b"value", 0)),
                        "timestamp": data.get(b"timestamp", b"").decode()
                    }
                idx += 1
        
        return results
    
    def write_online_features(self, 
                             entity_id: str,
                             features: Dict[str, any]):
        """写入在线特征"""
        pipeline = self.redis.pipeline()
        
        for feature_name, value in features.items():
            key = f"feature:{entity_id}:{feature_name}"
            feature = self.registry.get(feature_name)
            ttl = feature.definition.ttl if feature else timedelta(hours=24)
            
            pipeline.hset(key, mapping={
                "value": str(value),
                "timestamp": datetime.now().isoformat()
            })
            pipeline.expire(key, int(ttl.total_seconds()))
        
        pipeline.execute()

# 使用示例
store = OnlineFeatureStore(redis.Redis(), registry)
features = store.get_online_features(
    entity_ids=["user_123", "user_456"],
    feature_names=["user_avg_purchase_30d", "user_purchase_count_7d"]
)

Feast集成实践

Feast是开源特征存储的标准实现,支持离线存储(Parquet/BigQuery)和在线存储(Redis/SQLite)。通过Feature Store API统一管理特征定义、加载和查询。

# Feast特征存储配置
from feast import FeatureStore, Entity, Feature, ValueType
from feast import FileSource, RedisOnlineStore

# 定义实体
user_entity = Entity(
    name="user_id",
    value_type=ValueType.INT64,
    description="用户唯一标识"
)

# 定义特征视图
from feast import FeatureView, Field

user_features_view = FeatureView(
    name="user_features",
    entities=["user_id"],
    ttl=timedelta(days=1),
    schema=[
        Field(name="avg_purchase_30d", dtype=ValueType.FLOAT),
        Field(name="purchase_count_7d", dtype=ValueType.INT64),
        Field(name="last_purchase_days", dtype=ValueType.INT64)
    ],
    source=FileSource(
        path="s3://features/user_features.parquet",
        timestamp_field="event_timestamp"
    )
)

# 初始化Feature Store
store = FeatureStore(repo_path=".")

# 离线查询(训练)
training_df = store.get_historical_features(
    entity_df=entity_df,
    features=[
        "user_features:avg_purchase_30d",
        "user_features:purchase_count_7d"
    ]
).to_df()

# 在线查询(推理)
online_features = store.get_online_features(
    features=[
        "user_features:avg_purchase_30d",
        "user_features:purchase_count_7d"
    ],
    entity_rows=[
        {"user_id": 123},
        {"user_id": 456}
    ]
).to_dict()

特征监控与质量保障

特征存储需要监控特征质量:数据新鲜度、缺失率、分布漂移。通过定期检查和告警机制,及时发现特征异常,避免影响模型效果。

# 特征质量监控
class FeatureMonitor:
    def __init__(self, feature_store: OnlineFeatureStore):
        self.store = feature_store
    
    def check_freshness(self, feature_names: list, 
                       max_staleness_hours: int = 24) -> dict:
        """检查特征数据新鲜度"""
        results = {}
        for feature_name in feature_names:
            # 查询最近更新时间
            last_update = datetime.now() - timedelta(hours=12)
            staleness_hours = (datetime.now() - last_update).total_seconds() / 3600
            
            results[feature_name] = {
                "last_update": last_update.isoformat(),
                "staleness_hours": staleness_hours,
                "is_fresh": staleness_hours < max_staleness_hours
            }
        return results
    
    def check_distribution(self, feature_name: str,
                          reference_stats: dict,
                          current_stats: dict) -> dict:
        """检查特征分布是否发生漂移"""
        psi = self._calculate_psi(reference_stats, current_stats)
        return {
            "feature": feature_name,
            "psi": psi,
            "drift_detected": psi > 0.2,
            "recommendation": "investigate" if psi > 0.2 else "normal"
        }
    
    def _calculate_psi(self, expected: dict, actual: dict) -> float:
        """计算PSI(稳定性指标)"""
        psi = 0.0
        for key in expected:
            if key in actual:
                e = expected[key] + 1e-6
                a = actual[key] + 1e-6
                psi += (a - e) * np.log(a / e)
        return abs(psi)