测试驱动开发
测试驱动开发
测试驱动开发(TDD)是一种通过先写测试再写代码的开发方法。本文将介绍TDD的核心原则、RED-GREEN-REFACTOR循环和最佳实践。
TDD核心原则
# 1. 只在测试失败时写代码
# 2. 消除重复设计
# 3. 保持小步前进
# 传统开发 vs TDD
# 传统:需求 → 代码 → 测试 → 修复
# TDD:测试 → 代码 → 重构 → 下一个测试
RED-GREEN-REFACTOR循环
# 第一步:RED - 写一个失败的测试
def test_add():
calculator = Calculator()
assert calculator.add(2, 3) == 5
# 运行测试,期望失败(没有Calculator类)
# 第二步:GREEN - 写最少的代码让测试通过
class Calculator:
def add(self, a, b):
return a + b
# 运行测试,期望通过
# 第三步:REFACTOR - 改进代码结构
class Calculator:
"""简单的计算器类"""
def add(self, a: int, b: int) -> int:
"""返回两个数的和"""
return a + b
# 运行测试,仍然通过
实战示例:购物车
# 第一个测试:空购物车总价为0
def test_empty_cart_total():
cart = ShoppingCart()
assert cart.total() == 0
# 实现最小代码
class ShoppingCart:
def total(self):
return 0
# 第二个测试:添加商品
def test_add_item():
cart = ShoppingCart()
cart.add_item("apple", 1.5, 2)
assert cart.total() == 3.0
# 更新实现
class ShoppingCart:
def __init__(self):
self.items = []
def add_item(self, name, price, quantity):
self.items.append({
"name": name,
"price": price,
"quantity": quantity
})
def total(self):
return sum(item["price"] * item["quantity"]
for item in self.items)
# 第三个测试:折扣
def test_discount():
cart = ShoppingCart()
cart.add_item("apple", 1.5, 10)
cart.apply_discount(0.1) # 10%折扣
assert cart.total() == 13.5
# 更新实现
class ShoppingCart:
def __init__(self):
self.items = []
self.discount = 0
def add_item(self, name, price, quantity):
self.items.append({
"name": name,
"price": price,
"quantity": quantity
})
def apply_discount(self, percentage):
self.discount = percentage
def total(self):
subtotal = sum(item["price"] * item["quantity"]
for item in self.items)
return subtotal * (1 - self.discount)
测试金字塔
# 单元测试(底层,最多)
class TestCalculatorUnit:
def test_add(self):
assert Calculator().add(2, 3) == 5
def test_subtract(self):
assert Calculator().subtract(5, 2) == 3
def test_divide_by_zero(self):
with pytest.raises(ZeroDivisionError):
Calculator().divide(1, 0)
# 集成测试(中间层,适量)
class TestCalculatorIntegration:
def test_with_logging(self):
calc = Calculator()
with patch('app.logger') as mock_logger:
result = calc.add(2, 3)
mock_logger.info.assert_called_with("计算: 2 + 3 = 5")
assert result == 5
def test_with_database(self, db_session):
calc = Calculator()
result = calc.add(2, 3)
db_session.save(Calculation(result=result))
assert db_session.query(Calculation).count() == 1
# 端到端测试(顶层,最少)
class TestCalculatorE2E:
def test_api_endpoint(self, client):
response = client.post("/api/calculate", json={
"operation": "add",
"a": 2,
"b": 3
})
assert response.status_code == 200
assert response.json["result"] == 5
Mock和Stub
from unittest.mock import Mock, patch, MagicMock
import pytest
# Mock外部依赖
class UserService:
def __init__(self, db):
self.db = db
def get_user(self, user_id):
return self.db.query(User).get(user_id)
def test_get_user():
mock_db = Mock()
mock_db.query.return_value.get.return_value = User(id=1, name="Alice")
service = UserService(mock_db)
user = service.get_user(1)
assert user.name == "Alice"
mock_db.query.assert_called_once_with(User)
# 使用pytest-mock
def test_send_email(mocker):
mock_smtp = mocker.patch('smtplib.SMTP')
mock_smtp.return_value.__enter__.return_value.sendmail.return_value = {}
service = EmailService()
service.send_email("test@example.com", "Hello", "Body")
mock_smtp.assert_called_once_with('smtp.example.com', 587)
# 测试数据库操作
@pytest.fixture
def db_session():
session = create_test_session()
yield session
session.rollback()
session.close()
def test_create_user(db_session):
service = UserService(db_session)
user = service.create_user("Alice", "alice@example.com")
assert user.id is not None
assert db_session.query(User).count() == 1
参数化测试
import pytest
# 参数化测试
@pytest.mark.parametrize("input,expected", [
(1, 1),
(2, 4),
(3, 9),
(4, 16),
(5, 25),
])
def test_square(input, expected):
calculator = Calculator()
assert calculator.square(input) == expected
# 多参数
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(0, 0, 0),
(-1, 1, 0),
(100, 200, 300),
])
def test_add(a, b, expected):
calculator = Calculator()
assert calculator.add(a, b) == expected
# 条件测试
@pytest.mark.skipif(
sys.platform == "win32",
reason="Windows not supported"
)
def test_unix_only():
assert os.name == 'posix'
# 测试期望异常
def test_divide_by_zero():
calculator = Calculator()
with pytest.raises(ZeroDivisionError, match="除数不能为零"):
calculator.divide(1, 0)
测试覆盖率
# 测试覆盖率配置
# .coveragerc
[run]
source = src
omit =
*/tests/*
*/migrations/*
setup.py
[report]
exclude_lines =
pragma: no cover
def __repr__
if __name__ == .__main__
raise NotImplementedError
show_missing = True
fail_under = 80
# 使用pytest-cov
# pytest --cov=src --cov-report=html --cov-report=term
# 测试覆盖率策略
class TestCoverageStrategy:
"""确保关键路径有测试覆盖"""
def test_critical_payment_flow(self):
# 支付流程必须有测试
payment_service = PaymentService()
result = payment_service.process_payment(order_id=123)
assert result.success is True
def test_error_handling(self):
# 错误处理必须有测试
service = UserService()
with pytest.raises(UserNotFoundError):
service.get_user(999)
重构模式
# 重构前:重复代码
class OrderProcessor:
def process_order(self, order):
# 验证订单
if not order.items:
raise ValueError("订单不能为空")
if order.total <= 0:
raise ValueError("订单金额必须大于0")
# 计算折扣
if order.total > 1000:
discount = 0.1
elif order.total > 500:
discount = 0.05
else:
discount = 0
# 处理支付
# ... 重复代码
# 重构后:提取方法
class OrderProcessor:
def validate_order(self, order):
if not order.items:
raise ValueError("订单不能为空")
if order.total <= 0:
raise ValueError("订单金额必须大于0")
def calculate_discount(self, total):
if total > 1000:
return 0.1
elif total > 500:
return 0.05
return 0
def process_order(self, order):
self.validate_order(order)
discount = self.calculate_discount(order.total)
# ... 清晰的处理逻辑
测试命名规范
# 命名模式:test_<被测试功能>_<场景>_<预期结果>
class TestCalculator:
def test_add_positive_numbers_returns_sum(self):
assert Calculator().add(2, 3) == 5
def test_add_negative_numbers_returns_sum(self):
assert Calculator().add(-2, -3) == -5
def test_divide_by_zero_raises_error(self):
with pytest.raises(ZeroDivisionError):
Calculator().divide(1, 0)
def test_divide_non_zero_returns_quotient(self):
assert Calculator().divide(6, 2) == 3.0
# 测试套件组织
@pytest.mark.unit
class TestCalculatorUnit:
# 单元测试
@pytest.mark.integration
class TestCalculatorIntegration:
# 集成测试
@pytest.mark.slow
class TestCalculatorPerformance:
# 性能测试
最佳实践
- 小步前进:每次只添加一个测试
- 快速反馈:测试运行时间<1秒
- 独立测试:测试之间无依赖
- 可重复:每次运行结果相同
- 自验证:测试通过即代码正确
常见陷阱
- 测试实现而非行为:应该测试"做什么"而非"怎么做"
- 过度mock:真实集成测试很重要
- 脆弱测试:避免测试依赖实现细节
- 测试覆盖≠质量:高覆盖率不代表无bug
- 忽略边界条件:测试空值、极值、异常
总结
TDD通过RED-GREEN-REFACTOR循环确保代码质量。遵循测试金字塔原则,使用Mock隔离依赖,保持测试简洁,可以构建可维护的高质量代码。