← 返回首页
测试

测试驱动开发

📂 python ⏱ 4 min 659 words

测试驱动开发

测试驱动开发(TDD)是一种通过先写测试再写代码的开发方法。本文将介绍TDD的核心原则、RED-GREEN-REFACTOR循环和最佳实践。

TDD核心原则

# 1. 只在测试失败时写代码
# 2. 消除重复设计
# 3. 保持小步前进

# 传统开发 vs TDD
# 传统:需求 → 代码 → 测试 → 修复
# TDD:测试 → 代码 → 重构 → 下一个测试

RED-GREEN-REFACTOR循环

# 第一步:RED - 写一个失败的测试
def test_add():
    calculator = Calculator()
    assert calculator.add(2, 3) == 5

# 运行测试,期望失败(没有Calculator类)

# 第二步:GREEN - 写最少的代码让测试通过
class Calculator:
    def add(self, a, b):
        return a + b

# 运行测试,期望通过

# 第三步:REFACTOR - 改进代码结构
class Calculator:
    """简单的计算器类"""
    
    def add(self, a: int, b: int) -> int:
        """返回两个数的和"""
        return a + b

# 运行测试,仍然通过

实战示例:购物车

# 第一个测试:空购物车总价为0
def test_empty_cart_total():
    cart = ShoppingCart()
    assert cart.total() == 0

# 实现最小代码
class ShoppingCart:
    def total(self):
        return 0

# 第二个测试:添加商品
def test_add_item():
    cart = ShoppingCart()
    cart.add_item("apple", 1.5, 2)
    assert cart.total() == 3.0

# 更新实现
class ShoppingCart:
    def __init__(self):
        self.items = []
    
    def add_item(self, name, price, quantity):
        self.items.append({
            "name": name,
            "price": price,
            "quantity": quantity
        })
    
    def total(self):
        return sum(item["price"] * item["quantity"] 
                  for item in self.items)

# 第三个测试:折扣
def test_discount():
    cart = ShoppingCart()
    cart.add_item("apple", 1.5, 10)
    cart.apply_discount(0.1)  # 10%折扣
    assert cart.total() == 13.5

# 更新实现
class ShoppingCart:
    def __init__(self):
        self.items = []
        self.discount = 0
    
    def add_item(self, name, price, quantity):
        self.items.append({
            "name": name,
            "price": price,
            "quantity": quantity
        })
    
    def apply_discount(self, percentage):
        self.discount = percentage
    
    def total(self):
        subtotal = sum(item["price"] * item["quantity"] 
                      for item in self.items)
        return subtotal * (1 - self.discount)

测试金字塔

# 单元测试(底层,最多)
class TestCalculatorUnit:
    def test_add(self):
        assert Calculator().add(2, 3) == 5
    
    def test_subtract(self):
        assert Calculator().subtract(5, 2) == 3
    
    def test_divide_by_zero(self):
        with pytest.raises(ZeroDivisionError):
            Calculator().divide(1, 0)

# 集成测试(中间层,适量)
class TestCalculatorIntegration:
    def test_with_logging(self):
        calc = Calculator()
        with patch('app.logger') as mock_logger:
            result = calc.add(2, 3)
            mock_logger.info.assert_called_with("计算: 2 + 3 = 5")
            assert result == 5
    
    def test_with_database(self, db_session):
        calc = Calculator()
        result = calc.add(2, 3)
        db_session.save(Calculation(result=result))
        assert db_session.query(Calculation).count() == 1

# 端到端测试(顶层,最少)
class TestCalculatorE2E:
    def test_api_endpoint(self, client):
        response = client.post("/api/calculate", json={
            "operation": "add",
            "a": 2,
            "b": 3
        })
        assert response.status_code == 200
        assert response.json["result"] == 5

Mock和Stub

from unittest.mock import Mock, patch, MagicMock
import pytest

# Mock外部依赖
class UserService:
    def __init__(self, db):
        self.db = db
    
    def get_user(self, user_id):
        return self.db.query(User).get(user_id)

def test_get_user():
    mock_db = Mock()
    mock_db.query.return_value.get.return_value = User(id=1, name="Alice")
    
    service = UserService(mock_db)
    user = service.get_user(1)
    
    assert user.name == "Alice"
    mock_db.query.assert_called_once_with(User)

# 使用pytest-mock
def test_send_email(mocker):
    mock_smtp = mocker.patch('smtplib.SMTP')
    mock_smtp.return_value.__enter__.return_value.sendmail.return_value = {}
    
    service = EmailService()
    service.send_email("test@example.com", "Hello", "Body")
    
    mock_smtp.assert_called_once_with('smtp.example.com', 587)

# 测试数据库操作
@pytest.fixture
def db_session():
    session = create_test_session()
    yield session
    session.rollback()
    session.close()

def test_create_user(db_session):
    service = UserService(db_session)
    user = service.create_user("Alice", "alice@example.com")
    
    assert user.id is not None
    assert db_session.query(User).count() == 1

参数化测试

import pytest

# 参数化测试
@pytest.mark.parametrize("input,expected", [
    (1, 1),
    (2, 4),
    (3, 9),
    (4, 16),
    (5, 25),
])
def test_square(input, expected):
    calculator = Calculator()
    assert calculator.square(input) == expected

# 多参数
@pytest.mark.parametrize("a,b,expected", [
    (1, 2, 3),
    (0, 0, 0),
    (-1, 1, 0),
    (100, 200, 300),
])
def test_add(a, b, expected):
    calculator = Calculator()
    assert calculator.add(a, b) == expected

# 条件测试
@pytest.mark.skipif(
    sys.platform == "win32",
    reason="Windows not supported"
)
def test_unix_only():
    assert os.name == 'posix'

# 测试期望异常
def test_divide_by_zero():
    calculator = Calculator()
    with pytest.raises(ZeroDivisionError, match="除数不能为零"):
        calculator.divide(1, 0)

测试覆盖率

# 测试覆盖率配置
# .coveragerc
[run]
source = src
omit = 
    */tests/*
    */migrations/*
    setup.py

[report]
exclude_lines =
    pragma: no cover
    def __repr__
    if __name__ == .__main__
    raise NotImplementedError
show_missing = True
fail_under = 80

# 使用pytest-cov
# pytest --cov=src --cov-report=html --cov-report=term

# 测试覆盖率策略
class TestCoverageStrategy:
    """确保关键路径有测试覆盖"""
    
    def test_critical_payment_flow(self):
        # 支付流程必须有测试
        payment_service = PaymentService()
        result = payment_service.process_payment(order_id=123)
        assert result.success is True
    
    def test_error_handling(self):
        # 错误处理必须有测试
        service = UserService()
        with pytest.raises(UserNotFoundError):
            service.get_user(999)

重构模式

# 重构前:重复代码
class OrderProcessor:
    def process_order(self, order):
        # 验证订单
        if not order.items:
            raise ValueError("订单不能为空")
        if order.total <= 0:
            raise ValueError("订单金额必须大于0")
        
        # 计算折扣
        if order.total > 1000:
            discount = 0.1
        elif order.total > 500:
            discount = 0.05
        else:
            discount = 0
        
        # 处理支付
        # ... 重复代码

# 重构后:提取方法
class OrderProcessor:
    def validate_order(self, order):
        if not order.items:
            raise ValueError("订单不能为空")
        if order.total <= 0:
            raise ValueError("订单金额必须大于0")
    
    def calculate_discount(self, total):
        if total > 1000:
            return 0.1
        elif total > 500:
            return 0.05
        return 0
    
    def process_order(self, order):
        self.validate_order(order)
        discount = self.calculate_discount(order.total)
        # ... 清晰的处理逻辑

测试命名规范

# 命名模式:test_<被测试功能>_<场景>_<预期结果>
class TestCalculator:
    def test_add_positive_numbers_returns_sum(self):
        assert Calculator().add(2, 3) == 5
    
    def test_add_negative_numbers_returns_sum(self):
        assert Calculator().add(-2, -3) == -5
    
    def test_divide_by_zero_raises_error(self):
        with pytest.raises(ZeroDivisionError):
            Calculator().divide(1, 0)
    
    def test_divide_non_zero_returns_quotient(self):
        assert Calculator().divide(6, 2) == 3.0

# 测试套件组织
@pytest.mark.unit
class TestCalculatorUnit:
    # 单元测试

@pytest.mark.integration
class TestCalculatorIntegration:
    # 集成测试

@pytest.mark.slow
class TestCalculatorPerformance:
    # 性能测试

最佳实践

  1. 小步前进:每次只添加一个测试
  2. 快速反馈:测试运行时间<1秒
  3. 独立测试:测试之间无依赖
  4. 可重复:每次运行结果相同
  5. 自验证:测试通过即代码正确

常见陷阱

总结

TDD通过RED-GREEN-REFACTOR循环确保代码质量。遵循测试金字塔原则,使用Mock隔离依赖,保持测试简洁,可以构建可维护的高质量代码。