工具使用
--- title: "工具使用" description: "掌握LLM工具使用的核心技术,包括API调用、插件系统和工具集成的最佳实践" tags: ["工具使用", "API调用", "插件系统", "工具集成"] category: "llm" icon: "🧠"
工具使用
什么是工具使用
工具使用是指大语言模型能够调用外部工具、API或函数来扩展其能力。通过工具使用,LLM可以获取实时信息、执行复杂计算、操作外部系统等,从而完成单纯依靠语言生成无法实现的任务。
核心架构
1. 工具注册与管理
from typing import Callable, Dict, Any, Optional
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
@dataclass
class ToolDefinition:
name: str
description: str
parameters: Dict[str, Any]
function: Callable
is_async: bool = False
class ToolRegistry:
def __init__(self):
self.tools: Dict[str, ToolDefinition] = {}
def register(self, name: str, description: str, parameters: Dict, func: Callable):
"""注册工具"""
self.tools[name] = ToolDefinition(
name=name,
description=description,
parameters=parameters,
function=func
)
def get_tool(self, name: str) -> Optional[ToolDefinition]:
"""获取工具"""
return self.tools.get(name)
def list_tools(self) -> List[Dict]:
"""列出所有工具"""
return [
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
for tool in self.tools.values()
]
def execute(self, name: str, **kwargs) -> Any:
"""执行工具"""
tool = self.get_tool(name)
if tool is None:
raise ValueError(f"Tool '{name}' not found")
return tool.function(**kwargs)
# 使用示例
registry = ToolRegistry()
def search_web(query: str) -> str:
"""搜索网页"""
return f"搜索结果: 关于'{query}'的信息"
registry.register(
name="search",
description="在互联网上搜索信息",
parameters={
"query": {
"type": "string",
"description": "搜索查询"
}
},
func=search_web
)
2. 参数解析
import json
from typing import Any, Dict
class ParameterParser:
def __init__(self, tool_registry: ToolRegistry):
self.registry = tool_registry
def parse_arguments(self, tool_name: str, raw_args: str) -> Dict[str, Any]:
"""解析工具参数"""
tool = self.registry.get_tool(tool_name)
if tool is None:
raise ValueError(f"Unknown tool: {tool_name}")
try:
# 尝试JSON解析
args = json.loads(raw_args)
except json.JSONDecodeError:
# 尝试简单解析
args = self._simple_parse(raw_args, tool.parameters)
# 验证参数
return self._validate_args(args, tool.parameters)
def _simple_parse(self, raw_args: str, parameters: Dict) -> Dict:
"""简单参数解析"""
args = {}
# 尝试 key=value 格式
for part in raw_args.split(","):
if "=" in part:
key, value = part.split("=", 1)
args[key.strip()] = value.strip().strip('"\'')
return args
def _validate_args(self, args: Dict, parameters: Dict) -> Dict:
"""验证参数"""
validated = {}
for param_name, param_info in parameters.items():
if param_name in args:
value = args[param_name]
# 类型转换
if param_info.get("type") == "integer":
value = int(value)
elif param_info.get("type") == "number":
value = float(value)
elif param_info.get("type") == "boolean":
value = value.lower() in ("true", "1", "yes")
validated[param_name] = value
elif param_info.get("required", False):
raise ValueError(f"Missing required parameter: {param_name}")
return validated
API调用集成
import requests
from typing import Dict, Any
class APICaller:
def __init__(self):
self.session = requests.Session()
self.base_urls = {}
def register_api(self, name: str, base_url: str, headers: Dict = None):
"""注册API"""
self.base_urls[name] = base_url
if headers:
self.session.headers.update(headers)
def call(self, api_name: str, endpoint: str,
method: str = "GET", **kwargs) -> Dict[str, Any]:
"""调用API"""
base_url = self.base_urls.get(api_name)
if not base_url:
raise ValueError(f"Unknown API: {api_name}")
url = f"{base_url}{endpoint}"
response = self.session.request(
method=method,
url=url,
**kwargs
)
response.raise_for_status()
return response.json()
def register_as_tool(self, registry: ToolRegistry, api_name: str,
tool_name: str, description: str,
endpoint: str, method: str = "GET"):
"""将API注册为工具"""
def api_func(**kwargs):
return self.call(api_name, endpoint, method, **kwargs)
registry.register(
name=tool_name,
description=description,
parameters={"params": {"type": "object"}},
func=api_func
)
# 使用示例
api_caller = APICaller()
api_caller.register_api(
"weather",
"https://api.weatherapi.com/v1",
headers={"key": "YOUR_API_KEY"}
)
插件系统设计
from abc import ABC, abstractmethod
from typing import List
class Plugin(ABC):
"""插件基类"""
@property
@abstractmethod
def name(self) -> str:
pass
@property
@abstractmethod
def description(self) -> str:
pass
@abstractmethod
def get_tools(self) -> List[ToolDefinition]:
pass
@abstractmethod
def initialize(self, config: Dict):
pass
class WeatherPlugin(Plugin):
"""天气插件"""
@property
def name(self) -> str:
return "weather"
@property
def description(self) -> str:
return "获取天气信息的插件"
def get_tools(self) -> List[ToolDefinition]:
return [
ToolDefinition(
name="get_weather",
description="获取指定城市的天气信息",
parameters={
"city": {"type": "string", "description": "城市名称"}
},
function=self.get_weather
)
]
def initialize(self, config: Dict):
self.api_key = config.get("api_key")
def get_weather(self, city: str) -> str:
"""获取天气"""
return f"{city}的天气: 晴天,温度25°C"
class PluginManager:
def __init__(self):
self.plugins: Dict[str, Plugin] = {}
self.registry = ToolRegistry()
def load_plugin(self, plugin: Plugin, config: Dict = None):
"""加载插件"""
plugin.initialize(config or {})
self.plugins[plugin.name] = plugin
# 注册插件的工具
for tool_def in plugin.get_tools():
self.registry.tools[tool_def.name] = tool_def
def get_tools_for_llm(self) -> List[Dict]:
"""获取适合LLM的工具描述"""
return self.registry.list_tools()
安全性考虑
import hashlib
import time
from functools import wraps
class ToolSecurity:
def __init__(self):
self.rate_limits = {}
self.audit_log = []
def rate_limit(self, max_calls: int, window_seconds: int):
"""速率限制装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
key = f"{func.__name__}:{hashlib.md5(str(args).encode()).hexdigest()}"
current_time = time.time()
# 清理过期记录
if key in self.rate_limits:
self.rate_limits[key] = [
t for t in self.rate_limits[key]
if current_time - t < window_seconds
]
else:
self.rate_limits[key] = []
# 检查速率限制
if len(self.rate_limits[key]) >= max_calls:
raise RuntimeError(f"Rate limit exceeded for {func.__name__}")
# 记录调用
self.rate_limits[key].append(current_time)
# 审计日志
self.audit_log.append({
"function": func.__name__,
"timestamp": current_time,
"args": str(args)[:100]
})
return func(*args, **kwargs)
return wrapper
return decorator
def validate_input(self, input_data: Any, max_length: int = 1000) -> bool:
"""验证输入"""
if isinstance(input_data, str):
if len(input_data) > max_length:
raise ValueError(f"Input too long: {len(input_data)} > {max_length}")
# 检查潜在的注入攻击
dangerous_patterns = ["<script>", "javascript:", "drop table"]
for pattern in dangerous_patterns:
if pattern in input_data.lower():
raise ValueError(f"Potentially dangerous input detected")
return True
总结
工具使用是扩展LLM能力的关键技术。通过合理设计工具注册系统、参数解析器和安全机制,可以构建强大且安全的AI Agent系统。