← 返回首页
🧠

SQL生成:AI生成数据库查询

📂 llm ⏱ 3 min 420 words

--- title: "SQL生成:AI生成数据库查询" description: "使用LLM自动生成SQL查询语句" tags: ["SQL生成", "数据库查询", "AI", "LLM", "数据分析"] category: "llm" icon: "🗃️"

SQL生成:AI生成数据库查询

SQL生成概述

SQL生成是利用LLM将自然语言问题转换为SQL查询的技术,让非技术人员也能轻松查询数据库。

核心功能

1. 自然语言到SQL

from openai import OpenAI
from typing import Dict, List

class NaturalLanguageToSQL:
    """自然语言到SQL转换"""
    
    def __init__(self, model: str = "gpt-4"):
        self.client = OpenAI()
        self.model = model
    
    def generate_sql(self, question: str, schema: Dict) -> str:
        """生成SQL"""
        prompt = f"""你是一个SQL专家。请根据以下数据库结构,将自然语言问题转换为SQL查询。

数据库结构:
{schema}

问题:{question}

请只返回SQL查询,不要包含解释。"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "你是一个SQL生成专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2
        )
        
        sql = response.choices[0].message.content.strip()
        # 清理可能的代码块标记
        sql = sql.replace("```sql", "").replace("```", "").strip()
        return sql
    
    def explain_sql(self, sql: str) -> str:
        """解释SQL"""
        prompt = f"""请解释以下SQL查询的功能:

{sql}

请提供简洁的解释。"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "你是一个SQL解释专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.3
        )
        
        return response.choices[0].message.content

2. SQL优化器

class SQLOptimizer:
    """SQL优化器"""
    
    def __init__(self, model: str = "gpt-4"):
        self.client = OpenAI()
        self.model = model
    
    def optimize_sql(self, sql: str, schema: Dict = None) -> Dict:
        """优化SQL"""
        schema_info = f"\n数据库结构:{schema}" if schema else ""
        
        prompt = f"""请优化以下SQL查询:{schema_info}

原始SQL:
{sql}

请提供:
1. 优化后的SQL
2. 优化说明
3. 建议的索引"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "你是一个SQL优化专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2
        )
        
        return self._parse_optimization_response(response.choices[0].message.content)
    
    def _parse_optimization_response(self, response: str) -> Dict:
        """解析优化响应"""
        # 简单解析
        parts = response.split("\n\n")
        
        return {
            "optimized_sql": parts[0] if parts else "",
            "explanation": parts[1] if len(parts) > 1 else "",
            "index_suggestions": parts[2] if len(parts) > 2 else ""
        }
    
    def suggest_indexes(self, sql: str, schema: Dict) -> str:
        """建议索引"""
        prompt = f"""基于以下SQL查询和数据库结构,建议合适的索引:

SQL:
{sql}

数据库结构:
{schema}

请提供索引建议。"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "你是一个数据库索引专家。"},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2
        )
        
        return response.choices[0].message.content

3. SQL验证器

class SQLValidator:
    """SQL验证器"""
    
    @staticmethod
    def validate_syntax(sql: str) -> Dict:
        """验证语法"""
        import re
        
        issues = []
        
        # 检查基本语法
        sql_upper = sql.upper()
        
        # 检查SELECT语句
        if not sql_upper.startswith("SELECT"):
            issues.append("SQL应该以SELECT开头")
        
        # 检查FROM子句
        if "FROM" not in sql_upper:
            issues.append("缺少FROM子句")
        
        # 检查引号匹配
        if sql.count("'") % 2 != 0:
            issues.append("引号不匹配")
        
        # 检查括号匹配
        if sql.count("(") != sql.count(")"):
            issues.append("括号不匹配")
        
        return {
            "is_valid": len(issues) == 0,
            "issues": issues
        }
    
    @staticmethod
    def check_security(sql: str) -> Dict:
        """检查安全性"""
        security_issues = []
        
        # 检查危险操作
        dangerous_patterns = [
            r"DROP\s+TABLE",
            r"DELETE\s+FROM",
            r"TRUNCATE",
            r"UPDATE.*SET",
            r"INSERT\s+INTO"
        ]
        
        import re
        for pattern in dangerous_patterns:
            if re.search(pattern, sql, re.IGNORECASE):
                security_issues.append(f"检测到危险操作: {pattern}")
        
        return {
            "is_safe": len(security_issues) == 0,
            "issues": security_issues
        }

使用示例

# 创建SQL生成器
nl2sql = NaturalLanguageToSQL()

# 定义数据库结构
schema = {
    "users": {
        "columns": ["id", "name", "email", "created_at"],
        "types": {"id": "INT", "name": "VARCHAR", "email": "VARCHAR", "created_at": "DATETIME"}
    },
    "orders": {
        "columns": ["id", "user_id", "product", "amount", "created_at"],
        "types": {"id": "INT", "user_id": "INT", "product": "VARCHAR", "amount": "DECIMAL", "created_at": "DATETIME"}
    }
}

# 生成SQL
sql = nl2sql.generate_sql("查询所有用户的订单总数", schema)
print(f"生成的SQL: {sql}")

# 解释SQL
explanation = nl2sql.explain_sql(sql)
print(f"SQL解释: {explanation}")

# 验证SQL
validation = SQLValidator.validate_syntax(sql)
print(f"语法验证: {validation}")

# 安全检查
security = SQLValidator.check_security(sql)
print(f"安全检查: {security}")

最佳实践

  1. 提供完整结构:提供完整的数据库结构信息
  2. 验证结果:始终验证生成的SQL
  3. 安全检查:检查SQL的安全性
  4. 性能优化:优化SQL性能

总结

SQL生成是让非技术人员也能轻松查询数据库的强大工具。通过LLM技术,可以快速将自然语言转换为SQL查询。