← 返回首页
🧠

OAuth与LLM集成

📂 llm ⏱ 3 min 460 words

--- title: "OAuth与LLM集成" description: "讲解如何将OAuth 2.0认证协议与LLM服务集成,实现安全的第三方授权访问" tags: ["OAuth", "授权", "第三方登录"] category: "llm" icon: "🧠"

OAuth与LLM集成

为什么使用OAuth

当用户需要通过自己的LLM账户授权第三方应用访问时,OAuth 2.0提供了安全的授权机制。它允许应用在不暴露用户凭证的情况下获取有限的访问权限。

OAuth 2.0基础流程

用户 → 第三方应用 → LLM提供商授权服务器 → 获取访问令牌 → 访问LLM API

实现OAuth客户端

授权码流程

from urllib.parse import urlencode
from datetime import datetime, timedelta
import secrets
import requests

class OAuthLLMClient:
    def __init__(self, client_id, client_secret, redirect_uri):
        self.client_id = client_id
        self.client_secret = client_secret
        self.redirect_uri = redirect_uri
        self.authorization_url = "https://auth.llm-provider.com/authorize"
        self.token_url = "https://auth.llm-provider.com/token"
        self.state_tokens = {}
    
    def get_authorization_url(self, scopes: list) -> tuple:
        # 生成CSRF防护的状态值
        state = secrets.token_urlsafe(32)
        self.state_tokens[state] = {
            "created_at": datetime.now(),
            "expires_at": datetime.now() + timedelta(minutes=10)
        }
        
        params = {
            "response_type": "code",
            "client_id": self.client_id,
            "redirect_uri": self.redirect_uri,
            "scope": " ".join(scopes),
            "state": state
        }
        
        auth_url = f"{self.authorization_url}?{urlencode(params)}"
        return auth_url, state
    
    def exchange_code_for_token(self, authorization_code: str, 
                                 state: str) -> dict:
        # 验证状态值防止CSRF
        if state not in self.state_tokens:
            raise ValueError("无效的状态值")
        
        state_info = self.state_tokens.pop(state)
        if datetime.now() > state_info["expires_at"]:
            raise ValueError("授权码已过期")
        
        # 交换访问令牌
        token_data = {
            "grant_type": "authorization_code",
            "code": authorization_code,
            "redirect_uri": self.redirect_uri,
            "client_id": self.client_id,
            "client_secret": self.client_secret
        }
        
        response = requests.post(self.token_url, data=token_data)
        
        if response.status_code != 200:
            raise Exception(f"获取令牌失败: {response.json()}")
        
        return response.json()

令牌管理

令牌存储和刷新

import json
from pathlib import Path

class TokenManager:
    def __init__(self, token_storage_path: str = ".oauth_tokens.json"):
        self.storage_path = Path(token_storage_path)
        self.tokens = self._load_tokens()
    
    def _load_tokens(self) -> dict:
        if self.storage_path.exists():
            with open(self.storage_path, "r") as f:
                return json.load(f)
        return {}
    
    def save_token(self, user_id: str, token_data: dict):
        self.tokens[user_id] = {
            "access_token": token_data["access_token"],
            "refresh_token": token_data.get("refresh_token"),
            "expires_at": (datetime.now() + 
                          timedelta(seconds=token_data.get("expires_in", 3600))).isoformat(),
            "token_type": token_data.get("token_type", "Bearer")
        }
        
        with open(self.storage_path, "w") as f:
            json.dump(self.tokens, f, indent=2)
    
    def get_valid_token(self, user_id: str, 
                        refresh_func) -> str:
        if user_id not in self.tokens:
            raise ValueError("用户未授权")
        
        token_info = self.tokens[user_id]
        expires_at = datetime.fromisoformat(token_info["expires_at"])
        
        # 令牌未过期,直接返回
        if datetime.now() < expires_at:
            return token_info["access_token"]
        
        # 尝试刷新令牌
        if token_info.get("refresh_token"):
            new_token = refresh_func(token_info["refresh_token"])
            self.save_token(user_id, new_token)
            return new_token["access_token"]
        
        raise ValueError("访问令牌已过期,需要重新授权")
    
    def revoke_token(self, user_id: str):
        if user_id in self.tokens:
            del self.tokens[user_id]
            with open(self.storage_path, "w") as f:
                json.dump(self.tokens, f, indent=2)

安全最佳实践

PKCE扩展

公共客户端应使用PKCE增强安全性。

import hashlib
import base64
import secrets

class PKCEHelper:
    @staticmethod
    def generate_verifier() -> str:
        # 生成43-128字符的随机字符串
        return secrets.token_urlsafe(64)
    
    @staticmethod
    def generate_challenge(verifier: str) -> str:
        # SHA256哈希
        digest = hashlib.sha256(verifier.encode()).digest()
        
        # Base64 URL安全编码
        challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
        return challenge
    
    @staticmethod
    def get_auth_params(client_id: str, redirect_uri: str,
                       scopes: list, code_verifier: str) -> dict:
        code_challenge = PKCEHelper.generate_challenge(code_verifier)
        
        return {
            "response_type": "code",
            "client_id": client_id,
            "redirect_uri": redirect_uri,
            "scope": " ".join(scopes),
            "code_challenge": code_challenge,
            "code_challenge_method": "S256",
            "state": secrets.token_urlsafe(32)
        }

作用域限制

只请求必要的最小权限。

class LLMScopeManager:
    # LLM相关的OAuth作用域定义
    SCOPES = {
        "models:read": "读取可用模型列表",
        "chat:write": "发送聊天请求",
        "completions:write": "生成补全",
        "usage:read": "查看使用统计",
        "keys:manage": "管理API密钥"
    }
    
    @classmethod
    def get_minimal_scopes(cls, use_case: str) -> list:
        """根据使用场景返回最小必需作用域"""
        minimal_scopes = {
            "chat_app": ["chat:write", "models:read"],
            "code_generator": ["completions:write", "models:read"],
            "analytics": ["usage:read"],
            "admin": list(cls.SCOPES.keys())
        }
        
        return minimal_scopes.get(use_case, ["chat:write"])
    
    @classmethod
    def validate_scopes(cls, requested: list, allowed: list) -> bool:
        return all(scope in allowed for scope in requested)

处理OAuth错误

class OAuthErrorHandler:
    @staticmethod
    def handle_error(error_response: dict) -> str:
        error_codes = {
            "invalid_request": "请求参数错误",
            "unauthorized_client": "客户端未授权",
            "access_denied": "用户拒绝授权",
            "unsupported_response_type": "不支持的响应类型",
            "invalid_scope": "请求的作用域无效",
            "server_error": "授权服务器内部错误",
            "temporarily_unavailable": "授权服务器暂时不可用"
        }
        
        error_code = error_response.get("error", "unknown")
        error_desc = error_response.get("error_description", "")
        
        message = error_codes.get(error_code, f"未知错误: {error_code}")
        if error_desc:
            message += f" - {error_desc}"
        
        return message

OAuth为LLM服务提供了安全、标准化的授权机制,特别适合需要用户授权的第三方应用集成场景。