← 返回首页
🎭

Mock与测试替身:unittest.mock、patch与MagicMock

📂 python ⏱ 5 min 863 words

Mock与测试替身:unittest.mock、patch与MagicMock

在单元测试中,我们经常需要隔离外部依赖(数据库、API、文件系统等)。Mock对象可以模拟这些依赖,让测试专注于被测试代码的逻辑。本文将深入讲解Python的Mock机制。

什么是测试替身

测试替身(Test Double)是在测试中替代真实依赖的对象:

from unittest.mock import Mock, MagicMock

# 创建Mock对象
mock = Mock()
print(mock)            # <Mock id='...'>
print(mock.some_method) # <Mock name='mock.some_method' id='...'>

# 设置返回值
mock.some_method.return_value = 42
print(mock.some_method())  # 42

# 设置参数化返回值
mock.side_effect = [1, 2, 3]
print(mock())  # 1
print(mock())  # 2
print(mock())  # 3

# 异常行为
mock.side_effect = ValueError("测试错误")
try:
    mock()
except ValueError as e:
    print(f"捕获异常: {e}")

使用patch隔离依赖

import requests
from unittest.mock import patch, Mock

# 被测试的代码
def get_user_info(user_id):
    response = requests.get(f"https://api.example.com/users/{user_id}")
    if response.status_code == 200:
        return response.json()
    return None

# 测试代码 - 使用patch
class TestGetUserInfo:
    
    @patch('requests.get')
    def test_get_user_success(self, mock_get):
        # 设置mock行为
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = {"id": 1, "name": "张三"}
        mock_get.return_value = mock_response
        
        # 调用被测试函数
        result = get_user_info(1)
        
        # 验证结果
        assert result == {"id": 1, "name": "张三"}
        mock_get.assert_called_once_with("https://api.example.com/users/1")
    
    @patch('requests.get')
    def test_get_user_not_found(self, mock_get):
        mock_response = Mock()
        mock_response.status_code = 404
        mock_get.return_value = mock_response
        
        result = get_user_info(999)
        
        assert result is None

MagicMock详解

MagicMock是Mock的增强版,支持魔术方法:

from unittest.mock import MagicMock

# MagicMock支持魔术方法
mock = MagicMock()

# __len__
mock.__len__.return_value = 5
print(len(mock))  # 5

# __iter__
mock.__iter__.return_value = iter([1, 2, 3])
for item in mock:
    print(item)

# __str__
mock.__str__.return_value = "Mock对象"
print(str(mock))  # Mock对象

# __contains__
mock.__contains__.return_value = True
print("item" in mock)  # True

# __call__
mock.__call__.return_value = "被调用了"
print(mock())  # 被调用了

# 链式调用
mock.chain.method.return_value = 42
result = mock.chain.method()
print(result)  # 42

patch的不同用法

from unittest.mock import patch, Mock
import os

# 1. 装饰器方式
@patch('os.path.exists')
def test_file_exists(mock_exists):
    mock_exists.return_value = True
    assert os.path.exists('test.txt') == True

# 2. 上下文管理器方式
def test_file_exists_context():
    with patch('os.path.exists') as mock_exists:
        mock_exists.return_value = False
        assert os.path.exists('test.txt') == False

# 3. 手动启动/停止
def test_file_exists_manual():
    mock_exists = patch('os.path.exists')
    mock_exists.start()
    # ... 测试代码 ...
    mock_exists.stop()

# 4. patch对象的属性
class MyClass:
    def __init__(self):
        self.value = 10

@patch.object(MyClass, 'value', 100)
def test_patch_attribute():
    obj = MyClass()
    assert obj.value == 100

# 5. patch类的返回值
class Service:
    def process(self, data):
        return data * 2

@patch.object(Service, 'process')
def test_patch_class_method(mock_process):
    mock_process.return_value = 100
    service = Service()
    result = service.process(50)
    assert result == 100

模拟文件操作

from unittest.mock import patch, mock_open, Mock
import json

# 被测试的代码
def read_config(filepath):
    with open(filepath, 'r') as f:
        return json.load(f)

def write_log(filepath, message):
    with open(filepath, 'a') as f:
        f.write(message + '\n')

# 测试文件读取
def test_read_config():
    config_data = '{"debug": true, "version": "1.0"}'
    
    with patch('builtins.open', mock_open(read_data=config_data)) as mock_file:
        result = read_config('config.json')
        
        assert result == {"debug": True, "version": "1.0"}
        mock_file.assert_called_once_with('config.json', 'r')

# 测试文件写入
def test_write_log():
    with patch('builtins.open', mock_open()) as mock_file:
        write_log('app.log', '测试消息')
        
        mock_file.assert_called_once_with('app.log', 'a')
        mock_file().write.assert_called_once_with('测试消息\n')

模拟数据库操作

from unittest.mock import Mock, patch, MagicMock
import pytest

# 被测试的代码
class UserRepository:
    def __init__(self, db):
        self.db = db
    
    def get_user(self, user_id):
        result = self.db.execute("SELECT * FROM users WHERE id = %s", (user_id,))
        if result:
            return result[0]
        return None
    
    def create_user(self, name, email):
        self.db.execute(
            "INSERT INTO users (name, email) VALUES (%s, %s)",
            (name, email)
        )
        return self.db.lastrowid

class UserService:
    def __init__(self, user_repo):
        self.user_repo = user_repo
    
    def register_user(self, name, email):
        existing = self.get_user_by_email(email)
        if existing:
            raise ValueError("邮箱已存在")
        return self.user_repo.create_user(name, email)
    
    def get_user_by_email(self, email):
        # 简化实现
        return None

# 测试代码
class TestUserRepository:
    
    @pytest.fixture
    def mock_db(self):
        return Mock()
    
    @pytest.fixture
    def repo(self, mock_db):
        return UserRepository(mock_db)
    
    def test_get_user(self, repo, mock_db):
        mock_db.execute.return_value = [{"id": 1, "name": "张三"}]
        
        user = repo.get_user(1)
        
        assert user == {"id": 1, "name": "张三"}
        mock_db.execute.assert_called_once()
    
    def test_get_user_not_found(self, repo, mock_db):
        mock_db.execute.return_value = []
        
        user = repo.get_user(999)
        
        assert user is None
    
    def test_create_user(self, repo, mock_db):
        mock_db.lastrowid = 1
        
        user_id = repo.create_user("张三", "test@example.com")
        
        assert user_id == 1
        mock_db.execute.assert_called_once()

class TestUserService:
    
    @pytest.fixture
    def mock_repo(self):
        return Mock()
    
    @pytest.fixture
    def service(self, mock_repo):
        return UserService(mock_repo)
    
    def test_register_new_user(self, service, mock_repo):
        mock_repo.create_user.return_value = 1
        service.get_user_by_email = Mock(return_value=None)
        
        user_id = service.register_user("张三", "new@example.com")
        
        assert user_id == 1
        mock_repo.create_user.assert_called_once_with("张三", "new@example.com")
    
    def test_register_duplicate_email(self, service, mock_repo):
        service.get_user_by_email = Mock(return_value={"id": 1})
        
        with pytest.raises(ValueError, match="邮箱已存在"):
            service.register_user("张三", "existing@example.com")

模拟异步代码

import pytest
from unittest.mock import AsyncMock, Mock, patch
import asyncio

# 被测试的异步代码
class AsyncService:
    def __init__(self, api_client):
        self.api_client = api_client
    
    async def fetch_data(self, url):
        response = await self.api_client.get(url)
        return response.json()

# 测试异步代码
class TestAsyncService:
    
    @pytest.fixture
    def mock_client(self):
        return AsyncMock()
    
    @pytest.fixture
    def service(self, mock_client):
        return AsyncService(mock_client)
    
    @pytest.mark.asyncio
    async def test_fetch_data(self, service, mock_client):
        mock_response = AsyncMock()
        mock_response.json.return_value = {"data": "test"}
        mock_client.get.return_value = mock_response
        
        result = await service.fetch_data("https://api.example.com")
        
        assert result == {"data": "test"}
        mock_client.get.assert_called_once_with("https://api.example.com")

验证Mock调用

from unittest.mock import Mock, call

def test_call_verification():
    mock = Mock()
    
    # 调用mock
    mock("arg1", "arg2")
    mock("arg3", key="value")
    mock.method()
    
    # 验证调用次数
    assert mock.call_count == 3
    assert mock.call_args_list == [
        call("arg1", "arg2"),
        call("arg3", key="value"),
        call.method()
    ]
    
    # 验证特定调用
    mock.assert_any_call("arg1", "arg2")
    
    # 验证未被调用
    other_mock = Mock()
    other_mock.assert_not_called()

# 测试顺序
def test_call_order():
    mock = Mock()
    
    mock.first()
    mock.second()
    mock.third()
    
    expected_order = [
        call.first(),
        call.second(),
        call.third()
    ]
    mock.assert_has_calls(expected_order, any_order=False)

实战示例:API客户端测试

from unittest.mock import Mock, patch, MagicMock
import requests
import pytest

class APIClient:
    def __init__(self, base_url):
        self.base_url = base_url
        self.session = requests.Session()
    
    def get(self, endpoint, params=None):
        url = f"{self.base_url}/{endpoint}"
        response = self.session.get(url, params=params)
        response.raise_for_status()
        return response.json()
    
    def post(self, endpoint, data=None):
        url = f"{self.base_url}/{endpoint}"
        response = self.session.post(url, json=data)
        response.raise_for_status()
        return response.json()

class TestAPIClient:
    
    @pytest.fixture
    def client(self):
        return APIClient("https://api.example.com")
    
    @patch('requests.Session')
    def test_get_success(self, MockSession, client):
        mock_session = Mock()
        mock_response = Mock()
        mock_response.json.return_value = {"id": 1, "name": "test"}
        mock_session.get.return_value = mock_response
        client.session = mock_session
        
        result = client.get("users/1")
        
        assert result == {"id": 1, "name": "test"}
        mock_session.get.assert_called_once_with(
            "https://api.example.com/users/1",
            params=None
        )
    
    @patch('requests.Session')
    def test_get_with_params(self, MockSession, client):
        mock_session = Mock()
        mock_response = Mock()
        mock_response.json.return_value = {"results": []}
        mock_session.get.return_value = mock_response
        client.session = mock_session
        
        result = client.get("users", params={"page": 1})
        
        mock_session.get.assert_called_once_with(
            "https://api.example.com/users",
            params={"page": 1}
        )
    
    @patch('requests.Session')
    def test_post_success(self, MockSession, client):
        mock_session = Mock()
        mock_response = Mock()
        mock_response.json.return_value = {"id": 1}
        mock_session.post.return_value = mock_response
        client.session = mock_session
        
        result = client.post("users", data={"name": "new user"})
        
        assert result == {"id": 1}
        mock_session.post.assert_called_once()
    
    @patch('requests.Session')
    def test_get_error(self, MockSession, client):
        mock_session = Mock()
        mock_response = Mock()
        mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found")
        mock_session.get.return_value = mock_response
        client.session = mock_session
        
        with pytest.raises(requests.exceptions.HTTPError):
            client.get("nonexistent")

总结

Mock是单元测试中隔离依赖的利器。正确使用Mock可以让你的测试更加稳定、快速和可靠。记住:只Mock必要的依赖,不要过度Mock导致测试失去意义。