← 返回首页
🧠

LLM单元测试

📂 llm ⏱ 2 min 381 words

--- title: "LLM单元测试" description: "深入讲解LLM系统的单元测试方法,包括提示模板测试、工具函数测试、参数校验测试、模拟对象使用以及测试用例设计" tags: ["单元测试", "测试方法", "LLM开发", "代码质量"] category: "llm" icon: "🧠"

LLM单元测试

单元测试在LLM中的角色

单元测试是测试金字塔的基石,针对LLM系统中的最小可测试单元进行验证。虽然LLM推理本身具有不确定性,但围绕它的大量代码是确定性的,非常适合单元测试。

提示模板测试

提示模板是LLM应用的核心组件,需要确保模板渲染的正确性。

import pytest
from jinja2 import Template

class PromptTemplate:
    def __init__(self, template_str):
        self.template = Template(template_str)
    
    def render(self, **kwargs):
        try:
            return self.template.render(**kwargs)
        except Exception as e:
            raise ValueError(f"Template rendering failed: {e}")

class TestPromptTemplate:
    def setup_method(self):
        self.template = PromptTemplate(
            "你是一个{role}。请用{language}回答以下问题:{question}"
        )
    
    def test_basic_rendering(self):
        result = self.template.render(
            role="翻译专家", language="中文", question="Hello"
        )
        assert "翻译专家" in result
        assert "中文" in result
        assert "Hello" in result
    
    def test_missing_variable(self):
        with pytest.raises(ValueError):
            self.template.render(role="专家")
    
    def test_special_characters(self):
        result = self.template.render(
            role="专家", language="中文", question="<script>alert('xss')</script>"
        )
        assert "<script>" not in result or "&lt;script&gt;" in result

工具函数测试

LLM应用通常包含大量工具函数,如文本处理、格式转换、API调用封装等。

class TextProcessor:
    @staticmethod
    def extract_entities(text):
        """从文本中提取实体"""
        import re
        patterns = {
            'email': r'[\w.-]+@[\w.-]+\.\w+',
            'phone': r'\d{3}-\d{4}-\d{4}',
        }
        entities = {}
        for entity_type, pattern in patterns.items():
            matches = re.findall(pattern, text)
            if matches:
                entities[entity_type] = matches
        return entities
    
    @staticmethod
    def truncate_text(text, max_length=100, suffix="..."):
        if len(text) <= max_length:
            return text
        return text[:max_length - len(suffix)] + suffix

class TestTextProcessor:
    def test_extract_email(self):
        text = "联系我 test@example.com 或 support@company.com"
        result = TextProcessor.extract_entities(text)
        assert 'email' in result
        assert len(result['email']) == 2
    
    def test_extract_no_entities(self):
        result = TextProcessor.extract_entities("今天天气很好")
        assert result == {}
    
    def test_truncate_short_text(self):
        result = TextProcessor.truncate_text("短文本", max_length=10)
        assert result == "短文本"
    
    def test_truncate_long_text(self):
        result = TextProcessor.truncate_text("a" * 200, max_length=50)
        assert len(result) == 50
        assert result.endswith("...")

参数校验测试

验证LLM请求参数的合法性:

class RequestValidator:
    MAX_TOKENS = 4096
    MAX_PROMPT_LENGTH = 32000
    
    def validate(self, prompt, max_tokens=None, temperature=None):
        errors = []
        if not prompt or not prompt.strip():
            errors.append("Prompt cannot be empty")
        if len(prompt) > self.MAX_PROMPT_LENGTH:
            errors.append(f"Prompt exceeds {self.MAX_PROMPT_LENGTH} chars")
        if max_tokens is not None:
            if max_tokens < 1 or max_tokens > self.MAX_TOKENS:
                errors.append(f"max_tokens must be 1-{self.MAX_TOKENS}")
        if temperature is not None:
            if temperature < 0 or temperature > 2:
                errors.append("temperature must be 0-2")
        return errors

class TestRequestValidator:
    def setup_method(self):
        self.validator = RequestValidator()
    
    def test_valid_request(self):
        errors = self.validator.validate("Hello", max_tokens=100, temperature=0.7)
        assert errors == []
    
    def test_empty_prompt(self):
        errors = self.validator.validate("")
        assert len(errors) == 1
    
    def test_invalid_temperature(self):
        errors = self.validator.validate("test", temperature=3.0)
        assert any("temperature" in e for e in errors)

Mock对象的使用

对于涉及外部API调用的代码,使用Mock隔离依赖:

from unittest.mock import Mock, patch

class LLMClient:
    def __init__(self, api_key):
        self.api_key = api_key
    
    def generate(self, prompt):
        response = requests.post(
            "https://api.example.com/generate",
            headers={"Authorization": f"Bearer {self.api_key}"},
            json={"prompt": prompt}
        )
        return response.json()["text"]

class TestLLMClient:
    @patch('requests.post')
    def test_generate_success(self, mock_post):
        mock_post.return_value.json.return_value = {"text": "响应内容"}
        client = LLMClient("test-key")
        result = client.generate("测试提示")
        assert result == "响应内容"
        mock_post.assert_called_once()
    
    @patch('requests.post')
    def test_generate_api_error(self, mock_post):
        mock_post.side_effect = Exception("API Error")
        client = LLMClient("test-key")
        with pytest.raises(Exception):
            client.generate("测试提示")

测试覆盖率目标

对于LLM应用代码,建议单元测试覆盖率达到80%以上。重点关注核心逻辑、边界条件和异常处理。使用pytest-cov等工具自动跟踪覆盖率,并在CI中设置覆盖率门槛。