← 返回首页
🧪

测试数据管理:脱敏、工厂模式与夹具

📂 architecture ⏱ 8 min 1539 words

测试数据管理:脱敏、工厂模式与夹具

测试数据管理概述

测试数据是测试的基础,但管理测试数据面临多重挑战:数据隐私保护、数据一致性维护、数据生成效率和数据清理策略。完整的测试数据管理体系需要解决这些问题。

# 测试数据管理框架
from dataclasses import dataclass, field
from typing import Dict, List, Any, Callable, Optional
from datetime import datetime
import random
import string

@dataclass
class TestDataset:
    name: str
    data: List[Dict]
    created_at: datetime = field(default_factory=datetime.now)
    metadata: Dict = field(default_factory=dict)

class TestDataManager:
    def __init__(self):
        self.datasets: Dict[str, TestDataset] = {}
        self.factories: Dict[str, Callable] = {}
        self.sanitizers: Dict[str, Callable] = {}
    
    def register_factory(self, name: str, factory_fn: Callable):
        """注册数据工厂"""
        self.factories[name] = factory_fn
    
    def register_sanitizer(self, name: str, sanitizer_fn: Callable):
        """注册脱敏器"""
        self.sanitizers[name] = sanitizer_fn
    
    def create_dataset(self, name: str, factory_name: str, 
                      count: int, **kwargs) -> TestDataset:
        """创建测试数据集"""
        factory = self.factories.get(factory_name)
        if not factory:
            raise ValueError(f"Factory not found: {factory_name}")
        
        data = [factory(**kwargs) for _ in range(count)]
        
        dataset = TestDataset(name=name, data=data)
        self.datasets[name] = dataset
        
        return dataset
    
    def sanitize_dataset(self, dataset_name: str, 
                        sanitizer_name: str) -> TestDataset:
        """脱敏数据集"""
        dataset = self.datasets.get(dataset_name)
        sanitizer = self.sanitizers.get(sanitizer_name)
        
        if not dataset or not sanitizer:
            raise ValueError("Dataset or sanitizer not found")
        
        sanitized_data = [sanitizer(record) for record in dataset.data]
        
        sanitized_dataset = TestDataset(
            name=f"{dataset_name}_sanitized",
            data=sanitized_data,
            metadata={"original": dataset_name}
        )
        
        self.datasets[sanitized_dataset.name] = sanitized_dataset
        return sanitized_dataset
    
    def get_dataset(self, name: str) -> Optional[TestDataset]:
        return self.datasets.get(name)
    
    def export_dataset(self, name: str, format: str = "json") -> str:
        """导出数据集"""
        dataset = self.datasets.get(name)
        if not dataset:
            raise ValueError(f"Dataset not found: {name}")
        
        if format == "json":
            import json
            return json.dumps(dataset.data, indent=2, default=str)
        elif format == "csv":
            return self._to_csv(dataset.data)
        
        raise ValueError(f"Unsupported format: {format}")
    
    def _to_csv(self, data: List[Dict]) -> str:
        if not data:
            return ""
        
        headers = data[0].keys()
        lines = [",".join(headers)]
        
        for record in data:
            values = [str(record.get(h, "")) for h in headers]
            lines.append(",".join(values))
        
        return "\n".join(lines)

数据工厂模式

工厂模式是生成测试数据的常用方法,支持灵活配置和批量生成。通过组合不同的工厂,可以创建复杂的测试场景数据。

# 数据工厂框架
from faker import Faker
import random
from datetime import datetime, timedelta

class DataFactory:
    def __init__(self):
        self.faker = Faker()
        self.field_generators = {}
        self.overrides = {}
    
    def define_field(self, field_name: str, generator: Callable):
        """定义字段生成器"""
        self.field_generators[field_name] = generator
        return self
    
    def with_overrides(self, **overrides):
        """设置字段覆盖"""
        self.overrides.update(overrides)
        return self
    
    def create(self, **kwargs) -> Dict:
        """创建单条记录"""
        record = {}
        
        # 使用定义的生成器
        for field_name, generator in self.field_generators.items():
            if field_name in kwargs:
                record[field_name] = kwargs[field_name]
            elif field_name in self.overrides:
                record[field_name] = self.overrides[field_name]
            else:
                record[field_name] = generator()
        
        # 处理额外字段
        for key, value in kwargs.items():
            if key not in record:
                record[key] = value
        
        return record
    
    def create_batch(self, count: int, **kwargs) -> List[Dict]:
        """批量创建记录"""
        return [self.create(**kwargs) for _ in range(count)]

# 用户工厂
class UserFactory(DataFactory):
    def __init__(self):
        super().__init__()
        self.define_field("id", lambda: random.randint(1000, 9999))
        self.define_field("name", lambda: self.faker.name())
        self.define_field("email", lambda: self.faker.email())
        self.define_field("phone", lambda: self.faker.phone_number())
        self.define_field("created_at", lambda: self.faker.date_time_this_year())
        self.define_field("status", lambda: random.choice(["active", "inactive", "pending"]))

# 订单工厂
class OrderFactory(DataFactory):
    def __init__(self, user_factory: UserFactory = None):
        super().__init__()
        self.user_factory = user_factory or UserFactory()
        
        self.define_field("order_id", lambda: f"ORD{random.randint(100000, 999999)}")
        self.define_field("user_id", lambda: random.randint(1000, 9999))
        self.define_field("items", self._generate_items)
        self.define_field("total", lambda: round(random.uniform(10, 500), 2))
        self.define_field("status", lambda: random.choice(["pending", "paid", "shipped", "delivered"]))
        self.define_field("created_at", lambda: self.faker.date_time_this_month())
    
    def _generate_items(self) -> List[Dict]:
        item_count = random.randint(1, 5)
        return [
            {
                "product_id": f"P{random.randint(100, 999)}",
                "quantity": random.randint(1, 3),
                "price": round(random.uniform(5, 100), 2)
            }
            for _ in range(item_count)
        ]

# 产品工厂
class ProductFactory(DataFactory):
    def __init__(self):
        super().__init__()
        self.define_field("product_id", lambda: f"P{random.randint(100, 999)}")
        self.define_field("name", lambda: self.faker.catch_phrase())
        self.define_field("description", lambda: self.faker.text(max_nb_chars=200))
        self.define_field("price", lambda: round(random.uniform(9.99, 999.99), 2))
        self.define_field("category", lambda: random.choice(["electronics", "clothing", "books", "home"]))
        self.define_field("stock", lambda: random.randint(0, 1000))

# 复合工厂
class TestDataSuite:
    def __init__(self):
        self.user_factory = UserFactory()
        self.order_factory = OrderFactory(self.user_factory)
        self.product_factory = ProductFactory()
    
    def create_ecommerce_scenario(self, user_count: int = 10, 
                                  orders_per_user: int = 5) -> Dict:
        """创建电商测试场景"""
        users = self.user_factory.create_batch(user_count)
        products = self.product_factory.create_batch(20)
        
        orders = []
        for user in users:
            user_orders = self.order_factory.create_batch(
                orders_per_user,
                user_id=user["id"]
            )
            orders.extend(user_orders)
        
        return {
            "users": users,
            "products": products,
            "orders": orders
        }

# 使用示例
suite = TestDataSuite()
scenario = suite.create_ecommerce_scenario(user_count=5, orders_per_user=3)
print(f"Created: {len(scenario['users'])} users, {len(scenario['orders'])} orders")

数据脱敏技术

数据脱敏保护敏感信息,同时保持数据的测试价值。常用技术包括:替换、掩码、打乱、泛化和加密。

# 数据脱敏框架
import hashlib
import re
from typing import Callable, Dict

class DataSanitizer:
    def __init__(self):
        self.sanitizers: Dict[str, Callable] = {}
        self.preserve_format = True
    
    def register_sanitizer(self, field_pattern: str, 
                          sanitizer_fn: Callable):
        """注册脱敏器"""
        self.sanitizers[field_pattern] = sanitizer_fn
    
    def sanitize(self, record: Dict) -> Dict:
        """脱敏记录"""
        sanitized = record.copy()
        
        for field_name, value in sanitized.items():
            for pattern, sanitizer in self.sanitizers.items():
                if re.match(pattern, field_name):
                    sanitized[field_name] = sanitizer(value)
                    break
        
        return sanitized

class EmailSanitizer:
    """邮箱脱敏"""
    
    def __init__(self, preserve_domain: bool = True):
        self.preserve_domain = preserve_domain
    
    def __call__(self, email: str) -> str:
        if not email or "@" not in email:
            return "***MASKED***"
        
        local, domain = email.split("@")
        
        if self.preserve_domain:
            masked_local = local[0] + "***" if local else "***"
            return f"{masked_local}@{domain}"
        else:
            return "***@***.com"

class PhoneSanitizer:
    """电话脱敏"""
    
    def __call__(self, phone: str) -> str:
        if not phone:
            return "***MASKED***"
        
        # 保留最后4位
        return phone[:-4] + "****" if len(phone) >= 4 else "****"

class NameSanitizer:
    """姓名脱敏"""
    
    def __init__(self, faker_instance):
        self.faker = faker_instance
    
    def __call__(self, name: str) -> str:
        return self.faker.name()

class AddressSanitizer:
    """地址脱敏"""
    
    def __init__(self, faker_instance):
        self.faker = faker_instance
    
    def __call__(self, address: str) -> str:
        return self.faker.address()

class CreditCardSanitizer:
    """信用卡脱敏"""
    
    def __call__(self, card: str) -> str:
        if not card:
            return "***MASKED***"
        
        # 只保留最后4位
        return "****-****-****-" + card[-4:] if len(card) >= 4 else "****"

class SSNSanitizer:
    """社保号脱敏"""
    
    def __call__(self, ssn: str) -> str:
        if not ssn:
            return "***MASKED***"
        
        return "***-**-" + ssn[-4:] if len(ssn) >= 4 else "***-**-****"

# 配置化脱敏
class SanitizerConfig:
    def __init__(self):
        self.config = {
            "email": {"type": "email", "preserve_domain": True},
            "phone": {"type": "phone"},
            "name": {"type": "name"},
            "address": {"type": "address"},
            "credit_card": {"type": "credit_card"},
            "ssn": {"type": "ssn"},
            "ip_address": {"type": "ip"},
        }
    
    def get_sanitizer(self, field_name: str, faker_instance) -> Callable:
        """根据配置获取脱敏器"""
        field_config = self.config.get(field_name)
        
        if not field_config:
            return lambda x: "***MASKED***"
        
        sanitizer_type = field_config["type"]
        
        if sanitizer_type == "email":
            return EmailSanitizer(field_config.get("preserve_domain", True))
        elif sanitizer_type == "phone":
            return PhoneSanitizer()
        elif sanitizer_type == "name":
            return NameSanitizer(faker_instance)
        elif sanitizer_type == "address":
            return AddressSanitizer(faker_instance)
        elif sanitizer_type == "credit_card":
            return CreditCardSanitizer()
        elif sanitizer_type == "ssn":
            return SSNSanitizer()
        
        return lambda x: "***MASKED***"

# 使用示例
from faker import Faker

faker = Faker()
sanitizer = DataSanitizer()

sanitizer.register_sanitizer(r".*email.*", EmailSanitizer())
sanitizer.register_sanitizer(r".*phone.*", PhoneSanitizer())
sanitizer.register_sanitizer(r".*name.*", NameSanitizer(faker))

original = {
    "user_id": 123,
    "name": "John Doe",
    "email": "john.doe@example.com",
    "phone": "+1-555-123-4567"
}

sanitized = sanitizer.sanitize(original)
print(f"Original: {original}")
print(f"Sanitized: {sanitized}")

测试夹具管理

测试夹具是预先定义的测试数据集合,用于设置测试环境。夹具管理包括创建、加载、清理和版本控制。

# 测试夹具框架
import json
from pathlib import Path
from typing import Dict, List, Any
from datetime import datetime

class TestFixture:
    def __init__(self, name: str):
        self.name = name
        self.data: Dict[str, Any] = {}
        self.dependencies: List[str] = []
        self.setup_fn: Callable = None
        self.teardown_fn: Callable = None
    
    def add_data(self, key: str, value: Any):
        """添加数据"""
        self.data[key] = value
        return self
    
    def depends_on(self, fixture_name: str):
        """声明依赖"""
        self.dependencies.append(fixture_name)
        return self
    
    def setup(self, fn: Callable):
        """设置setup函数"""
        self.setup_fn = fn
        return self
    
    def teardown(self, fn: Callable):
        """设置teardown函数"""
        self.teardown_fn = fn
        return self

class FixtureManager:
    def __init__(self):
        self.fixtures: Dict[str, TestFixture] = {}
        self.loaded_fixtures: Dict[str, Any] = {}
    
    def register(self, fixture: TestFixture):
        """注册夹具"""
        self.fixtures[fixture.name] = fixture
    
    def load(self, fixture_name: str) -> Any:
        """加载夹具"""
        if fixture_name in self.loaded_fixtures:
            return self.loaded_fixtures[fixture_name]
        
        fixture = self.fixtures.get(fixture_name)
        if not fixture:
            raise ValueError(f"Fixture not found: {fixture_name}")
        
        # 加载依赖
        for dep in fixture.dependencies:
            self.load(dep)
        
        # 执行setup
        if fixture.setup_fn:
            fixture.setup_fn()
        
        self.loaded_fixtures[fixture_name] = fixture.data
        return fixture.data
    
    def teardown_all(self):
        """清理所有夹具"""
        for fixture_name in reversed(list(self.loaded_fixtures.keys())):
            fixture = self.fixtures.get(fixture_name)
            if fixture and fixture.teardown_fn:
                fixture.teardown_fn()
        
        self.loaded_fixtures.clear()
    
    def export_fixture(self, name: str, filepath: str):
        """导出夹具"""
        fixture = self.fixtures.get(name)
        if not fixture:
            raise ValueError(f"Fixture not found: {name}")
        
        with open(filepath, 'w') as f:
            json.dump({
                "name": fixture.name,
                "data": fixture.data,
                "dependencies": fixture.dependencies
            }, f, indent=2, default=str)
    
    def import_fixture(self, filepath: str) -> TestFixture:
        """导入夹具"""
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        fixture = TestFixture(data["name"])
        fixture.data = data["data"]
        fixture.dependencies = data.get("dependencies", [])
        
        self.register(fixture)
        return fixture

# 夹具工厂
class FixtureFactory:
    def __init__(self, manager: FixtureManager):
        self.manager = manager
    
    def create_database_fixtures(self, db_connection):
        """创建数据库夹具"""
        # 用户夹具
        user_fixture = TestFixture("users")
        user_fixture.add_data("users", [
            {"id": 1, "name": "Admin User", "role": "admin"},
            {"id": 2, "name": "Regular User", "role": "user"},
            {"id": 3, "name": "Guest User", "role": "guest"}
        ])
        
        @user_fixture.setup
        def setup_users():
            for user in user_fixture.data["users"]:
                db_connection.execute(
                    "INSERT INTO users (id, name, role) VALUES (?, ?, ?)",
                    (user["id"], user["name"], user["role"])
                )
        
        @user_fixture.teardown
        def teardown_users():
            db_connection.execute("DELETE FROM users")
        
        self.manager.register(user_fixture)
        
        # 订单夹具
        order_fixture = TestFixture("orders")
        order_fixture.depends_on("users")
        order_fixture.add_data("orders", [
            {"order_id": "ORD001", "user_id": 1, "total": 99.99},
            {"order_id": "ORD002", "user_id": 2, "total": 149.99}
        ])
        
        @order_fixture.setup
        def setup_orders():
            for order in order_fixture.data["orders"]:
                db_connection.execute(
                    "INSERT INTO orders (order_id, user_id, total) VALUES (?, ?, ?)",
                    (order["order_id"], order["user_id"], order["total"])
                )
        
        @order_fixture.teardown
        def teardown_orders():
            db_connection.execute("DELETE FROM orders")
        
        self.manager.register(order_fixture)

# 使用示例
manager = FixtureManager()
factory = FixtureFactory(manager)

# 创建并加载夹具
# factory.create_database_fixtures(db_connection)
# users_data = manager.load("users")
# orders_data = manager.load("orders")

# 测试结束后清理
# manager.teardown_all()

数据子集化

数据子集化从生产数据中提取有代表性的子集用于测试,减少数据量同时保持数据特征。

# 数据子集化器
from typing import List, Dict
import random

class DataSubsetter:
    def __init__(self):
        self.strategies = {
            "random": self._random_subset,
            "stratified": self._stratified_subset,
            "boundary": self._boundary_subset,
            "time_based": self._time_based_subset
        }
    
    def create_subset(self, data: List[Dict], strategy: str, 
                     size: int = None, **kwargs) -> List[Dict]:
        """创建数据子集"""
        subset_fn = self.strategies.get(strategy)
        if not subset_fn:
            raise ValueError(f"Unknown strategy: {strategy}")
        
        return subset_fn(data, size, **kwargs)
    
    def _random_subset(self, data: List[Dict], size: int, 
                      **kwargs) -> List[Dict]:
        """随机子集"""
        if size is None:
            size = len(data) // 10  # 默认10%
        
        return random.sample(data, min(size, len(data)))
    
    def _stratified_subset(self, data: List[Dict], size: int,
                          stratify_by: str, **kwargs) -> List[Dict]:
        """分层子集"""
        if size is None:
            size = len(data) // 10
        
        # 按字段分组
        groups = {}
        for record in data:
            key = record.get(stratify_by, "default")
            if key not in groups:
                groups[key] = []
            groups[key].append(record)
        
        # 从每个组中按比例抽样
        subset = []
        for group_key, group_data in groups.items():
            group_size = int(size * len(group_data) / len(data))
            subset.extend(random.sample(group_data, min(group_size, len(group_data))))
        
        return subset
    
    def _boundary_subset(self, data: List[Dict], size: int,
                        boundary_field: str, **kwargs) -> List[Dict]:
        """边界值子集"""
        if not data:
            return []
        
        # 获取边界值
        values = [record.get(boundary_field) for record in data if boundary_field in record]
        
        if not values:
            return data[:size] if size else data
        
        min_val = min(values)
        max_val = max(values)
        avg_val = sum(values) / len(values)
        
        # 选择边界值附近的记录
        boundary_records = []
        for record in data:
            val = record.get(boundary_field)
            if val in [min_val, max_val] or abs(val - avg_val) < (max_val - min_val) * 0.1:
                boundary_records.append(record)
        
        return boundary_records[:size] if size else boundary_records
    
    def _time_based_subset(self, data: List[Dict], size: int,
                          time_field: str, days: int = 7, **kwargs) -> List[Dict]:
        """基于时间的子集"""
        from datetime import datetime, timedelta
        
        cutoff = datetime.now() - timedelta(days=days)
        
        subset = []
        for record in data:
            record_time = record.get(time_field)
            if isinstance(record_time, str):
                record_time = datetime.fromisoformat(record_time.replace('Z', '+00:00'))
            
            if record_time and record_time >= cutoff:
                subset.append(record)
        
        return subset[:size] if size else subset

# 使用示例
subsetter = DataSubsetter()

# 假设有大量生产数据
production_data = [
    {"id": i, "type": random.choice(["A", "B", "C"]), "value": random.randint(1, 100)}
    for i in range(10000)
]

# 随机子集
random_subset = subsetter.create_subset(production_data, "random", size=100)
print(f"Random subset: {len(random_subset)} records")

# 分层子集
stratified_subset = subsetter.create_subset(
    production_data, "stratified", 
    size=100, stratify_by="type"
)
print(f"Stratified subset: {len(stratified_subset)} records")