← 返回首页
🧠

工具使用

📂 llm ⏱ 4 min 670 words

--- title: "工具使用" description: "掌握LLM工具使用的核心技术,包括API调用、插件系统和工具集成的最佳实践" tags: ["工具使用", "API调用", "插件系统", "工具集成"] category: "llm" icon: "🧠"

工具使用

什么是工具使用

工具使用是指大语言模型能够调用外部工具、API或函数来扩展其能力。通过工具使用,LLM可以获取实时信息、执行复杂计算、操作外部系统等,从而完成单纯依靠语言生成无法实现的任务。

核心架构

1. 工具注册与管理

from typing import Callable, Dict, Any, Optional
from dataclasses import dataclass, field
from pydantic import BaseModel, Field

@dataclass
class ToolDefinition:
    name: str
    description: str
    parameters: Dict[str, Any]
    function: Callable
    is_async: bool = False

class ToolRegistry:
    def __init__(self):
        self.tools: Dict[str, ToolDefinition] = {}
    
    def register(self, name: str, description: str, parameters: Dict, func: Callable):
        """注册工具"""
        self.tools[name] = ToolDefinition(
            name=name,
            description=description,
            parameters=parameters,
            function=func
        )
    
    def get_tool(self, name: str) -> Optional[ToolDefinition]:
        """获取工具"""
        return self.tools.get(name)
    
    def list_tools(self) -> List[Dict]:
        """列出所有工具"""
        return [
            {
                "name": tool.name,
                "description": tool.description,
                "parameters": tool.parameters
            }
            for tool in self.tools.values()
        ]
    
    def execute(self, name: str, **kwargs) -> Any:
        """执行工具"""
        tool = self.get_tool(name)
        if tool is None:
            raise ValueError(f"Tool '{name}' not found")
        return tool.function(**kwargs)

# 使用示例
registry = ToolRegistry()

def search_web(query: str) -> str:
    """搜索网页"""
    return f"搜索结果: 关于'{query}'的信息"

registry.register(
    name="search",
    description="在互联网上搜索信息",
    parameters={
        "query": {
            "type": "string",
            "description": "搜索查询"
        }
    },
    func=search_web
)

2. 参数解析

import json
from typing import Any, Dict

class ParameterParser:
    def __init__(self, tool_registry: ToolRegistry):
        self.registry = tool_registry
    
    def parse_arguments(self, tool_name: str, raw_args: str) -> Dict[str, Any]:
        """解析工具参数"""
        tool = self.registry.get_tool(tool_name)
        if tool is None:
            raise ValueError(f"Unknown tool: {tool_name}")
        
        try:
            # 尝试JSON解析
            args = json.loads(raw_args)
        except json.JSONDecodeError:
            # 尝试简单解析
            args = self._simple_parse(raw_args, tool.parameters)
        
        # 验证参数
        return self._validate_args(args, tool.parameters)
    
    def _simple_parse(self, raw_args: str, parameters: Dict) -> Dict:
        """简单参数解析"""
        args = {}
        # 尝试 key=value 格式
        for part in raw_args.split(","):
            if "=" in part:
                key, value = part.split("=", 1)
                args[key.strip()] = value.strip().strip('"\'')
        return args
    
    def _validate_args(self, args: Dict, parameters: Dict) -> Dict:
        """验证参数"""
        validated = {}
        for param_name, param_info in parameters.items():
            if param_name in args:
                value = args[param_name]
                # 类型转换
                if param_info.get("type") == "integer":
                    value = int(value)
                elif param_info.get("type") == "number":
                    value = float(value)
                elif param_info.get("type") == "boolean":
                    value = value.lower() in ("true", "1", "yes")
                validated[param_name] = value
            elif param_info.get("required", False):
                raise ValueError(f"Missing required parameter: {param_name}")
        return validated

API调用集成

import requests
from typing import Dict, Any

class APICaller:
    def __init__(self):
        self.session = requests.Session()
        self.base_urls = {}
    
    def register_api(self, name: str, base_url: str, headers: Dict = None):
        """注册API"""
        self.base_urls[name] = base_url
        if headers:
            self.session.headers.update(headers)
    
    def call(self, api_name: str, endpoint: str, 
             method: str = "GET", **kwargs) -> Dict[str, Any]:
        """调用API"""
        base_url = self.base_urls.get(api_name)
        if not base_url:
            raise ValueError(f"Unknown API: {api_name}")
        
        url = f"{base_url}{endpoint}"
        
        response = self.session.request(
            method=method,
            url=url,
            **kwargs
        )
        response.raise_for_status()
        return response.json()
    
    def register_as_tool(self, registry: ToolRegistry, api_name: str,
                         tool_name: str, description: str,
                         endpoint: str, method: str = "GET"):
        """将API注册为工具"""
        def api_func(**kwargs):
            return self.call(api_name, endpoint, method, **kwargs)
        
        registry.register(
            name=tool_name,
            description=description,
            parameters={"params": {"type": "object"}},
            func=api_func
        )

# 使用示例
api_caller = APICaller()
api_caller.register_api(
    "weather",
    "https://api.weatherapi.com/v1",
    headers={"key": "YOUR_API_KEY"}
)

插件系统设计

from abc import ABC, abstractmethod
from typing import List

class Plugin(ABC):
    """插件基类"""
    
    @property
    @abstractmethod
    def name(self) -> str:
        pass
    
    @property
    @abstractmethod
    def description(self) -> str:
        pass
    
    @abstractmethod
    def get_tools(self) -> List[ToolDefinition]:
        pass
    
    @abstractmethod
    def initialize(self, config: Dict):
        pass

class WeatherPlugin(Plugin):
    """天气插件"""
    
    @property
    def name(self) -> str:
        return "weather"
    
    @property
    def description(self) -> str:
        return "获取天气信息的插件"
    
    def get_tools(self) -> List[ToolDefinition]:
        return [
            ToolDefinition(
                name="get_weather",
                description="获取指定城市的天气信息",
                parameters={
                    "city": {"type": "string", "description": "城市名称"}
                },
                function=self.get_weather
            )
        ]
    
    def initialize(self, config: Dict):
        self.api_key = config.get("api_key")
    
    def get_weather(self, city: str) -> str:
        """获取天气"""
        return f"{city}的天气: 晴天,温度25°C"

class PluginManager:
    def __init__(self):
        self.plugins: Dict[str, Plugin] = {}
        self.registry = ToolRegistry()
    
    def load_plugin(self, plugin: Plugin, config: Dict = None):
        """加载插件"""
        plugin.initialize(config or {})
        self.plugins[plugin.name] = plugin
        
        # 注册插件的工具
        for tool_def in plugin.get_tools():
            self.registry.tools[tool_def.name] = tool_def
    
    def get_tools_for_llm(self) -> List[Dict]:
        """获取适合LLM的工具描述"""
        return self.registry.list_tools()

安全性考虑

import hashlib
import time
from functools import wraps

class ToolSecurity:
    def __init__(self):
        self.rate_limits = {}
        self.audit_log = []
    
    def rate_limit(self, max_calls: int, window_seconds: int):
        """速率限制装饰器"""
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                key = f"{func.__name__}:{hashlib.md5(str(args).encode()).hexdigest()}"
                current_time = time.time()
                
                # 清理过期记录
                if key in self.rate_limits:
                    self.rate_limits[key] = [
                        t for t in self.rate_limits[key]
                        if current_time - t < window_seconds
                    ]
                else:
                    self.rate_limits[key] = []
                
                # 检查速率限制
                if len(self.rate_limits[key]) >= max_calls:
                    raise RuntimeError(f"Rate limit exceeded for {func.__name__}")
                
                # 记录调用
                self.rate_limits[key].append(current_time)
                
                # 审计日志
                self.audit_log.append({
                    "function": func.__name__,
                    "timestamp": current_time,
                    "args": str(args)[:100]
                })
                
                return func(*args, **kwargs)
            return wrapper
        return decorator
    
    def validate_input(self, input_data: Any, max_length: int = 1000) -> bool:
        """验证输入"""
        if isinstance(input_data, str):
            if len(input_data) > max_length:
                raise ValueError(f"Input too long: {len(input_data)} > {max_length}")
            # 检查潜在的注入攻击
            dangerous_patterns = ["<script>", "javascript:", "drop table"]
            for pattern in dangerous_patterns:
                if pattern in input_data.lower():
                    raise ValueError(f"Potentially dangerous input detected")
        return True

总结

工具使用是扩展LLM能力的关键技术。通过合理设计工具注册系统、参数解析器和安全机制,可以构建强大且安全的AI Agent系统。