← 返回首页
🧪

API测试架构:自动化Mock与流量录制

📂 architecture ⏱ 8 min 1498 words

API测试架构:自动化Mock与流量录制

API测试策略

API测试是微服务架构中的关键测试层次,需要验证接口契约、数据格式、错误处理和性能。完整的API测试策略包括:单元测试(Mock外部依赖)、集成测试(真实服务交互)、契约测试(接口一致性)和流量测试(生产流量回放)。

# API测试框架核心
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional, Callable
from enum import Enum
import json
import time

class HTTPMethod(Enum):
    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    DELETE = "DELETE"
    PATCH = "PATCH"

@dataclass
class APIRequest:
    method: HTTPMethod
    url: str
    headers: Dict[str, str] = field(default_factory=dict)
    body: Any = None
    query_params: Dict[str, str] = field(default_factory=dict)

@dataclass
class APIResponse:
    status_code: int
    headers: Dict[str, str]
    body: Any
    latency_ms: float
    raw_response: Any = None

@dataclass
class APITestCase:
    name: str
    request: APIRequest
    expected_status: int
    expected_body: Dict = None
    expected_headers: Dict = None
    assertions: List[Callable] = field(default_factory=list)
    tags: List[str] = field(default_factory=list)

class APITestRunner:
    def __init__(self, base_url: str):
        self.base_url = base_url
        self.results: List[Dict] = []
        self.mock_server = None
    
    def run_test(self, test_case: APITestCase) -> Dict:
        """运行单个测试"""
        start_time = time.time()
        
        try:
            # 发送请求
            response = self._send_request(test_case.request)
            
            # 验证状态码
            status_ok = response.status_code == test_case.expected_status
            
            # 验证响应体
            body_ok = True
            if test_case.expected_body:
                body_ok = self._validate_body(response.body, test_case.expected_body)
            
            # 运行自定义断言
            assertions_passed = True
            for assertion in test_case.assertions:
                if not assertion(response):
                    assertions_passed = False
                    break
            
            latency = (time.time() - start_time) * 1000
            
            result = {
                "name": test_case.name,
                "status": "passed" if (status_ok and body_ok and assertions_passed) else "failed",
                "latency_ms": latency,
                "checks": {
                    "status_code": status_ok,
                    "response_body": body_ok,
                    "assertions": assertions_passed
                }
            }
        
        except Exception as e:
            result = {
                "name": test_case.name,
                "status": "error",
                "error": str(e),
                "latency_ms": (time.time() - start_time) * 1000
            }
        
        self.results.append(result)
        return result
    
    def _send_request(self, request: APIRequest) -> APIResponse:
        """发送HTTP请求"""
        import requests
        
        start_time = time.time()
        
        response = requests.request(
            method=request.method.value,
            url=f"{self.base_url}{request.url}",
            headers=request.headers,
            json=request.body if request.body else None,
            params=request.query_params
        )
        
        latency = (time.time() - start_time) * 1000
        
        return APIResponse(
            status_code=response.status_code,
            headers=dict(response.headers),
            body=response.json() if response.text else None,
            latency_ms=latency,
            raw_response=response
        )
    
    def _validate_body(self, actual: Any, expected: Any) -> bool:
        """验证响应体"""
        if isinstance(expected, dict):
            for key, value in expected.items():
                if key not in actual:
                    return False
                if isinstance(value, dict):
                    if not self._validate_body(actual[key], value):
                        return False
                elif actual[key] != value:
                    return False
            return True
        return actual == expected
    
    def generate_report(self) -> Dict:
        """生成测试报告"""
        total = len(self.results)
        passed = sum(1 for r in self.results if r["status"] == "passed")
        
        latencies = [r["latency_ms"] for r in self.results if "latency_ms" in r]
        
        return {
            "summary": {
                "total": total,
                "passed": passed,
                "failed": total - passed,
                "pass_rate": passed / total if total > 0 else 0
            },
            "performance": {
                "avg_latency_ms": sum(latencies) / len(latencies) if latencies else 0,
                "p95_latency_ms": sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
                "p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)] if latencies else 0
            },
            "results": self.results
        }

# 使用示例
runner = APITestRunner("https://api.example.com")

test_get_user = APITestCase(
    name="get_user_by_id",
    request=APIRequest(
        method=HTTPMethod.GET,
        url="/users/123",
        headers={"Accept": "application/json"}
    ),
    expected_status=200,
    expected_body={"id": 123, "name": "Test User"}
)

result = runner.run_test(test_get_user)
print(f"Test result: {result['status']}")

Mock服务架构

Mock服务模拟外部API的行为,支持测试隔离和场景模拟。支持静态配置、动态响应和录制回放模式。

# Mock服务框架
from typing import Dict, List, Any, Callable
from dataclasses import dataclass, field
import json
import re
from datetime import datetime

@dataclass
class MockEndpoint:
    method: str
    path: str
    response_status: int
    response_body: Any
    response_headers: Dict[str, str] = field(default_factory=dict)
    delay_ms: int = 0
    matchers: List[Dict] = field(default_factory=list)

class MockServer:
    def __init__(self, port: int = 8080):
        self.port = port
        self.endpoints: List[MockEndpoint] = []
        self.request_log: List[Dict] = []
        self.running = False
    
    def add_endpoint(self, method: str, path: str, 
                    status: int, body: Any, **kwargs):
        """添加Mock端点"""
        endpoint = MockEndpoint(
            method=method.upper(),
            path=path,
            response_status=status,
            response_body=body,
            **kwargs
        )
        self.endpoints.append(endpoint)
        return self
    
    def handle_request(self, method: str, path: str,
                      headers: Dict = None, body: Any = None) -> Dict:
        """处理请求"""
        # 记录请求
        self.request_log.append({
            "method": method,
            "path": path,
            "headers": headers,
            "body": body,
            "timestamp": datetime.now().isoformat()
        })
        
        # 查找匹配的端点
        for endpoint in self.endpoints:
            if (endpoint.method == method.upper() and 
                self._match_path(endpoint.path, path)):
                
                # 检查匹配器
                if self._check_matchers(endpoint, headers, body):
                    return {
                        "status": endpoint.response_status,
                        "headers": endpoint.response_headers,
                        "body": endpoint.response_body
                    }
        
        return {"status": 404, "body": {"error": "Not found"}}
    
    def _match_path(self, pattern: str, path: str) -> bool:
        """匹配路径"""
        # 支持路径参数: /users/{id}
        pattern = re.sub(r'\{(\w+)\}', r'(?P<\1>[^/]+)', pattern)
        return bool(re.match(f"^{pattern}$", path))
    
    def _check_matchers(self, endpoint: MockEndpoint,
                       headers: Dict, body: Any) -> bool:
        """检查匹配器"""
        for matcher in endpoint.matchers:
            if matcher["type"] == "header":
                if headers.get(matcher["name"]) != matcher["value"]:
                    return False
            elif matcher["type"] == "body":
                if body != matcher["value"]:
                    return False
        return True
    
    def get_request_count(self, path: str = None) -> int:
        """获取请求数量"""
        if path:
            return sum(1 for r in self.request_log if r["path"] == path)
        return len(self.request_log)
    
    def clear_request_log(self):
        """清除请求日志"""
        self.request_log.clear()

# 动态Mock配置
class DynamicMockConfig:
    """动态Mock配置"""
    
    def __init__(self):
        self.scenarios = {}
        self.current_scenario = "default"
    
    def add_scenario(self, name: str, endpoints: List[Dict]):
        """添加场景"""
        self.scenarios[name] = endpoints
    
    def set_scenario(self, name: str):
        """设置当前场景"""
        if name in self.scenarios:
            self.current_scenario = name
    
    def get_endpoints(self) -> List[Dict]:
        """获取当前场景的端点"""
        return self.scenarios.get(self.current_scenario, [])

# 使用示例
mock_server = MockServer()

# 添加Mock端点
mock_server.add_endpoint(
    method="GET",
    path="/users/{id}",
    status=200,
    body={"id": 123, "name": "Test User"}
)

mock_server.add_endpoint(
    method="POST",
    path="/users",
    status=201,
    body={"id": 124, "name": "New User"}
)

# 模拟请求
response = mock_server.handle_request("GET", "/users/123")
print(f"Response: {response}")

# 检查请求日志
print(f"Request count: {mock_server.get_request_count()}")

流量录制与回放

流量录制捕获生产环境的真实请求,用于测试回放和回归验证。支持录制模式、回放模式和对比分析。

# 流量录制器
from typing import Dict, List, Any
from dataclasses import dataclass, field
from datetime import datetime
import json

@dataclass
class RecordedRequest:
    method: str
    url: str
    headers: Dict[str, str]
    body: Any
    timestamp: float

@dataclass
class RecordedResponse:
    status_code: int
    headers: Dict[str, str]
    body: Any
    latency_ms: float

@dataclass
class TrafficRecord:
    request: RecordedRequest
    response: RecordedResponse
    recorded_at: datetime = field(default_factory=datetime.now)

class TrafficRecorder:
    def __init__(self):
        self.records: List[TrafficRecord] = []
        self.recording = False
    
    def start_recording(self):
        """开始录制"""
        self.recording = True
        print("Traffic recording started")
    
    def stop_recording(self):
        """停止录制"""
        self.recording = False
        print(f"Traffic recording stopped. Captured {len(self.records)} requests")
    
    def record(self, request: Dict, response: Dict):
        """记录请求响应对"""
        if not self.recording:
            return
        
        record = TrafficRecord(
            request=RecordedRequest(
                method=request["method"],
                url=request["url"],
                headers=request.get("headers", {}),
                body=request.get("body"),
                timestamp=request.get("timestamp", time.time())
            ),
            response=RecordedResponse(
                status_code=response["status_code"],
                headers=response.get("headers", {}),
                body=response.get("body"),
                latency_ms=response.get("latency_ms", 0)
            )
        )
        
        self.records.append(record)
    
    def export_records(self, filepath: str):
        """导出录制数据"""
        data = [
            {
                "request": {
                    "method": r.request.method,
                    "url": r.request.url,
                    "headers": r.request.headers,
                    "body": r.request.body
                },
                "response": {
                    "status_code": r.response.status_code,
                    "headers": r.response.headers,
                    "body": r.response.body
                },
                "recorded_at": r.recorded_at.isoformat()
            }
            for r in self.records
        ]
        
        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)
        
        print(f"Exported {len(self.records)} records to {filepath}")
    
    def import_records(self, filepath: str):
        """导入录制数据"""
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        self.records = [
            TrafficRecord(
                request=RecordedRequest(
                    method=r["request"]["method"],
                    url=r["request"]["url"],
                    headers=r["request"].get("headers", {}),
                    body=r["request"].get("body"),
                    timestamp=0
                ),
                response=RecordedResponse(
                    status_code=r["response"]["status_code"],
                    headers=r["response"].get("headers", {}),
                    body=r["response"].get("body"),
                    latency_ms=0
                )
            )
            for r in data
        ]
        
        print(f"Imported {len(self.records)} records")

class TrafficReplayer:
    """流量回放器"""
    
    def __init__(self, target_url: str):
        self.target_url = target_url
        self.results: List[Dict] = []
    
    def replay(self, records: List[TrafficRecord], 
              speed: float = 1.0) -> List[Dict]:
        """回放流量"""
        import requests
        
        for record in records:
            start_time = time.time()
            
            try:
                response = requests.request(
                    method=record.request.method,
                    url=f"{self.target_url}{record.request.url}",
                    headers=record.request.headers,
                    json=record.request.body if record.request.body else None
                )
                
                latency = (time.time() - start_time) * 1000
                
                self.results.append({
                    "request": {
                        "method": record.request.method,
                        "url": record.request.url
                    },
                    "expected_status": record.response.status_code,
                    "actual_status": response.status_code,
                    "status_match": response.status_code == record.response.status_code,
                    "latency_ms": latency
                })
            
            except Exception as e:
                self.results.append({
                    "request": {
                        "method": record.request.method,
                        "url": record.request.url
                    },
                    "error": str(e)
                })
        
        return self.results
    
    def compare_results(self) -> Dict:
        """对比结果"""
        total = len(self.results)
        status_matches = sum(1 for r in self.results if r.get("status_match"))
        
        return {
            "total_requests": total,
            "status_match_rate": status_matches / total if total > 0 else 0,
            "mismatches": [r for r in self.results if not r.get("status_match", True)]
        }

# 使用示例
recorder = TrafficRecorder()
recorder.start_recording()

# 模拟录制
recorder.record(
    request={"method": "GET", "url": "/users/123"},
    response={"status_code": 200, "body": {"id": 123}}
)

recorder.stop_recording()
recorder.export_records("traffic.json")

# 回放
replayer = TrafficReplayer("https://api-staging.example.com")
replayer.replay(recorder.records)
print(f"Comparison: {replayer.compare_results()}")

API监控与告警

API监控持续追踪接口健康状况,包括可用性、延迟、错误率和流量模式。设置告警规则及时发现问题。

# API监控器
from collections import defaultdict
from datetime import datetime, timedelta
import statistics

class APIMonitor:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.alert_rules = []
        self.alerts = []
    
    def record_request(self, endpoint: str, method: str,
                      status_code: int, latency_ms: float):
        """记录请求"""
        key = f"{method}:{endpoint}"
        
        self.metrics[key].append({
            "status_code": status_code,
            "latency_ms": latency_ms,
            "timestamp": datetime.now()
        })
        
        # 检查告警规则
        self._check_alerts(key, status_code, latency_ms)
    
    def _check_alerts(self, endpoint: str, status_code: int,
                     latency_ms: float):
        """检查告警规则"""
        for rule in self.alert_rules:
            if rule["type"] == "error_rate":
                self._check_error_rate(endpoint, rule)
            elif rule["type"] == "latency":
                self._check_latency(endpoint, rule)
    
    def _check_error_rate(self, endpoint: str, rule: Dict):
        """检查错误率"""
        recent_requests = [
            r for r in self.metrics[endpoint]
            if r["timestamp"] > datetime.now() - timedelta(minutes=5)
        ]
        
        if not recent_requests:
            return
        
        error_count = sum(1 for r in recent_requests if r["status_code"] >= 500)
        error_rate = error_count / len(recent_requests)
        
        if error_rate > rule["threshold"]:
            self.alerts.append({
                "type": "error_rate",
                "endpoint": endpoint,
                "value": error_rate,
                "threshold": rule["threshold"],
                "timestamp": datetime.now()
            })
    
    def _check_latency(self, endpoint: str, rule: Dict):
        """检查延迟"""
        recent_requests = [
            r for r in self.metrics[endpoint]
            if r["timestamp"] > datetime.now() - timedelta(minutes=5)
        ]
        
        if not recent_requests:
            return
        
        latencies = [r["latency_ms"] for r in recent_requests]
        p99_latency = sorted(latencies)[int(len(latencies) * 0.99)] if latencies else 0
        
        if p99_latency > rule["threshold"]:
            self.alerts.append({
                "type": "latency",
                "endpoint": endpoint,
                "value": p99_latency,
                "threshold": rule["threshold"],
                "timestamp": datetime.now()
            })
    
    def add_alert_rule(self, rule_type: str, threshold: float,
                      endpoints: List[str] = None):
        """添加告警规则"""
        self.alert_rules.append({
            "type": rule_type,
            "threshold": threshold,
            "endpoints": endpoints
        })
    
    def get_endpoint_stats(self, endpoint: str) -> Dict:
        """获取端点统计"""
        requests = self.metrics.get(endpoint, [])
        
        if not requests:
            return {"total": 0}
        
        latencies = [r["latency_ms"] for r in requests]
        status_codes = [r["status_code"] for r in requests]
        
        return {
            "total": len(requests),
            "success_rate": sum(1 for s in status_codes if s < 400) / len(status_codes),
            "avg_latency_ms": statistics.mean(latencies),
            "p95_latency_ms": sorted(latencies)[int(len(latencies) * 0.95)],
            "p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)]
        }
    
    def generate_dashboard(self) -> Dict:
        """生成监控面板"""
        return {
            "total_endpoints": len(self.metrics),
            "total_requests": sum(len(r) for r in self.metrics.values()),
            "endpoints": {
                endpoint: self.get_endpoint_stats(endpoint)
                for endpoint in self.metrics
            },
            "recent_alerts": self.alerts[-10:]
        }

# API契约验证
class APIContractValidator:
    """API契约验证器"""
    
    def __init__(self):
        self.schemas = {}
    
    def register_schema(self, endpoint: str, schema: Dict):
        """注册API Schema"""
        self.schemas[endpoint] = schema
    
    def validate_response(self, endpoint: str, response: Dict) -> Dict:
        """验证响应"""
        schema = self.schemas.get(endpoint)
        if not schema:
            return {"valid": True, "reason": "No schema defined"}
        
        # 验证必需字段
        required_fields = schema.get("required", [])
        missing = [f for f in required_fields if f not in response]
        
        if missing:
            return {
                "valid": False,
                "reason": f"Missing required fields: {missing}"
            }
        
        # 验证字段类型
        properties = schema.get("properties", {})
        type_errors = []
        
        for field_name, field_schema in properties.items():
            if field_name in response:
                expected_type = field_schema.get("type")
                actual_value = response[field_name]
                
                if expected_type == "string" and not isinstance(actual_value, str):
                    type_errors.append(f"{field_name}: expected string")
                elif expected_type == "number" and not isinstance(actual_value, (int, float)):
                    type_errors.append(f"{field_name}: expected number")
                elif expected_type == "integer" and not isinstance(actual_value, int):
                    type_errors.append(f"{field_name}: expected integer")
                elif expected_type == "boolean" and not isinstance(actual_value, bool):
                    type_errors.append(f"{field_name}: expected boolean")
        
        if type_errors:
            return {
                "valid": False,
                "reason": f"Type errors: {type_errors}"
            }
        
        return {"valid": True}