API测试架构:自动化Mock与流量录制
API测试架构:自动化Mock与流量录制
API测试策略
API测试是微服务架构中的关键测试层次,需要验证接口契约、数据格式、错误处理和性能。完整的API测试策略包括:单元测试(Mock外部依赖)、集成测试(真实服务交互)、契约测试(接口一致性)和流量测试(生产流量回放)。
# API测试框架核心
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional, Callable
from enum import Enum
import json
import time
class HTTPMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
@dataclass
class APIRequest:
method: HTTPMethod
url: str
headers: Dict[str, str] = field(default_factory=dict)
body: Any = None
query_params: Dict[str, str] = field(default_factory=dict)
@dataclass
class APIResponse:
status_code: int
headers: Dict[str, str]
body: Any
latency_ms: float
raw_response: Any = None
@dataclass
class APITestCase:
name: str
request: APIRequest
expected_status: int
expected_body: Dict = None
expected_headers: Dict = None
assertions: List[Callable] = field(default_factory=list)
tags: List[str] = field(default_factory=list)
class APITestRunner:
def __init__(self, base_url: str):
self.base_url = base_url
self.results: List[Dict] = []
self.mock_server = None
def run_test(self, test_case: APITestCase) -> Dict:
"""运行单个测试"""
start_time = time.time()
try:
# 发送请求
response = self._send_request(test_case.request)
# 验证状态码
status_ok = response.status_code == test_case.expected_status
# 验证响应体
body_ok = True
if test_case.expected_body:
body_ok = self._validate_body(response.body, test_case.expected_body)
# 运行自定义断言
assertions_passed = True
for assertion in test_case.assertions:
if not assertion(response):
assertions_passed = False
break
latency = (time.time() - start_time) * 1000
result = {
"name": test_case.name,
"status": "passed" if (status_ok and body_ok and assertions_passed) else "failed",
"latency_ms": latency,
"checks": {
"status_code": status_ok,
"response_body": body_ok,
"assertions": assertions_passed
}
}
except Exception as e:
result = {
"name": test_case.name,
"status": "error",
"error": str(e),
"latency_ms": (time.time() - start_time) * 1000
}
self.results.append(result)
return result
def _send_request(self, request: APIRequest) -> APIResponse:
"""发送HTTP请求"""
import requests
start_time = time.time()
response = requests.request(
method=request.method.value,
url=f"{self.base_url}{request.url}",
headers=request.headers,
json=request.body if request.body else None,
params=request.query_params
)
latency = (time.time() - start_time) * 1000
return APIResponse(
status_code=response.status_code,
headers=dict(response.headers),
body=response.json() if response.text else None,
latency_ms=latency,
raw_response=response
)
def _validate_body(self, actual: Any, expected: Any) -> bool:
"""验证响应体"""
if isinstance(expected, dict):
for key, value in expected.items():
if key not in actual:
return False
if isinstance(value, dict):
if not self._validate_body(actual[key], value):
return False
elif actual[key] != value:
return False
return True
return actual == expected
def generate_report(self) -> Dict:
"""生成测试报告"""
total = len(self.results)
passed = sum(1 for r in self.results if r["status"] == "passed")
latencies = [r["latency_ms"] for r in self.results if "latency_ms" in r]
return {
"summary": {
"total": total,
"passed": passed,
"failed": total - passed,
"pass_rate": passed / total if total > 0 else 0
},
"performance": {
"avg_latency_ms": sum(latencies) / len(latencies) if latencies else 0,
"p95_latency_ms": sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
"p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)] if latencies else 0
},
"results": self.results
}
# 使用示例
runner = APITestRunner("https://api.example.com")
test_get_user = APITestCase(
name="get_user_by_id",
request=APIRequest(
method=HTTPMethod.GET,
url="/users/123",
headers={"Accept": "application/json"}
),
expected_status=200,
expected_body={"id": 123, "name": "Test User"}
)
result = runner.run_test(test_get_user)
print(f"Test result: {result['status']}")
Mock服务架构
Mock服务模拟外部API的行为,支持测试隔离和场景模拟。支持静态配置、动态响应和录制回放模式。
# Mock服务框架
from typing import Dict, List, Any, Callable
from dataclasses import dataclass, field
import json
import re
from datetime import datetime
@dataclass
class MockEndpoint:
method: str
path: str
response_status: int
response_body: Any
response_headers: Dict[str, str] = field(default_factory=dict)
delay_ms: int = 0
matchers: List[Dict] = field(default_factory=list)
class MockServer:
def __init__(self, port: int = 8080):
self.port = port
self.endpoints: List[MockEndpoint] = []
self.request_log: List[Dict] = []
self.running = False
def add_endpoint(self, method: str, path: str,
status: int, body: Any, **kwargs):
"""添加Mock端点"""
endpoint = MockEndpoint(
method=method.upper(),
path=path,
response_status=status,
response_body=body,
**kwargs
)
self.endpoints.append(endpoint)
return self
def handle_request(self, method: str, path: str,
headers: Dict = None, body: Any = None) -> Dict:
"""处理请求"""
# 记录请求
self.request_log.append({
"method": method,
"path": path,
"headers": headers,
"body": body,
"timestamp": datetime.now().isoformat()
})
# 查找匹配的端点
for endpoint in self.endpoints:
if (endpoint.method == method.upper() and
self._match_path(endpoint.path, path)):
# 检查匹配器
if self._check_matchers(endpoint, headers, body):
return {
"status": endpoint.response_status,
"headers": endpoint.response_headers,
"body": endpoint.response_body
}
return {"status": 404, "body": {"error": "Not found"}}
def _match_path(self, pattern: str, path: str) -> bool:
"""匹配路径"""
# 支持路径参数: /users/{id}
pattern = re.sub(r'\{(\w+)\}', r'(?P<\1>[^/]+)', pattern)
return bool(re.match(f"^{pattern}$", path))
def _check_matchers(self, endpoint: MockEndpoint,
headers: Dict, body: Any) -> bool:
"""检查匹配器"""
for matcher in endpoint.matchers:
if matcher["type"] == "header":
if headers.get(matcher["name"]) != matcher["value"]:
return False
elif matcher["type"] == "body":
if body != matcher["value"]:
return False
return True
def get_request_count(self, path: str = None) -> int:
"""获取请求数量"""
if path:
return sum(1 for r in self.request_log if r["path"] == path)
return len(self.request_log)
def clear_request_log(self):
"""清除请求日志"""
self.request_log.clear()
# 动态Mock配置
class DynamicMockConfig:
"""动态Mock配置"""
def __init__(self):
self.scenarios = {}
self.current_scenario = "default"
def add_scenario(self, name: str, endpoints: List[Dict]):
"""添加场景"""
self.scenarios[name] = endpoints
def set_scenario(self, name: str):
"""设置当前场景"""
if name in self.scenarios:
self.current_scenario = name
def get_endpoints(self) -> List[Dict]:
"""获取当前场景的端点"""
return self.scenarios.get(self.current_scenario, [])
# 使用示例
mock_server = MockServer()
# 添加Mock端点
mock_server.add_endpoint(
method="GET",
path="/users/{id}",
status=200,
body={"id": 123, "name": "Test User"}
)
mock_server.add_endpoint(
method="POST",
path="/users",
status=201,
body={"id": 124, "name": "New User"}
)
# 模拟请求
response = mock_server.handle_request("GET", "/users/123")
print(f"Response: {response}")
# 检查请求日志
print(f"Request count: {mock_server.get_request_count()}")
流量录制与回放
流量录制捕获生产环境的真实请求,用于测试回放和回归验证。支持录制模式、回放模式和对比分析。
# 流量录制器
from typing import Dict, List, Any
from dataclasses import dataclass, field
from datetime import datetime
import json
@dataclass
class RecordedRequest:
method: str
url: str
headers: Dict[str, str]
body: Any
timestamp: float
@dataclass
class RecordedResponse:
status_code: int
headers: Dict[str, str]
body: Any
latency_ms: float
@dataclass
class TrafficRecord:
request: RecordedRequest
response: RecordedResponse
recorded_at: datetime = field(default_factory=datetime.now)
class TrafficRecorder:
def __init__(self):
self.records: List[TrafficRecord] = []
self.recording = False
def start_recording(self):
"""开始录制"""
self.recording = True
print("Traffic recording started")
def stop_recording(self):
"""停止录制"""
self.recording = False
print(f"Traffic recording stopped. Captured {len(self.records)} requests")
def record(self, request: Dict, response: Dict):
"""记录请求响应对"""
if not self.recording:
return
record = TrafficRecord(
request=RecordedRequest(
method=request["method"],
url=request["url"],
headers=request.get("headers", {}),
body=request.get("body"),
timestamp=request.get("timestamp", time.time())
),
response=RecordedResponse(
status_code=response["status_code"],
headers=response.get("headers", {}),
body=response.get("body"),
latency_ms=response.get("latency_ms", 0)
)
)
self.records.append(record)
def export_records(self, filepath: str):
"""导出录制数据"""
data = [
{
"request": {
"method": r.request.method,
"url": r.request.url,
"headers": r.request.headers,
"body": r.request.body
},
"response": {
"status_code": r.response.status_code,
"headers": r.response.headers,
"body": r.response.body
},
"recorded_at": r.recorded_at.isoformat()
}
for r in self.records
]
with open(filepath, 'w') as f:
json.dump(data, f, indent=2)
print(f"Exported {len(self.records)} records to {filepath}")
def import_records(self, filepath: str):
"""导入录制数据"""
with open(filepath, 'r') as f:
data = json.load(f)
self.records = [
TrafficRecord(
request=RecordedRequest(
method=r["request"]["method"],
url=r["request"]["url"],
headers=r["request"].get("headers", {}),
body=r["request"].get("body"),
timestamp=0
),
response=RecordedResponse(
status_code=r["response"]["status_code"],
headers=r["response"].get("headers", {}),
body=r["response"].get("body"),
latency_ms=0
)
)
for r in data
]
print(f"Imported {len(self.records)} records")
class TrafficReplayer:
"""流量回放器"""
def __init__(self, target_url: str):
self.target_url = target_url
self.results: List[Dict] = []
def replay(self, records: List[TrafficRecord],
speed: float = 1.0) -> List[Dict]:
"""回放流量"""
import requests
for record in records:
start_time = time.time()
try:
response = requests.request(
method=record.request.method,
url=f"{self.target_url}{record.request.url}",
headers=record.request.headers,
json=record.request.body if record.request.body else None
)
latency = (time.time() - start_time) * 1000
self.results.append({
"request": {
"method": record.request.method,
"url": record.request.url
},
"expected_status": record.response.status_code,
"actual_status": response.status_code,
"status_match": response.status_code == record.response.status_code,
"latency_ms": latency
})
except Exception as e:
self.results.append({
"request": {
"method": record.request.method,
"url": record.request.url
},
"error": str(e)
})
return self.results
def compare_results(self) -> Dict:
"""对比结果"""
total = len(self.results)
status_matches = sum(1 for r in self.results if r.get("status_match"))
return {
"total_requests": total,
"status_match_rate": status_matches / total if total > 0 else 0,
"mismatches": [r for r in self.results if not r.get("status_match", True)]
}
# 使用示例
recorder = TrafficRecorder()
recorder.start_recording()
# 模拟录制
recorder.record(
request={"method": "GET", "url": "/users/123"},
response={"status_code": 200, "body": {"id": 123}}
)
recorder.stop_recording()
recorder.export_records("traffic.json")
# 回放
replayer = TrafficReplayer("https://api-staging.example.com")
replayer.replay(recorder.records)
print(f"Comparison: {replayer.compare_results()}")
API监控与告警
API监控持续追踪接口健康状况,包括可用性、延迟、错误率和流量模式。设置告警规则及时发现问题。
# API监控器
from collections import defaultdict
from datetime import datetime, timedelta
import statistics
class APIMonitor:
def __init__(self):
self.metrics = defaultdict(list)
self.alert_rules = []
self.alerts = []
def record_request(self, endpoint: str, method: str,
status_code: int, latency_ms: float):
"""记录请求"""
key = f"{method}:{endpoint}"
self.metrics[key].append({
"status_code": status_code,
"latency_ms": latency_ms,
"timestamp": datetime.now()
})
# 检查告警规则
self._check_alerts(key, status_code, latency_ms)
def _check_alerts(self, endpoint: str, status_code: int,
latency_ms: float):
"""检查告警规则"""
for rule in self.alert_rules:
if rule["type"] == "error_rate":
self._check_error_rate(endpoint, rule)
elif rule["type"] == "latency":
self._check_latency(endpoint, rule)
def _check_error_rate(self, endpoint: str, rule: Dict):
"""检查错误率"""
recent_requests = [
r for r in self.metrics[endpoint]
if r["timestamp"] > datetime.now() - timedelta(minutes=5)
]
if not recent_requests:
return
error_count = sum(1 for r in recent_requests if r["status_code"] >= 500)
error_rate = error_count / len(recent_requests)
if error_rate > rule["threshold"]:
self.alerts.append({
"type": "error_rate",
"endpoint": endpoint,
"value": error_rate,
"threshold": rule["threshold"],
"timestamp": datetime.now()
})
def _check_latency(self, endpoint: str, rule: Dict):
"""检查延迟"""
recent_requests = [
r for r in self.metrics[endpoint]
if r["timestamp"] > datetime.now() - timedelta(minutes=5)
]
if not recent_requests:
return
latencies = [r["latency_ms"] for r in recent_requests]
p99_latency = sorted(latencies)[int(len(latencies) * 0.99)] if latencies else 0
if p99_latency > rule["threshold"]:
self.alerts.append({
"type": "latency",
"endpoint": endpoint,
"value": p99_latency,
"threshold": rule["threshold"],
"timestamp": datetime.now()
})
def add_alert_rule(self, rule_type: str, threshold: float,
endpoints: List[str] = None):
"""添加告警规则"""
self.alert_rules.append({
"type": rule_type,
"threshold": threshold,
"endpoints": endpoints
})
def get_endpoint_stats(self, endpoint: str) -> Dict:
"""获取端点统计"""
requests = self.metrics.get(endpoint, [])
if not requests:
return {"total": 0}
latencies = [r["latency_ms"] for r in requests]
status_codes = [r["status_code"] for r in requests]
return {
"total": len(requests),
"success_rate": sum(1 for s in status_codes if s < 400) / len(status_codes),
"avg_latency_ms": statistics.mean(latencies),
"p95_latency_ms": sorted(latencies)[int(len(latencies) * 0.95)],
"p99_latency_ms": sorted(latencies)[int(len(latencies) * 0.99)]
}
def generate_dashboard(self) -> Dict:
"""生成监控面板"""
return {
"total_endpoints": len(self.metrics),
"total_requests": sum(len(r) for r in self.metrics.values()),
"endpoints": {
endpoint: self.get_endpoint_stats(endpoint)
for endpoint in self.metrics
},
"recent_alerts": self.alerts[-10:]
}
# API契约验证
class APIContractValidator:
"""API契约验证器"""
def __init__(self):
self.schemas = {}
def register_schema(self, endpoint: str, schema: Dict):
"""注册API Schema"""
self.schemas[endpoint] = schema
def validate_response(self, endpoint: str, response: Dict) -> Dict:
"""验证响应"""
schema = self.schemas.get(endpoint)
if not schema:
return {"valid": True, "reason": "No schema defined"}
# 验证必需字段
required_fields = schema.get("required", [])
missing = [f for f in required_fields if f not in response]
if missing:
return {
"valid": False,
"reason": f"Missing required fields: {missing}"
}
# 验证字段类型
properties = schema.get("properties", {})
type_errors = []
for field_name, field_schema in properties.items():
if field_name in response:
expected_type = field_schema.get("type")
actual_value = response[field_name]
if expected_type == "string" and not isinstance(actual_value, str):
type_errors.append(f"{field_name}: expected string")
elif expected_type == "number" and not isinstance(actual_value, (int, float)):
type_errors.append(f"{field_name}: expected number")
elif expected_type == "integer" and not isinstance(actual_value, int):
type_errors.append(f"{field_name}: expected integer")
elif expected_type == "boolean" and not isinstance(actual_value, bool):
type_errors.append(f"{field_name}: expected boolean")
if type_errors:
return {
"valid": False,
"reason": f"Type errors: {type_errors}"
}
return {"valid": True}