Mock与测试替身:unittest.mock、patch与MagicMock
Mock与测试替身:unittest.mock、patch与MagicMock
在单元测试中,我们经常需要隔离外部依赖(数据库、API、文件系统等)。Mock对象可以模拟这些依赖,让测试专注于被测试代码的逻辑。本文将深入讲解Python的Mock机制。
什么是测试替身
测试替身(Test Double)是在测试中替代真实依赖的对象:
- Dummy:占位对象,不执行任何操作
- Stub:返回预设值的对象
- Spy:记录调用信息的包装器
- Mock:预设行为并验证调用的对象
- Fake:简化版的真实实现
from unittest.mock import Mock, MagicMock
# 创建Mock对象
mock = Mock()
print(mock) # <Mock id='...'>
print(mock.some_method) # <Mock name='mock.some_method' id='...'>
# 设置返回值
mock.some_method.return_value = 42
print(mock.some_method()) # 42
# 设置参数化返回值
mock.side_effect = [1, 2, 3]
print(mock()) # 1
print(mock()) # 2
print(mock()) # 3
# 异常行为
mock.side_effect = ValueError("测试错误")
try:
mock()
except ValueError as e:
print(f"捕获异常: {e}")
使用patch隔离依赖
import requests
from unittest.mock import patch, Mock
# 被测试的代码
def get_user_info(user_id):
response = requests.get(f"https://api.example.com/users/{user_id}")
if response.status_code == 200:
return response.json()
return None
# 测试代码 - 使用patch
class TestGetUserInfo:
@patch('requests.get')
def test_get_user_success(self, mock_get):
# 设置mock行为
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"id": 1, "name": "张三"}
mock_get.return_value = mock_response
# 调用被测试函数
result = get_user_info(1)
# 验证结果
assert result == {"id": 1, "name": "张三"}
mock_get.assert_called_once_with("https://api.example.com/users/1")
@patch('requests.get')
def test_get_user_not_found(self, mock_get):
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
result = get_user_info(999)
assert result is None
MagicMock详解
MagicMock是Mock的增强版,支持魔术方法:
from unittest.mock import MagicMock
# MagicMock支持魔术方法
mock = MagicMock()
# __len__
mock.__len__.return_value = 5
print(len(mock)) # 5
# __iter__
mock.__iter__.return_value = iter([1, 2, 3])
for item in mock:
print(item)
# __str__
mock.__str__.return_value = "Mock对象"
print(str(mock)) # Mock对象
# __contains__
mock.__contains__.return_value = True
print("item" in mock) # True
# __call__
mock.__call__.return_value = "被调用了"
print(mock()) # 被调用了
# 链式调用
mock.chain.method.return_value = 42
result = mock.chain.method()
print(result) # 42
patch的不同用法
from unittest.mock import patch, Mock
import os
# 1. 装饰器方式
@patch('os.path.exists')
def test_file_exists(mock_exists):
mock_exists.return_value = True
assert os.path.exists('test.txt') == True
# 2. 上下文管理器方式
def test_file_exists_context():
with patch('os.path.exists') as mock_exists:
mock_exists.return_value = False
assert os.path.exists('test.txt') == False
# 3. 手动启动/停止
def test_file_exists_manual():
mock_exists = patch('os.path.exists')
mock_exists.start()
# ... 测试代码 ...
mock_exists.stop()
# 4. patch对象的属性
class MyClass:
def __init__(self):
self.value = 10
@patch.object(MyClass, 'value', 100)
def test_patch_attribute():
obj = MyClass()
assert obj.value == 100
# 5. patch类的返回值
class Service:
def process(self, data):
return data * 2
@patch.object(Service, 'process')
def test_patch_class_method(mock_process):
mock_process.return_value = 100
service = Service()
result = service.process(50)
assert result == 100
模拟文件操作
from unittest.mock import patch, mock_open, Mock
import json
# 被测试的代码
def read_config(filepath):
with open(filepath, 'r') as f:
return json.load(f)
def write_log(filepath, message):
with open(filepath, 'a') as f:
f.write(message + '\n')
# 测试文件读取
def test_read_config():
config_data = '{"debug": true, "version": "1.0"}'
with patch('builtins.open', mock_open(read_data=config_data)) as mock_file:
result = read_config('config.json')
assert result == {"debug": True, "version": "1.0"}
mock_file.assert_called_once_with('config.json', 'r')
# 测试文件写入
def test_write_log():
with patch('builtins.open', mock_open()) as mock_file:
write_log('app.log', '测试消息')
mock_file.assert_called_once_with('app.log', 'a')
mock_file().write.assert_called_once_with('测试消息\n')
模拟数据库操作
from unittest.mock import Mock, patch, MagicMock
import pytest
# 被测试的代码
class UserRepository:
def __init__(self, db):
self.db = db
def get_user(self, user_id):
result = self.db.execute("SELECT * FROM users WHERE id = %s", (user_id,))
if result:
return result[0]
return None
def create_user(self, name, email):
self.db.execute(
"INSERT INTO users (name, email) VALUES (%s, %s)",
(name, email)
)
return self.db.lastrowid
class UserService:
def __init__(self, user_repo):
self.user_repo = user_repo
def register_user(self, name, email):
existing = self.get_user_by_email(email)
if existing:
raise ValueError("邮箱已存在")
return self.user_repo.create_user(name, email)
def get_user_by_email(self, email):
# 简化实现
return None
# 测试代码
class TestUserRepository:
@pytest.fixture
def mock_db(self):
return Mock()
@pytest.fixture
def repo(self, mock_db):
return UserRepository(mock_db)
def test_get_user(self, repo, mock_db):
mock_db.execute.return_value = [{"id": 1, "name": "张三"}]
user = repo.get_user(1)
assert user == {"id": 1, "name": "张三"}
mock_db.execute.assert_called_once()
def test_get_user_not_found(self, repo, mock_db):
mock_db.execute.return_value = []
user = repo.get_user(999)
assert user is None
def test_create_user(self, repo, mock_db):
mock_db.lastrowid = 1
user_id = repo.create_user("张三", "test@example.com")
assert user_id == 1
mock_db.execute.assert_called_once()
class TestUserService:
@pytest.fixture
def mock_repo(self):
return Mock()
@pytest.fixture
def service(self, mock_repo):
return UserService(mock_repo)
def test_register_new_user(self, service, mock_repo):
mock_repo.create_user.return_value = 1
service.get_user_by_email = Mock(return_value=None)
user_id = service.register_user("张三", "new@example.com")
assert user_id == 1
mock_repo.create_user.assert_called_once_with("张三", "new@example.com")
def test_register_duplicate_email(self, service, mock_repo):
service.get_user_by_email = Mock(return_value={"id": 1})
with pytest.raises(ValueError, match="邮箱已存在"):
service.register_user("张三", "existing@example.com")
模拟异步代码
import pytest
from unittest.mock import AsyncMock, Mock, patch
import asyncio
# 被测试的异步代码
class AsyncService:
def __init__(self, api_client):
self.api_client = api_client
async def fetch_data(self, url):
response = await self.api_client.get(url)
return response.json()
# 测试异步代码
class TestAsyncService:
@pytest.fixture
def mock_client(self):
return AsyncMock()
@pytest.fixture
def service(self, mock_client):
return AsyncService(mock_client)
@pytest.mark.asyncio
async def test_fetch_data(self, service, mock_client):
mock_response = AsyncMock()
mock_response.json.return_value = {"data": "test"}
mock_client.get.return_value = mock_response
result = await service.fetch_data("https://api.example.com")
assert result == {"data": "test"}
mock_client.get.assert_called_once_with("https://api.example.com")
验证Mock调用
from unittest.mock import Mock, call
def test_call_verification():
mock = Mock()
# 调用mock
mock("arg1", "arg2")
mock("arg3", key="value")
mock.method()
# 验证调用次数
assert mock.call_count == 3
assert mock.call_args_list == [
call("arg1", "arg2"),
call("arg3", key="value"),
call.method()
]
# 验证特定调用
mock.assert_any_call("arg1", "arg2")
# 验证未被调用
other_mock = Mock()
other_mock.assert_not_called()
# 测试顺序
def test_call_order():
mock = Mock()
mock.first()
mock.second()
mock.third()
expected_order = [
call.first(),
call.second(),
call.third()
]
mock.assert_has_calls(expected_order, any_order=False)
实战示例:API客户端测试
from unittest.mock import Mock, patch, MagicMock
import requests
import pytest
class APIClient:
def __init__(self, base_url):
self.base_url = base_url
self.session = requests.Session()
def get(self, endpoint, params=None):
url = f"{self.base_url}/{endpoint}"
response = self.session.get(url, params=params)
response.raise_for_status()
return response.json()
def post(self, endpoint, data=None):
url = f"{self.base_url}/{endpoint}"
response = self.session.post(url, json=data)
response.raise_for_status()
return response.json()
class TestAPIClient:
@pytest.fixture
def client(self):
return APIClient("https://api.example.com")
@patch('requests.Session')
def test_get_success(self, MockSession, client):
mock_session = Mock()
mock_response = Mock()
mock_response.json.return_value = {"id": 1, "name": "test"}
mock_session.get.return_value = mock_response
client.session = mock_session
result = client.get("users/1")
assert result == {"id": 1, "name": "test"}
mock_session.get.assert_called_once_with(
"https://api.example.com/users/1",
params=None
)
@patch('requests.Session')
def test_get_with_params(self, MockSession, client):
mock_session = Mock()
mock_response = Mock()
mock_response.json.return_value = {"results": []}
mock_session.get.return_value = mock_response
client.session = mock_session
result = client.get("users", params={"page": 1})
mock_session.get.assert_called_once_with(
"https://api.example.com/users",
params={"page": 1}
)
@patch('requests.Session')
def test_post_success(self, MockSession, client):
mock_session = Mock()
mock_response = Mock()
mock_response.json.return_value = {"id": 1}
mock_session.post.return_value = mock_response
client.session = mock_session
result = client.post("users", data={"name": "new user"})
assert result == {"id": 1}
mock_session.post.assert_called_once()
@patch('requests.Session')
def test_get_error(self, MockSession, client):
mock_session = Mock()
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("404 Not Found")
mock_session.get.return_value = mock_response
client.session = mock_session
with pytest.raises(requests.exceptions.HTTPError):
client.get("nonexistent")
总结
Mock是单元测试中隔离依赖的利器。正确使用Mock可以让你的测试更加稳定、快速和可靠。记住:只Mock必要的依赖,不要过度Mock导致测试失去意义。