测试金字塔:单元测试、集成测试与E2E的比例
测试金字塔:单元测试、集成测试与E2E的比例
测试金字塔模型
测试金字塔是Mike Cohn提出的测试策略模型,描述了不同层次测试的理想比例:底层是大量快速的单元测试,中间是适量的集成测试,顶层是少量的端到端测试。这种分层结构确保测试既有速度又有信心。
# 测试金字塔模型
from dataclasses import dataclass, field
from typing import List, Dict, Callable
from enum import Enum
class TestLayer(Enum):
UNIT = "unit"
INTEGRATION = "integration"
E2E = "e2e"
@dataclass
class TestConfig:
layer: TestLayer
target_coverage: float
execution_time_target: int # seconds
reliability_target: float # 0-1
description: str = ""
class TestPyramid:
def __init__(self):
self.layers = {
TestLayer.UNIT: TestConfig(
layer=TestLayer.UNIT,
target_coverage=0.7, # 70%的测试应该是单元测试
execution_time_target=60, # 1分钟内完成
reliability_target=0.99,
description="快速、独立、大量"
),
TestLayer.INTEGRATION: TestConfig(
layer=TestLayer.INTEGRATION,
target_coverage=0.2, # 20%的测试应该是集成测试
execution_time_target=300, # 5分钟内完成
reliability_target=0.95,
description="验证组件间交互"
),
TestLayer.E2E: TestConfig(
layer=TestLayer.E2E,
target_coverage=0.1, # 10%的测试应该是E2E测试
execution_time_target=600, # 10分钟内完成
reliability_target=0.90,
description="验证完整业务流程"
)
}
self.tests = {layer: [] for layer in TestLayer}
def add_test(self, layer: TestLayer, test_name: str,
test_fn: Callable, metadata: Dict = None):
"""添加测试"""
self.tests[layer].append({
"name": test_name,
"fn": test_fn,
"metadata": metadata or {}
})
def get_distribution(self) -> Dict:
"""获取测试分布"""
total = sum(len(tests) for tests in self.tests.values())
distribution = {}
for layer, tests in self.tests.items():
count = len(tests)
distribution[layer.value] = {
"count": count,
"percentage": count / total * 100 if total > 0 else 0,
"target_percentage": self.layers[layer].target_coverage * 100
}
return distribution
def validate_pyramid(self) -> Dict:
"""验证测试金字塔是否健康"""
distribution = self.get_distribution()
issues = []
for layer_name, stats in distribution.items():
layer = TestLayer(layer_name)
config = self.layers[layer]
actual = stats["percentage"]
target = stats["target_percentage"]
# 检查是否偏离目标
deviation = abs(actual - target)
if deviation > 10: # 允许10%偏差
issues.append({
"layer": layer_name,
"actual": actual,
"target": target,
"deviation": deviation,
"severity": "warning" if deviation < 20 else "critical"
})
return {
"healthy": len(issues) == 0,
"distribution": distribution,
"issues": issues
}
# 使用示例
pyramid = TestPyramid()
# 添加单元测试
for i in range(70):
pyramid.add_test(TestLayer.UNIT, f"unit_test_{i}", lambda: None)
# 添加集成测试
for i in range(20):
pyramid.add_test(TestLayer.INTEGRATION, f"integration_test_{i}", lambda: None)
# 添加E2E测试
for i in range(10):
pyramid.add_test(TestLayer.E2E, f"e2e_test_{i}", lambda: None)
validation = pyramid.validate_pyramid()
print(f"Pyramid healthy: {validation['healthy']}")
单元测试策略
单元测试是金字塔的基石,测试单个函数或类的行为。特点是快速、独立、大量。使用Mock隔离外部依赖,确保测试的确定性和速度。
# 单元测试框架
import unittest
from unittest.mock import Mock, patch, MagicMock
from typing import Any, Callable
from dataclasses import dataclass
class UnitTestBase(unittest.TestCase):
"""单元测试基类"""
def setUp(self):
"""测试前准备"""
self.mocks = {}
def tearDown(self):
"""测试后清理"""
for mock in self.mocks.values():
if isinstance(mock, Mock):
mock.reset_mock()
def create_mock(self, name: str, spec: Any = None) -> Mock:
"""创建Mock对象"""
mock = Mock(spec=spec)
self.mocks[name] = mock
return mock
def assert_called_with(self, mock_name: str, *args, **kwargs):
"""断言Mock被调用"""
mock = self.mocks.get(mock_name)
self.assertIsNotNone(mock, f"Mock '{mock_name}' not found")
mock.assert_called_with(*args, **kwargs)
def assert_not_called(self, mock_name: str):
"""断言Mock未被调用"""
mock = self.mocks.get(mock_name)
self.assertIsNotNone(mock, f"Mock '{mock_name}' not found")
mock.assert_not_called()
# 业务逻辑单元测试
class UserService:
def __init__(self, repository, email_service):
self.repository = repository
self.email_service = email_service
def create_user(self, user_data: dict) -> dict:
# 验证输入
if not user_data.get("email"):
raise ValueError("Email is required")
# 创建用户
user = self.repository.save(user_data)
# 发送欢迎邮件
self.email_service.send_welcome(user["email"])
return user
class UserServiceTest(UnitTestBase):
def setUp(self):
super().setUp()
self.repository = self.create_mock("repository")
self.email_service = self.create_mock("email_service")
self.service = UserService(self.repository, self.email_service)
def test_create_user_success(self):
"""测试成功创建用户"""
user_data = {"name": "Test", "email": "test@example.com"}
self.repository.save.return_value = {"id": 1, **user_data}
result = self.service.create_user(user_data)
self.assertEqual(result["email"], "test@example.com")
self.assert_called_with("repository", user_data)
self.assert_called_with("email_service", "test@example.com")
def test_create_user_missing_email(self):
"""测试缺少邮箱"""
user_data = {"name": "Test"}
with self.assertRaises(ValueError) as context:
self.service.create_user(user_data)
self.assertIn("Email is required", str(context.exception))
self.assert_not_called("repository")
self.assert_not_called("email_service")
# 参数化测试
def parameterized_test(test_cases):
"""参数化测试装饰器"""
def decorator(test_fn):
def wrapper(self):
for case in test_cases:
with self.subTest(case=case):
test_fn(self, case)
return wrapper
return decorator
class MathServiceTest(UnitTestBase):
@parameterized_test([
{"a": 1, "b": 2, "expected": 3},
{"a": 0, "b": 0, "expected": 0},
{"a": -1, "b": 1, "expected": 0},
])
def test_add(self, case):
"""测试加法"""
result = case["a"] + case["b"]
self.assertEqual(result, case["expected"])
# 测试覆盖率分析
class CoverageAnalyzer:
def __init__(self):
self.coverage_data = {}
def record_coverage(self, module: str, line: int,
executed: bool = True):
"""记录覆盖率"""
if module not in self.coverage_data:
self.coverage_data[module] = {}
self.coverage_data[module][line] = executed
def calculate_coverage(self, module: str) -> Dict:
"""计算覆盖率"""
if module not in self.coverage_data:
return {"coverage": 0, "total_lines": 0, "covered_lines": 0}
lines = self.coverage_data[module]
total = len(lines)
covered = sum(1 for executed in lines.values() if executed)
return {
"coverage": covered / total * 100 if total > 0 else 0,
"total_lines": total,
"covered_lines": covered,
"uncovered_lines": [line for line, exec in lines.items() if not exec]
}
def get_report(self) -> Dict:
"""生成覆盖率报告"""
report = {
"total_modules": len(self.coverage_data),
"modules": {}
}
total_lines = 0
total_covered = 0
for module in self.coverage_data:
coverage = self.calculate_coverage(module)
report["modules"][module] = coverage
total_lines += coverage["total_lines"]
total_covered += coverage["covered_lines"]
report["overall_coverage"] = total_covered / total_lines * 100 if total_lines > 0 else 0
return report
集成测试策略
集成测试验证组件间的交互,包括服务间调用、数据库交互和外部API集成。使用Testcontainers和WireMock提供真实的测试环境。
# 集成测试框架
from typing import Dict, List
import asyncio
class IntegrationTestBase(unittest.TestCase):
"""集成测试基类"""
@classmethod
def setUpClass(cls):
"""测试类初始化"""
cls.containers = {}
cls.test_data = {}
@classmethod
def tearDownClass(cls):
"""测试类清理"""
for container in cls.containers.values():
container.stop()
def setup_database(self):
"""设置测试数据库"""
# 启动数据库容器
self.containers["db"] = DatabaseContainer(
image="postgres:14",
port=5432
)
self.containers["db"].start()
# 创建测试数据
self._seed_test_data()
def setup_messaging(self):
"""设置消息队列"""
self.containers["rabbitmq"] = MessageQueueContainer(
image="rabbitmq:3",
port=5672
)
self.containers["rabbitmq"].start()
def _seed_test_data(self):
"""填充测试数据"""
self.test_data["users"] = [
{"id": 1, "name": "Test User 1"},
{"id": 2, "name": "Test User 2"}
]
# 数据库集成测试
class DatabaseIntegrationTest(IntegrationTestBase):
def setUp(self):
self.setup_database()
self.repository = UserRepository(self.containers["db"])
def test_create_and_fetch_user(self):
"""测试创建和获取用户"""
# 创建用户
user = self.repository.create({"name": "Test", "email": "test@example.com"})
self.assertIsNotNone(user["id"])
# 获取用户
fetched = self.repository.get_by_id(user["id"])
self.assertEqual(fetched["name"], "Test")
def test_transaction_rollback(self):
"""测试事务回滚"""
try:
with self.repository.transaction():
self.repository.create({"name": "User 1"})
raise Exception("Simulated error")
except Exception:
pass
# 验证回滚
users = self.repository.list_all()
self.assertEqual(len(users), 0)
# API集成测试
class APIIntegrationTest(IntegrationTestBase):
def setUp(self):
self.client = TestClient(self.app)
def test_full_user_flow(self):
"""测试完整用户流程"""
# 创建用户
response = self.client.post("/users", json={
"name": "Test User",
"email": "test@example.com"
})
self.assertEqual(response.status_code, 201)
user_id = response.json()["id"]
# 获取用户
response = self.client.get(f"/users/{user_id}")
self.assertEqual(response.status_code, 200)
# 更新用户
response = self.client.put(f"/users/{user_id}", json={
"name": "Updated Name"
})
self.assertEqual(response.status_code, 200)
# 删除用户
response = self.client.delete(f"/users/{user_id}")
self.assertEqual(response.status_code, 204)
# 消息队列集成测试
class MessagingIntegrationTest(IntegrationTestBase):
def setUp(self):
self.setup_messaging()
self.publisher = MessagePublisher(self.containers["rabbitmq"])
self.consumer = MessageConsumer(self.containers["rabbitmq"])
async def test_publish_and_consume(self):
"""测试发布和消费消息"""
message = {"type": "user_created", "user_id": 123}
# 发布消息
await self.publisher.publish("users", message)
# 消费消息
received = await self.consumer.consume("users", timeout=5)
self.assertEqual(received["type"], "user_created")
self.assertEqual(received["user_id"], 123)
# 测试容器管理
class TestContainer:
def __init__(self, image: str, ports: Dict = None):
self.image = image
self.ports = ports or {}
self.status = "stopped"
def start(self):
"""启动容器"""
print(f"Starting container: {self.image}")
self.status = "running"
def stop(self):
"""停止容器"""
print(f"Stopping container: {self.image}")
self.status = "stopped"
def is_running(self) -> bool:
return self.status == "running"
class DatabaseContainer(TestContainer):
def __init__(self, image: str, port: int):
super().__init__(image, {"port": port})
self.connection_url = f"postgresql://test:test@localhost:{port}/testdb"
def execute_query(self, query: str):
"""执行SQL查询"""
print(f"Executing: {query}")
return []
class MessageQueueContainer(TestContainer):
def __init__(self, image: str, port: int):
super().__init__(image, {"port": port})
E2E测试策略
E2E测试位于金字塔顶端,验证完整的用户流程。数量最少但信心最高。使用Page Object模式提高可维护性,选择性地运行关键路径测试。
# E2E测试框架
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import time
class E2ETestBase(unittest.TestCase):
"""E2E测试基类"""
@classmethod
def setUpClass(cls):
"""设置浏览器"""
options = webdriver.ChromeOptions()
options.add_argument("--headless")
cls.driver = webdriver.Chrome(options=options)
cls.wait = WebDriverWait(cls.driver, 10)
@classmethod
def tearDownClass(cls):
"""关闭浏览器"""
cls.driver.quit()
def take_screenshot(self, name: str):
"""截图"""
self.driver.save_screenshot(f"screenshots/{name}.png")
class PageObject:
"""页面对象基类"""
def __init__(self, driver):
self.driver = driver
self.wait = WebDriverWait(driver, 10)
def find_element(self, locator: tuple):
return self.wait.until(EC.presence_of_element_located(locator))
def click(self, locator: tuple):
element = self.wait.until(EC.element_to_be_clickable(locator))
element.click()
def type_text(self, locator: tuple, text: str):
element = self.find_element(locator)
element.clear()
element.send_keys(text)
class LoginPage(PageObject):
"""登录页面"""
USERNAME = (By.ID, "username")
PASSWORD = (By.ID, "password")
LOGIN_BUTTON = (By.ID, "login-btn")
ERROR_MESSAGE = (By.CLASS_NAME, "error")
def login(self, username: str, password: str):
self.type_text(self.USERNAME, username)
self.type_text(self.PASSWORD, password)
self.click(self.LOGIN_BUTTON)
return DashboardPage(self.driver)
def get_error(self) -> str:
return self.find_element(self.ERROR_MESSAGE).text
class DashboardPage(PageObject):
"""仪表板页面"""
WELCOME = (By.CLASS_NAME, "welcome")
LOGOUT = (By.ID, "logout")
def get_welcome_message(self) -> str:
return self.find_element(self.WELCOME).text
def logout(self):
self.click(self.LOGOUT)
return LoginPage(self.driver)
# E2E测试用例
class UserFlowTest(E2ETestBase):
def test_complete_user_flow(self):
"""测试完整用户流程"""
# 登录
login_page = LoginPage(self.driver)
login_page.driver.get("https://example.com/login")
dashboard = login_page.login("testuser", "password123")
# 验证仪表板
welcome = dashboard.get_welcome_message()
self.assertIn("Welcome", welcome)
# 登出
login_page = dashboard.logout()
# 验证返回登录页
self.driver.find_element(*LoginPage.LOGIN_BUTTON)
# 选择性E2E测试
class SelectiveE2ETest:
def __init__(self):
self.critical_paths = [
"user_login",
"create_order",
"payment_process"
]
def should_run_test(self, test_name: str,
changed_files: List[str]) -> bool:
"""判断是否应该运行测试"""
# 始终运行关键路径
if test_name in self.critical_paths:
return True
# 根据变更文件判断
relevant_files = {
"user_login": ["auth/", "login/"],
"create_order": ["order/", "cart/"],
"payment_process": ["payment/", "checkout/"]
}
for critical_path, files in relevant_files.items():
if any(f in file for file in changed_files for f in files):
return True
return False
# 测试报告生成
class E2ETestReporter:
def __init__(self, results: List[Dict]):
self.results = results
def generate_html_report(self) -> str:
"""生成HTML报告"""
passed = sum(1 for r in self.results if r["status"] == "passed")
failed = sum(1 for r in self.results if r["status"] == "failed")
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>E2E Test Report</title>
<style>
.passed {{ color: green; }}
.failed {{ color: red; }}
</style>
</head>
<body>
<h1>E2E Test Report</h1>
<div class="summary">
<p>Total: {len(self.results)}</p>
<p class="passed">Passed: {passed}</p>
<p class="failed">Failed: {failed}</p>
</div>
<table>
<tr><th>Test</th><th>Status</th><th>Duration</th></tr>
"""
for result in self.results:
status_class = "passed" if result["status"] == "passed" else "failed"
html += f"""
<tr>
<td>{result['name']}</td>
<td class="{status_class}">{result['status']}</td>
<td>{result.get('duration', 0):.2f}s</td>
</tr>
"""
html += """
</table>
</body>
</html>
"""
return html
测试金字塔反模式
识别和避免测试金字塔的常见反模式:冰淇淋筒(E2E过多)、沙漏型(缺少集成测试)、笼型(单元测试不足)。
# 测试反模式检测
class TestAntiPatternDetector:
def __init__(self, test_distribution: Dict):
self.distribution = test_distribution
def detect_all(self) -> List[Dict]:
"""检测所有反模式"""
patterns = []
# 检测冰淇淋筒反模式
ice_cream = self._detect_ice_cream_cone()
if ice_cream:
patterns.append(ice_cream)
# 检测沙漏反模式
hourglass = self._detect_hourglass()
if hourglass:
patterns.append(hourglass)
# 检测笼型反模式
cage = self._detect_cage()
if cage:
patterns.append(cage)
return patterns
def _detect_ice_cream_cone(self) -> Dict:
"""检测冰淇淋筒反模式(E2E过多)"""
unit = self.distribution.get("unit", 0)
integration = self.distribution.get("integration", 0)
e2e = self.distribution.get("e2e", 0)
total = unit + integration + e2e
if total == 0:
return None
e2e_ratio = e2e / total
if e2e_ratio > 0.3: # E2E超过30%
return {
"pattern": "ice_cream_cone",
"severity": "high",
"description": "E2E测试比例过高,导致测试套件缓慢且脆弱",
"recommendation": "将更多测试下沉到单元和集成测试层"
}
return None
def _detect_hourglass(self) -> Dict:
"""检测沙漏反模式(缺少集成测试)"""
unit = self.distribution.get("unit", 0)
integration = self.distribution.get("integration", 0)
e2e = self.distribution.get("e2e", 0)
total = unit + integration + e2e
if total == 0:
return None
integration_ratio = integration / total
if integration_ratio < 0.1: # 集成测试少于10%
return {
"pattern": "hourglass",
"severity": "medium",
"description": "集成测试不足,组件间交互验证不够",
"recommendation": "增加集成测试覆盖关键的组件交互"
}
return None
def _detect_cage(self) -> Dict:
"""检测笼型反模式(单元测试不足)"""
unit = self.distribution.get("unit", 0)
integration = self.distribution.get("integration", 0)
e2e = self.distribution.get("e2e", 0)
total = unit + integration + e2e
if total == 0:
return None
unit_ratio = unit / total
if unit_ratio < 0.5: # 单元测试少于50%
return {
"pattern": "cage",
"severity": "high",
"description": "单元测试不足,代码质量保障不够",
"recommendation": "增加单元测试覆盖核心业务逻辑"
}
return None
def generate_recommendations(self) -> List[str]:
"""生成改进建议"""
patterns = self.detect_all()
recommendations = []
for pattern in patterns:
recommendations.append(pattern["recommendation"])
if not recommendations:
recommendations.append("测试金字塔结构健康,继续保持")
return recommendations
# 使用示例
detector = TestAntiPatternDetector({
"unit": 50,
"integration": 10,
"e2e": 40
})
patterns = detector.detect_all()
for pattern in patterns:
print(f"Detected: {pattern['pattern']} - {pattern['description']}")