单元测试:unittest、pytest、断言与测试组织
单元测试:unittest、pytest、断言与测试组织
单元测试是保证代码质量的关键手段,能帮助我们在代码变更时快速发现问题。本文将介绍Python中最常用的测试框架和最佳实践。
为什么需要单元测试
- 提前发现问题:在开发阶段发现bug比生产环境更便宜
- 重构保障:修改代码时确保现有功能不被破坏
- 文档作用:测试用例展示函数的预期行为
- 设计辅助:编写可测试的代码通常设计更好
unittest框架
unittest是Python标准库自带的测试框架:
import unittest
def add(a, b):
return a + b
def divide(a, b):
if b == 0:
raise ValueError("除数不能为零")
return a / b
class TestCalculator(unittest.TestCase):
def setUp(self):
"""每个测试方法前执行"""
self.test_data = [1, 2, 3, 4, 5]
def tearDown(self):
"""每个测试方法后执行"""
pass
def test_add_positive(self):
self.assertEqual(add(2, 3), 5)
def test_add_negative(self):
self.assertEqual(add(-1, -1), -2)
def test_add_zero(self):
self.assertEqual(add(0, 5), 5)
def test_divide_normal(self):
self.assertAlmostEqual(divide(10, 2), 5.0)
def test_divide_by_zero(self):
with self.assertRaises(ValueError):
divide(10, 0)
if __name__ == '__main__':
unittest.main()
pytest框架
pytest是更现代化的测试框架,语法更简洁:
# test_calculator.py
def add(a, b):
return a + b
def divide(a, b):
if b == 0:
raise ValueError("除数不能为零")
return a / b
# 简单的测试函数
def test_add():
assert add(2, 3) == 5
assert add(-1, -1) == -2
assert add(0, 5) == 5
def test_divide():
assert divide(10, 2) == 5.0
assert divide(9, 3) == 3.0
def test_divide_by_zero():
import pytest
with pytest.raises(ValueError):
divide(10, 0)
# 使用fixtures
import pytest
@pytest.fixture
def sample_list():
return [1, 2, 3, 4, 5]
def test_list_length(sample_list):
assert len(sample_list) == 5
def test_list_sum(sample_list):
assert sum(sample_list) == 15
pytest进阶特性
import pytest
# 参数化测试
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(5, 5, 10),
(-1, 1, 0),
(0, 0, 0),
])
def test_add_parametrize(a, b, expected):
assert a + b == expected
# 标记测试
@pytest.mark.slow
def test_heavy_computation():
# 耗时测试
result = sum(range(1000000))
assert result == 499999500000
@pytest.mark.skip(reason="暂时跳过")
def test_not_ready():
pass
@pytest.mark.skipif(True, reason="条件不满足时跳过")
def test_conditional_skip():
pass
# 异常测试
def test_raises_exception():
with pytest.raises(ZeroDivisionError):
1 / 0
def test_exception_message():
with pytest.raises(ValueError, match="不能为负"):
raise ValueError("值不能为负数")
# 临时目录和文件
def test_with_tmpdir(tmpdir):
file = tmpdir.join("test.txt")
file.write("hello")
assert file.read() == "hello"
# fixtures依赖
@pytest.fixture
def database():
# 设置
db = {"users": []}
yield db
# 清理
db.clear()
def test_add_user(database):
database["users"].append("张三")
assert len(database["users"]) == 1
测试类和方法
import pytest
class Calculator:
def __init__(self):
self.result = 0
def add(self, value):
self.result += value
return self
def subtract(self, value):
self.result -= value
return self
def reset(self):
self.result = 0
return self
class TestCalculator:
@pytest.fixture
def calc(self):
return Calculator()
def test_initial_value(self, calc):
assert calc.result == 0
def test_add(self, calc):
calc.add(5)
assert calc.result == 5
def test_chain_add(self, calc):
calc.add(1).add(2).add(3)
assert calc.result == 6
def test_subtract(self, calc):
calc.subtract(5)
assert calc.result == -5
def test_reset(self, calc):
calc.add(10)
calc.reset()
assert calc.result == 0
测试组织最佳实践
# 项目结构
# project/
# ├── src/
# │ ├── calculator.py
# │ └── utils.py
# ├── tests/
# │ ├── __init__.py
# │ ├── conftest.py # pytest fixtures
# │ ├── test_calculator.py
# │ └── test_utils.py
# └── pytest.ini # pytest配置
# conftest.py - 共享的fixtures
import pytest
@pytest.fixture
def sample_user():
return {
"name": "张三",
"age": 25,
"email": "zhangsan@example.com"
}
@pytest.fixture
def db_connection():
# 模拟数据库连接
class MockDB:
def __init__(self):
self.connected = True
def execute(self, query):
return []
def close(self):
self.connected = False
db = MockDB()
yield db
db.close()
# test_calculator.py
def test_add():
from src.calculator import add
assert add(2, 3) == 5
# test_utils.py
def test_validate_email():
from src.utils import validate_email
assert validate_email("test@example.com") == True
assert validate_email("invalid") == False
测试覆盖率
# 安装 pytest-cov
# pip install pytest-cov
# 运行测试并生成覆盖率报告
# pytest --cov=src --cov-report=html
# pytest.ini配置
"""
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
markers =
slow: marks tests as slow
smoke: marks tests as smoke tests
"""
# 在代码中检查覆盖率
def complex_function(x, y, z):
if x > 0:
if y > 0:
return x + y + z
else:
return x + z
else:
if z > 0:
return y + z
else:
return 0
# 测试用例
def test_complex_positive():
assert complex_function(1, 1, 1) == 3
def test_complex_negative():
assert complex_function(-1, 1, 1) == 2
Mock测试
from unittest.mock import Mock, patch, MagicMock
import pytest
# 被测试的代码
class UserService:
def __init__(self, db):
self.db = db
def get_user(self, user_id):
return self.db.query(f"SELECT * FROM users WHERE id={user_id}")
def create_user(self, name, email):
if not email:
raise ValueError("邮箱不能为空")
return self.db.insert(f"INSERT INTO users (name, email) VALUES ('{name}', '{email}')")
# 测试代码
class TestUserService:
@pytest.fixture
def mock_db(self):
return Mock()
@pytest.fixture
def service(self, mock_db):
return UserService(mock_db)
def test_get_user(self, service, mock_db):
mock_db.query.return_value = {"id": 1, "name": "张三"}
user = service.get_user(1)
mock_db.query.assert_called_once()
assert user["name"] == "张三"
def test_create_user(self, service, mock_db):
mock_db.insert.return_value = 1
result = service.create_user("张三", "test@example.com")
assert result == 1
def test_create_user_invalid_email(self, service):
with pytest.raises(ValueError):
service.create_user("张三", "")
测试驱动开发(TDD)
import pytest
# TDD步骤:先写测试,再写实现
# 1. 编写失败的测试
def test_fizzbuzz():
assert fizzbuzz(1) == "1"
assert fizzbuzz(2) == "2"
assert fizzbuzz(3) == "Fizz"
assert fizzbuzz(5) == "Buzz"
assert fizzbuzz(15) == "FizzBuzz"
# 2. 编写最小实现
def fizzbuzz(n):
if n % 15 == 0:
return "FizzBuzz"
elif n % 3 == 0:
return "Fizz"
elif n % 5 == 0:
return "Buzz"
else:
return str(n)
# 3. 重构优化
def fizzbuzz_refactored(n):
if n % 15 == 0:
return "FizzBuzz"
if n % 3 == 0:
return "Fizz"
if n % 5 == 0:
return "Buzz"
return str(n)
运行测试
# 运行所有测试
pytest
# 运行特定文件
pytest tests/test_calculator.py
# 运行特定测试
pytest tests/test_calculator.py::test_add
# 运行标记的测试
pytest -m slow
# 显示详细输出
pytest -v
# 显示print输出
pytest -s
# 运行上次失败的测试
pytest --lf
# 停在第一个失败
pytest -x
# 生成覆盖率报告
pytest --cov=src --cov-report=html
总结
单元测试是专业软件开发的必备技能。推荐使用pytest框架,它的语法简洁、功能强大。记住:测试不是为了达到覆盖率指标,而是为了保证代码质量。编写清晰、有意义的测试,让它们成为你代码的守护者。