← 返回首页
🧪

测试金字塔:单元测试、集成测试与E2E的比例

📂 architecture ⏱ 8 min 1498 words

测试金字塔:单元测试、集成测试与E2E的比例

测试金字塔模型

测试金字塔是Mike Cohn提出的测试策略模型,描述了不同层次测试的理想比例:底层是大量快速的单元测试,中间是适量的集成测试,顶层是少量的端到端测试。这种分层结构确保测试既有速度又有信心。

# 测试金字塔模型
from dataclasses import dataclass, field
from typing import List, Dict, Callable
from enum import Enum

class TestLayer(Enum):
    UNIT = "unit"
    INTEGRATION = "integration"
    E2E = "e2e"

@dataclass
class TestConfig:
    layer: TestLayer
    target_coverage: float
    execution_time_target: int  # seconds
    reliability_target: float  # 0-1
    description: str = ""

class TestPyramid:
    def __init__(self):
        self.layers = {
            TestLayer.UNIT: TestConfig(
                layer=TestLayer.UNIT,
                target_coverage=0.7,  # 70%的测试应该是单元测试
                execution_time_target=60,  # 1分钟内完成
                reliability_target=0.99,
                description="快速、独立、大量"
            ),
            TestLayer.INTEGRATION: TestConfig(
                layer=TestLayer.INTEGRATION,
                target_coverage=0.2,  # 20%的测试应该是集成测试
                execution_time_target=300,  # 5分钟内完成
                reliability_target=0.95,
                description="验证组件间交互"
            ),
            TestLayer.E2E: TestConfig(
                layer=TestLayer.E2E,
                target_coverage=0.1,  # 10%的测试应该是E2E测试
                execution_time_target=600,  # 10分钟内完成
                reliability_target=0.90,
                description="验证完整业务流程"
            )
        }
        self.tests = {layer: [] for layer in TestLayer}
    
    def add_test(self, layer: TestLayer, test_name: str, 
                test_fn: Callable, metadata: Dict = None):
        """添加测试"""
        self.tests[layer].append({
            "name": test_name,
            "fn": test_fn,
            "metadata": metadata or {}
        })
    
    def get_distribution(self) -> Dict:
        """获取测试分布"""
        total = sum(len(tests) for tests in self.tests.values())
        
        distribution = {}
        for layer, tests in self.tests.items():
            count = len(tests)
            distribution[layer.value] = {
                "count": count,
                "percentage": count / total * 100 if total > 0 else 0,
                "target_percentage": self.layers[layer].target_coverage * 100
            }
        
        return distribution
    
    def validate_pyramid(self) -> Dict:
        """验证测试金字塔是否健康"""
        distribution = self.get_distribution()
        issues = []
        
        for layer_name, stats in distribution.items():
            layer = TestLayer(layer_name)
            config = self.layers[layer]
            actual = stats["percentage"]
            target = stats["target_percentage"]
            
            # 检查是否偏离目标
            deviation = abs(actual - target)
            if deviation > 10:  # 允许10%偏差
                issues.append({
                    "layer": layer_name,
                    "actual": actual,
                    "target": target,
                    "deviation": deviation,
                    "severity": "warning" if deviation < 20 else "critical"
                })
        
        return {
            "healthy": len(issues) == 0,
            "distribution": distribution,
            "issues": issues
        }

# 使用示例
pyramid = TestPyramid()

# 添加单元测试
for i in range(70):
    pyramid.add_test(TestLayer.UNIT, f"unit_test_{i}", lambda: None)

# 添加集成测试
for i in range(20):
    pyramid.add_test(TestLayer.INTEGRATION, f"integration_test_{i}", lambda: None)

# 添加E2E测试
for i in range(10):
    pyramid.add_test(TestLayer.E2E, f"e2e_test_{i}", lambda: None)

validation = pyramid.validate_pyramid()
print(f"Pyramid healthy: {validation['healthy']}")

单元测试策略

单元测试是金字塔的基石,测试单个函数或类的行为。特点是快速、独立、大量。使用Mock隔离外部依赖,确保测试的确定性和速度。

# 单元测试框架
import unittest
from unittest.mock import Mock, patch, MagicMock
from typing import Any, Callable
from dataclasses import dataclass

class UnitTestBase(unittest.TestCase):
    """单元测试基类"""
    
    def setUp(self):
        """测试前准备"""
        self.mocks = {}
    
    def tearDown(self):
        """测试后清理"""
        for mock in self.mocks.values():
            if isinstance(mock, Mock):
                mock.reset_mock()
    
    def create_mock(self, name: str, spec: Any = None) -> Mock:
        """创建Mock对象"""
        mock = Mock(spec=spec)
        self.mocks[name] = mock
        return mock
    
    def assert_called_with(self, mock_name: str, *args, **kwargs):
        """断言Mock被调用"""
        mock = self.mocks.get(mock_name)
        self.assertIsNotNone(mock, f"Mock '{mock_name}' not found")
        mock.assert_called_with(*args, **kwargs)
    
    def assert_not_called(self, mock_name: str):
        """断言Mock未被调用"""
        mock = self.mocks.get(mock_name)
        self.assertIsNotNone(mock, f"Mock '{mock_name}' not found")
        mock.assert_not_called()

# 业务逻辑单元测试
class UserService:
    def __init__(self, repository, email_service):
        self.repository = repository
        self.email_service = email_service
    
    def create_user(self, user_data: dict) -> dict:
        # 验证输入
        if not user_data.get("email"):
            raise ValueError("Email is required")
        
        # 创建用户
        user = self.repository.save(user_data)
        
        # 发送欢迎邮件
        self.email_service.send_welcome(user["email"])
        
        return user

class UserServiceTest(UnitTestBase):
    def setUp(self):
        super().setUp()
        self.repository = self.create_mock("repository")
        self.email_service = self.create_mock("email_service")
        self.service = UserService(self.repository, self.email_service)
    
    def test_create_user_success(self):
        """测试成功创建用户"""
        user_data = {"name": "Test", "email": "test@example.com"}
        self.repository.save.return_value = {"id": 1, **user_data}
        
        result = self.service.create_user(user_data)
        
        self.assertEqual(result["email"], "test@example.com")
        self.assert_called_with("repository", user_data)
        self.assert_called_with("email_service", "test@example.com")
    
    def test_create_user_missing_email(self):
        """测试缺少邮箱"""
        user_data = {"name": "Test"}
        
        with self.assertRaises(ValueError) as context:
            self.service.create_user(user_data)
        
        self.assertIn("Email is required", str(context.exception))
        self.assert_not_called("repository")
        self.assert_not_called("email_service")

# 参数化测试
def parameterized_test(test_cases):
    """参数化测试装饰器"""
    def decorator(test_fn):
        def wrapper(self):
            for case in test_cases:
                with self.subTest(case=case):
                    test_fn(self, case)
        return wrapper
    return decorator

class MathServiceTest(UnitTestBase):
    @parameterized_test([
        {"a": 1, "b": 2, "expected": 3},
        {"a": 0, "b": 0, "expected": 0},
        {"a": -1, "b": 1, "expected": 0},
    ])
    def test_add(self, case):
        """测试加法"""
        result = case["a"] + case["b"]
        self.assertEqual(result, case["expected"])

# 测试覆盖率分析
class CoverageAnalyzer:
    def __init__(self):
        self.coverage_data = {}
    
    def record_coverage(self, module: str, line: int, 
                       executed: bool = True):
        """记录覆盖率"""
        if module not in self.coverage_data:
            self.coverage_data[module] = {}
        
        self.coverage_data[module][line] = executed
    
    def calculate_coverage(self, module: str) -> Dict:
        """计算覆盖率"""
        if module not in self.coverage_data:
            return {"coverage": 0, "total_lines": 0, "covered_lines": 0}
        
        lines = self.coverage_data[module]
        total = len(lines)
        covered = sum(1 for executed in lines.values() if executed)
        
        return {
            "coverage": covered / total * 100 if total > 0 else 0,
            "total_lines": total,
            "covered_lines": covered,
            "uncovered_lines": [line for line, exec in lines.items() if not exec]
        }
    
    def get_report(self) -> Dict:
        """生成覆盖率报告"""
        report = {
            "total_modules": len(self.coverage_data),
            "modules": {}
        }
        
        total_lines = 0
        total_covered = 0
        
        for module in self.coverage_data:
            coverage = self.calculate_coverage(module)
            report["modules"][module] = coverage
            total_lines += coverage["total_lines"]
            total_covered += coverage["covered_lines"]
        
        report["overall_coverage"] = total_covered / total_lines * 100 if total_lines > 0 else 0
        
        return report

集成测试策略

集成测试验证组件间的交互,包括服务间调用、数据库交互和外部API集成。使用Testcontainers和WireMock提供真实的测试环境。

# 集成测试框架
from typing import Dict, List
import asyncio

class IntegrationTestBase(unittest.TestCase):
    """集成测试基类"""
    
    @classmethod
    def setUpClass(cls):
        """测试类初始化"""
        cls.containers = {}
        cls.test_data = {}
    
    @classmethod
    def tearDownClass(cls):
        """测试类清理"""
        for container in cls.containers.values():
            container.stop()
    
    def setup_database(self):
        """设置测试数据库"""
        # 启动数据库容器
        self.containers["db"] = DatabaseContainer(
            image="postgres:14",
            port=5432
        )
        self.containers["db"].start()
        
        # 创建测试数据
        self._seed_test_data()
    
    def setup_messaging(self):
        """设置消息队列"""
        self.containers["rabbitmq"] = MessageQueueContainer(
            image="rabbitmq:3",
            port=5672
        )
        self.containers["rabbitmq"].start()
    
    def _seed_test_data(self):
        """填充测试数据"""
        self.test_data["users"] = [
            {"id": 1, "name": "Test User 1"},
            {"id": 2, "name": "Test User 2"}
        ]

# 数据库集成测试
class DatabaseIntegrationTest(IntegrationTestBase):
    def setUp(self):
        self.setup_database()
        self.repository = UserRepository(self.containers["db"])
    
    def test_create_and_fetch_user(self):
        """测试创建和获取用户"""
        # 创建用户
        user = self.repository.create({"name": "Test", "email": "test@example.com"})
        self.assertIsNotNone(user["id"])
        
        # 获取用户
        fetched = self.repository.get_by_id(user["id"])
        self.assertEqual(fetched["name"], "Test")
    
    def test_transaction_rollback(self):
        """测试事务回滚"""
        try:
            with self.repository.transaction():
                self.repository.create({"name": "User 1"})
                raise Exception("Simulated error")
        except Exception:
            pass
        
        # 验证回滚
        users = self.repository.list_all()
        self.assertEqual(len(users), 0)

# API集成测试
class APIIntegrationTest(IntegrationTestBase):
    def setUp(self):
        self.client = TestClient(self.app)
    
    def test_full_user_flow(self):
        """测试完整用户流程"""
        # 创建用户
        response = self.client.post("/users", json={
            "name": "Test User",
            "email": "test@example.com"
        })
        self.assertEqual(response.status_code, 201)
        user_id = response.json()["id"]
        
        # 获取用户
        response = self.client.get(f"/users/{user_id}")
        self.assertEqual(response.status_code, 200)
        
        # 更新用户
        response = self.client.put(f"/users/{user_id}", json={
            "name": "Updated Name"
        })
        self.assertEqual(response.status_code, 200)
        
        # 删除用户
        response = self.client.delete(f"/users/{user_id}")
        self.assertEqual(response.status_code, 204)

# 消息队列集成测试
class MessagingIntegrationTest(IntegrationTestBase):
    def setUp(self):
        self.setup_messaging()
        self.publisher = MessagePublisher(self.containers["rabbitmq"])
        self.consumer = MessageConsumer(self.containers["rabbitmq"])
    
    async def test_publish_and_consume(self):
        """测试发布和消费消息"""
        message = {"type": "user_created", "user_id": 123}
        
        # 发布消息
        await self.publisher.publish("users", message)
        
        # 消费消息
        received = await self.consumer.consume("users", timeout=5)
        
        self.assertEqual(received["type"], "user_created")
        self.assertEqual(received["user_id"], 123)

# 测试容器管理
class TestContainer:
    def __init__(self, image: str, ports: Dict = None):
        self.image = image
        self.ports = ports or {}
        self.status = "stopped"
    
    def start(self):
        """启动容器"""
        print(f"Starting container: {self.image}")
        self.status = "running"
    
    def stop(self):
        """停止容器"""
        print(f"Stopping container: {self.image}")
        self.status = "stopped"
    
    def is_running(self) -> bool:
        return self.status == "running"

class DatabaseContainer(TestContainer):
    def __init__(self, image: str, port: int):
        super().__init__(image, {"port": port})
        self.connection_url = f"postgresql://test:test@localhost:{port}/testdb"
    
    def execute_query(self, query: str):
        """执行SQL查询"""
        print(f"Executing: {query}")
        return []

class MessageQueueContainer(TestContainer):
    def __init__(self, image: str, port: int):
        super().__init__(image, {"port": port})

E2E测试策略

E2E测试位于金字塔顶端,验证完整的用户流程。数量最少但信心最高。使用Page Object模式提高可维护性,选择性地运行关键路径测试。

# E2E测试框架
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import time

class E2ETestBase(unittest.TestCase):
    """E2E测试基类"""
    
    @classmethod
    def setUpClass(cls):
        """设置浏览器"""
        options = webdriver.ChromeOptions()
        options.add_argument("--headless")
        cls.driver = webdriver.Chrome(options=options)
        cls.wait = WebDriverWait(cls.driver, 10)
    
    @classmethod
    def tearDownClass(cls):
        """关闭浏览器"""
        cls.driver.quit()
    
    def take_screenshot(self, name: str):
        """截图"""
        self.driver.save_screenshot(f"screenshots/{name}.png")

class PageObject:
    """页面对象基类"""
    
    def __init__(self, driver):
        self.driver = driver
        self.wait = WebDriverWait(driver, 10)
    
    def find_element(self, locator: tuple):
        return self.wait.until(EC.presence_of_element_located(locator))
    
    def click(self, locator: tuple):
        element = self.wait.until(EC.element_to_be_clickable(locator))
        element.click()
    
    def type_text(self, locator: tuple, text: str):
        element = self.find_element(locator)
        element.clear()
        element.send_keys(text)

class LoginPage(PageObject):
    """登录页面"""
    
    USERNAME = (By.ID, "username")
    PASSWORD = (By.ID, "password")
    LOGIN_BUTTON = (By.ID, "login-btn")
    ERROR_MESSAGE = (By.CLASS_NAME, "error")
    
    def login(self, username: str, password: str):
        self.type_text(self.USERNAME, username)
        self.type_text(self.PASSWORD, password)
        self.click(self.LOGIN_BUTTON)
        return DashboardPage(self.driver)
    
    def get_error(self) -> str:
        return self.find_element(self.ERROR_MESSAGE).text

class DashboardPage(PageObject):
    """仪表板页面"""
    
    WELCOME = (By.CLASS_NAME, "welcome")
    LOGOUT = (By.ID, "logout")
    
    def get_welcome_message(self) -> str:
        return self.find_element(self.WELCOME).text
    
    def logout(self):
        self.click(self.LOGOUT)
        return LoginPage(self.driver)

# E2E测试用例
class UserFlowTest(E2ETestBase):
    def test_complete_user_flow(self):
        """测试完整用户流程"""
        # 登录
        login_page = LoginPage(self.driver)
        login_page.driver.get("https://example.com/login")
        
        dashboard = login_page.login("testuser", "password123")
        
        # 验证仪表板
        welcome = dashboard.get_welcome_message()
        self.assertIn("Welcome", welcome)
        
        # 登出
        login_page = dashboard.logout()
        
        # 验证返回登录页
        self.driver.find_element(*LoginPage.LOGIN_BUTTON)

# 选择性E2E测试
class SelectiveE2ETest:
    def __init__(self):
        self.critical_paths = [
            "user_login",
            "create_order",
            "payment_process"
        ]
    
    def should_run_test(self, test_name: str, 
                       changed_files: List[str]) -> bool:
        """判断是否应该运行测试"""
        # 始终运行关键路径
        if test_name in self.critical_paths:
            return True
        
        # 根据变更文件判断
        relevant_files = {
            "user_login": ["auth/", "login/"],
            "create_order": ["order/", "cart/"],
            "payment_process": ["payment/", "checkout/"]
        }
        
        for critical_path, files in relevant_files.items():
            if any(f in file for file in changed_files for f in files):
                return True
        
        return False

# 测试报告生成
class E2ETestReporter:
    def __init__(self, results: List[Dict]):
        self.results = results
    
    def generate_html_report(self) -> str:
        """生成HTML报告"""
        passed = sum(1 for r in self.results if r["status"] == "passed")
        failed = sum(1 for r in self.results if r["status"] == "failed")
        
        html = f"""
<!DOCTYPE html>
<html>
<head>
    <title>E2E Test Report</title>
    <style>
        .passed {{ color: green; }}
        .failed {{ color: red; }}
    </style>
</head>
<body>
    <h1>E2E Test Report</h1>
    <div class="summary">
        <p>Total: {len(self.results)}</p>
        <p class="passed">Passed: {passed}</p>
        <p class="failed">Failed: {failed}</p>
    </div>
    <table>
        <tr><th>Test</th><th>Status</th><th>Duration</th></tr>
"""
        
        for result in self.results:
            status_class = "passed" if result["status"] == "passed" else "failed"
            html += f"""
        <tr>
            <td>{result['name']}</td>
            <td class="{status_class}">{result['status']}</td>
            <td>{result.get('duration', 0):.2f}s</td>
        </tr>
"""
        
        html += """
    </table>
</body>
</html>
"""
        return html

测试金字塔反模式

识别和避免测试金字塔的常见反模式:冰淇淋筒(E2E过多)、沙漏型(缺少集成测试)、笼型(单元测试不足)。

# 测试反模式检测
class TestAntiPatternDetector:
    def __init__(self, test_distribution: Dict):
        self.distribution = test_distribution
    
    def detect_all(self) -> List[Dict]:
        """检测所有反模式"""
        patterns = []
        
        # 检测冰淇淋筒反模式
        ice_cream = self._detect_ice_cream_cone()
        if ice_cream:
            patterns.append(ice_cream)
        
        # 检测沙漏反模式
        hourglass = self._detect_hourglass()
        if hourglass:
            patterns.append(hourglass)
        
        # 检测笼型反模式
        cage = self._detect_cage()
        if cage:
            patterns.append(cage)
        
        return patterns
    
    def _detect_ice_cream_cone(self) -> Dict:
        """检测冰淇淋筒反模式(E2E过多)"""
        unit = self.distribution.get("unit", 0)
        integration = self.distribution.get("integration", 0)
        e2e = self.distribution.get("e2e", 0)
        
        total = unit + integration + e2e
        if total == 0:
            return None
        
        e2e_ratio = e2e / total
        
        if e2e_ratio > 0.3:  # E2E超过30%
            return {
                "pattern": "ice_cream_cone",
                "severity": "high",
                "description": "E2E测试比例过高,导致测试套件缓慢且脆弱",
                "recommendation": "将更多测试下沉到单元和集成测试层"
            }
        
        return None
    
    def _detect_hourglass(self) -> Dict:
        """检测沙漏反模式(缺少集成测试)"""
        unit = self.distribution.get("unit", 0)
        integration = self.distribution.get("integration", 0)
        e2e = self.distribution.get("e2e", 0)
        
        total = unit + integration + e2e
        if total == 0:
            return None
        
        integration_ratio = integration / total
        
        if integration_ratio < 0.1:  # 集成测试少于10%
            return {
                "pattern": "hourglass",
                "severity": "medium",
                "description": "集成测试不足,组件间交互验证不够",
                "recommendation": "增加集成测试覆盖关键的组件交互"
            }
        
        return None
    
    def _detect_cage(self) -> Dict:
        """检测笼型反模式(单元测试不足)"""
        unit = self.distribution.get("unit", 0)
        integration = self.distribution.get("integration", 0)
        e2e = self.distribution.get("e2e", 0)
        
        total = unit + integration + e2e
        if total == 0:
            return None
        
        unit_ratio = unit / total
        
        if unit_ratio < 0.5:  # 单元测试少于50%
            return {
                "pattern": "cage",
                "severity": "high",
                "description": "单元测试不足,代码质量保障不够",
                "recommendation": "增加单元测试覆盖核心业务逻辑"
            }
        
        return None
    
    def generate_recommendations(self) -> List[str]:
        """生成改进建议"""
        patterns = self.detect_all()
        recommendations = []
        
        for pattern in patterns:
            recommendations.append(pattern["recommendation"])
        
        if not recommendations:
            recommendations.append("测试金字塔结构健康,继续保持")
        
        return recommendations

# 使用示例
detector = TestAntiPatternDetector({
    "unit": 50,
    "integration": 10,
    "e2e": 40
})

patterns = detector.detect_all()
for pattern in patterns:
    print(f"Detected: {pattern['pattern']} - {pattern['description']}")