LLM单元测试
--- title: "LLM单元测试" description: "深入讲解LLM系统的单元测试方法,包括提示模板测试、工具函数测试、参数校验测试、模拟对象使用以及测试用例设计" tags: ["单元测试", "测试方法", "LLM开发", "代码质量"] category: "llm" icon: "🧠"
LLM单元测试
单元测试在LLM中的角色
单元测试是测试金字塔的基石,针对LLM系统中的最小可测试单元进行验证。虽然LLM推理本身具有不确定性,但围绕它的大量代码是确定性的,非常适合单元测试。
提示模板测试
提示模板是LLM应用的核心组件,需要确保模板渲染的正确性。
import pytest
from jinja2 import Template
class PromptTemplate:
def __init__(self, template_str):
self.template = Template(template_str)
def render(self, **kwargs):
try:
return self.template.render(**kwargs)
except Exception as e:
raise ValueError(f"Template rendering failed: {e}")
class TestPromptTemplate:
def setup_method(self):
self.template = PromptTemplate(
"你是一个{role}。请用{language}回答以下问题:{question}"
)
def test_basic_rendering(self):
result = self.template.render(
role="翻译专家", language="中文", question="Hello"
)
assert "翻译专家" in result
assert "中文" in result
assert "Hello" in result
def test_missing_variable(self):
with pytest.raises(ValueError):
self.template.render(role="专家")
def test_special_characters(self):
result = self.template.render(
role="专家", language="中文", question="<script>alert('xss')</script>"
)
assert "<script>" not in result or "<script>" in result
工具函数测试
LLM应用通常包含大量工具函数,如文本处理、格式转换、API调用封装等。
class TextProcessor:
@staticmethod
def extract_entities(text):
"""从文本中提取实体"""
import re
patterns = {
'email': r'[\w.-]+@[\w.-]+\.\w+',
'phone': r'\d{3}-\d{4}-\d{4}',
}
entities = {}
for entity_type, pattern in patterns.items():
matches = re.findall(pattern, text)
if matches:
entities[entity_type] = matches
return entities
@staticmethod
def truncate_text(text, max_length=100, suffix="..."):
if len(text) <= max_length:
return text
return text[:max_length - len(suffix)] + suffix
class TestTextProcessor:
def test_extract_email(self):
text = "联系我 test@example.com 或 support@company.com"
result = TextProcessor.extract_entities(text)
assert 'email' in result
assert len(result['email']) == 2
def test_extract_no_entities(self):
result = TextProcessor.extract_entities("今天天气很好")
assert result == {}
def test_truncate_short_text(self):
result = TextProcessor.truncate_text("短文本", max_length=10)
assert result == "短文本"
def test_truncate_long_text(self):
result = TextProcessor.truncate_text("a" * 200, max_length=50)
assert len(result) == 50
assert result.endswith("...")
参数校验测试
验证LLM请求参数的合法性:
class RequestValidator:
MAX_TOKENS = 4096
MAX_PROMPT_LENGTH = 32000
def validate(self, prompt, max_tokens=None, temperature=None):
errors = []
if not prompt or not prompt.strip():
errors.append("Prompt cannot be empty")
if len(prompt) > self.MAX_PROMPT_LENGTH:
errors.append(f"Prompt exceeds {self.MAX_PROMPT_LENGTH} chars")
if max_tokens is not None:
if max_tokens < 1 or max_tokens > self.MAX_TOKENS:
errors.append(f"max_tokens must be 1-{self.MAX_TOKENS}")
if temperature is not None:
if temperature < 0 or temperature > 2:
errors.append("temperature must be 0-2")
return errors
class TestRequestValidator:
def setup_method(self):
self.validator = RequestValidator()
def test_valid_request(self):
errors = self.validator.validate("Hello", max_tokens=100, temperature=0.7)
assert errors == []
def test_empty_prompt(self):
errors = self.validator.validate("")
assert len(errors) == 1
def test_invalid_temperature(self):
errors = self.validator.validate("test", temperature=3.0)
assert any("temperature" in e for e in errors)
Mock对象的使用
对于涉及外部API调用的代码,使用Mock隔离依赖:
from unittest.mock import Mock, patch
class LLMClient:
def __init__(self, api_key):
self.api_key = api_key
def generate(self, prompt):
response = requests.post(
"https://api.example.com/generate",
headers={"Authorization": f"Bearer {self.api_key}"},
json={"prompt": prompt}
)
return response.json()["text"]
class TestLLMClient:
@patch('requests.post')
def test_generate_success(self, mock_post):
mock_post.return_value.json.return_value = {"text": "响应内容"}
client = LLMClient("test-key")
result = client.generate("测试提示")
assert result == "响应内容"
mock_post.assert_called_once()
@patch('requests.post')
def test_generate_api_error(self, mock_post):
mock_post.side_effect = Exception("API Error")
client = LLMClient("test-key")
with pytest.raises(Exception):
client.generate("测试提示")
测试覆盖率目标
对于LLM应用代码,建议单元测试覆盖率达到80%以上。重点关注核心逻辑、边界条件和异常处理。使用pytest-cov等工具自动跟踪覆盖率,并在CI中设置覆盖率门槛。