← 返回首页
🧪

单元测试:unittest、pytest、断言与测试组织

📂 python ⏱ 4 min 757 words

单元测试:unittest、pytest、断言与测试组织

单元测试是保证代码质量的关键手段,能帮助我们在代码变更时快速发现问题。本文将介绍Python中最常用的测试框架和最佳实践。

为什么需要单元测试

unittest框架

unittest是Python标准库自带的测试框架:

import unittest

def add(a, b):
    return a + b

def divide(a, b):
    if b == 0:
        raise ValueError("除数不能为零")
    return a / b

class TestCalculator(unittest.TestCase):
    
    def setUp(self):
        """每个测试方法前执行"""
        self.test_data = [1, 2, 3, 4, 5]
    
    def tearDown(self):
        """每个测试方法后执行"""
        pass
    
    def test_add_positive(self):
        self.assertEqual(add(2, 3), 5)
    
    def test_add_negative(self):
        self.assertEqual(add(-1, -1), -2)
    
    def test_add_zero(self):
        self.assertEqual(add(0, 5), 5)
    
    def test_divide_normal(self):
        self.assertAlmostEqual(divide(10, 2), 5.0)
    
    def test_divide_by_zero(self):
        with self.assertRaises(ValueError):
            divide(10, 0)

if __name__ == '__main__':
    unittest.main()

pytest框架

pytest是更现代化的测试框架,语法更简洁:

# test_calculator.py
def add(a, b):
    return a + b

def divide(a, b):
    if b == 0:
        raise ValueError("除数不能为零")
    return a / b

# 简单的测试函数
def test_add():
    assert add(2, 3) == 5
    assert add(-1, -1) == -2
    assert add(0, 5) == 5

def test_divide():
    assert divide(10, 2) == 5.0
    assert divide(9, 3) == 3.0

def test_divide_by_zero():
    import pytest
    with pytest.raises(ValueError):
        divide(10, 0)

# 使用fixtures
import pytest

@pytest.fixture
def sample_list():
    return [1, 2, 3, 4, 5]

def test_list_length(sample_list):
    assert len(sample_list) == 5

def test_list_sum(sample_list):
    assert sum(sample_list) == 15

pytest进阶特性

import pytest

# 参数化测试
@pytest.mark.parametrize("a,b,expected", [
    (1, 2, 3),
    (5, 5, 10),
    (-1, 1, 0),
    (0, 0, 0),
])
def test_add_parametrize(a, b, expected):
    assert a + b == expected

# 标记测试
@pytest.mark.slow
def test_heavy_computation():
    # 耗时测试
    result = sum(range(1000000))
    assert result == 499999500000

@pytest.mark.skip(reason="暂时跳过")
def test_not_ready():
    pass

@pytest.mark.skipif(True, reason="条件不满足时跳过")
def test_conditional_skip():
    pass

# 异常测试
def test_raises_exception():
    with pytest.raises(ZeroDivisionError):
        1 / 0

def test_exception_message():
    with pytest.raises(ValueError, match="不能为负"):
        raise ValueError("值不能为负数")

# 临时目录和文件
def test_with_tmpdir(tmpdir):
    file = tmpdir.join("test.txt")
    file.write("hello")
    assert file.read() == "hello"

# fixtures依赖
@pytest.fixture
def database():
    # 设置
    db = {"users": []}
    yield db
    # 清理
    db.clear()

def test_add_user(database):
    database["users"].append("张三")
    assert len(database["users"]) == 1

测试类和方法

import pytest

class Calculator:
    def __init__(self):
        self.result = 0
    
    def add(self, value):
        self.result += value
        return self
    
    def subtract(self, value):
        self.result -= value
        return self
    
    def reset(self):
        self.result = 0
        return self

class TestCalculator:
    
    @pytest.fixture
    def calc(self):
        return Calculator()
    
    def test_initial_value(self, calc):
        assert calc.result == 0
    
    def test_add(self, calc):
        calc.add(5)
        assert calc.result == 5
    
    def test_chain_add(self, calc):
        calc.add(1).add(2).add(3)
        assert calc.result == 6
    
    def test_subtract(self, calc):
        calc.subtract(5)
        assert calc.result == -5
    
    def test_reset(self, calc):
        calc.add(10)
        calc.reset()
        assert calc.result == 0

测试组织最佳实践

# 项目结构
# project/
# ├── src/
# │   ├── calculator.py
# │   └── utils.py
# ├── tests/
# │   ├── __init__.py
# │   ├── conftest.py      # pytest fixtures
# │   ├── test_calculator.py
# │   └── test_utils.py
# └── pytest.ini            # pytest配置

# conftest.py - 共享的fixtures
import pytest

@pytest.fixture
def sample_user():
    return {
        "name": "张三",
        "age": 25,
        "email": "zhangsan@example.com"
    }

@pytest.fixture
def db_connection():
    # 模拟数据库连接
    class MockDB:
        def __init__(self):
            self.connected = True
        
        def execute(self, query):
            return []
        
        def close(self):
            self.connected = False
    
    db = MockDB()
    yield db
    db.close()

# test_calculator.py
def test_add():
    from src.calculator import add
    assert add(2, 3) == 5

# test_utils.py
def test_validate_email():
    from src.utils import validate_email
    assert validate_email("test@example.com") == True
    assert validate_email("invalid") == False

测试覆盖率

# 安装 pytest-cov
# pip install pytest-cov

# 运行测试并生成覆盖率报告
# pytest --cov=src --cov-report=html

# pytest.ini配置
"""
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
markers =
    slow: marks tests as slow
    smoke: marks tests as smoke tests
"""

# 在代码中检查覆盖率
def complex_function(x, y, z):
    if x > 0:
        if y > 0:
            return x + y + z
        else:
            return x + z
    else:
        if z > 0:
            return y + z
        else:
            return 0

# 测试用例
def test_complex_positive():
    assert complex_function(1, 1, 1) == 3

def test_complex_negative():
    assert complex_function(-1, 1, 1) == 2

Mock测试

from unittest.mock import Mock, patch, MagicMock
import pytest

# 被测试的代码
class UserService:
    def __init__(self, db):
        self.db = db
    
    def get_user(self, user_id):
        return self.db.query(f"SELECT * FROM users WHERE id={user_id}")
    
    def create_user(self, name, email):
        if not email:
            raise ValueError("邮箱不能为空")
        return self.db.insert(f"INSERT INTO users (name, email) VALUES ('{name}', '{email}')")

# 测试代码
class TestUserService:
    
    @pytest.fixture
    def mock_db(self):
        return Mock()
    
    @pytest.fixture
    def service(self, mock_db):
        return UserService(mock_db)
    
    def test_get_user(self, service, mock_db):
        mock_db.query.return_value = {"id": 1, "name": "张三"}
        user = service.get_user(1)
        mock_db.query.assert_called_once()
        assert user["name"] == "张三"
    
    def test_create_user(self, service, mock_db):
        mock_db.insert.return_value = 1
        result = service.create_user("张三", "test@example.com")
        assert result == 1
    
    def test_create_user_invalid_email(self, service):
        with pytest.raises(ValueError):
            service.create_user("张三", "")

测试驱动开发(TDD)

import pytest

# TDD步骤:先写测试,再写实现

# 1. 编写失败的测试
def test_fizzbuzz():
    assert fizzbuzz(1) == "1"
    assert fizzbuzz(2) == "2"
    assert fizzbuzz(3) == "Fizz"
    assert fizzbuzz(5) == "Buzz"
    assert fizzbuzz(15) == "FizzBuzz"

# 2. 编写最小实现
def fizzbuzz(n):
    if n % 15 == 0:
        return "FizzBuzz"
    elif n % 3 == 0:
        return "Fizz"
    elif n % 5 == 0:
        return "Buzz"
    else:
        return str(n)

# 3. 重构优化
def fizzbuzz_refactored(n):
    if n % 15 == 0:
        return "FizzBuzz"
    if n % 3 == 0:
        return "Fizz"
    if n % 5 == 0:
        return "Buzz"
    return str(n)

运行测试

# 运行所有测试
pytest

# 运行特定文件
pytest tests/test_calculator.py

# 运行特定测试
pytest tests/test_calculator.py::test_add

# 运行标记的测试
pytest -m slow

# 显示详细输出
pytest -v

# 显示print输出
pytest -s

# 运行上次失败的测试
pytest --lf

# 停在第一个失败
pytest -x

# 生成覆盖率报告
pytest --cov=src --cov-report=html

总结

单元测试是专业软件开发的必备技能。推荐使用pytest框架,它的语法简洁、功能强大。记住:测试不是为了达到覆盖率指标,而是为了保证代码质量。编写清晰、有意义的测试,让它们成为你代码的守护者。