← 返回首页
🧠

JWT在LLM中的应用

📂 llm ⏱ 3 min 537 words

--- title: "JWT在LLM中的应用" description: "详解JWT令牌在LLM应用中的实现,包括生成、验证、刷新等最佳实践" tags: ["JWT", "令牌", "身份验证"] category: "llm" icon: "🧠"

JWT在LLM中的应用

JWT简介

JSON Web Token(JWT)是一种开放标准(RFC 7519),用于在各方之间安全地传输信息。在LLM应用中,JWT常用于服务间认证和用户会话管理。

JWT结构

JWT由三部分组成:Header(头部)、Payload(载荷)、Signature(签名)。

import base64
import json
from datetime import datetime, timedelta

class SimpleJWT:
    def __init__(self, secret_key: str, algorithm: str = "HS256"):
        self.secret_key = secret_key
        self.algorithm = algorithm
    
    def base64url_encode(self, data: bytes) -> str:
        return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
    
    def base64url_decode(self, data: str) -> bytes:
        padding = 4 - len(data) % 4
        if padding != 4:
            data += "=" * padding
        return base64.urlsafe_b64decode(data)
    
    def encode(self, payload: dict, expires_delta: timedelta = None) -> str:
        import hmac
        import hashlib
        
        # 设置标准声明
        if "exp" not in payload:
            expires = datetime.utcnow() + (expires_delta or timedelta(hours=1))
            payload["exp"] = int(expires.timestamp())
        
        if "iat" not in payload:
            payload["iat"] = int(datetime.utcnow().timestamp())
        
        # 创建头部
        header = {"alg": self.algorithm, "typ": "JWT"}
        
        # 编码头部和载荷
        header_encoded = self.base64url_encode(json.dumps(header).encode())
        payload_encoded = self.base64url_encode(json.dumps(payload).encode())
        
        # 创建签名
        message = f"{header_encoded}.{payload_encoded}".encode()
        signature = hmac.new(
            self.secret_key.encode(),
            message,
            hashlib.sha256
        ).digest()
        
        signature_encoded = self.base64url_encode(signature)
        
        return f"{header_encoded}.{payload_encoded}.{signature_encoded}"
    
    def decode(self, token: str) -> dict:
        import hmac
        import hashlib
        
        parts = token.split(".")
        if len(parts) != 3:
            raise ValueError("无效的JWT格式")
        
        header_encoded, payload_encoded, signature_encoded = parts
        
        # 验证签名
        message = f"{header_encoded}.{payload_encoded}".encode()
        expected_signature = hmac.new(
            self.secret_key.encode(),
            message,
            hashlib.sha256
        ).digest()
        
        actual_signature = self.base64url_decode(signature_encoded)
        
        if not hmac.compare_digest(expected_signature, actual_signature):
            raise ValueError("签名验证失败")
        
        # 解码载荷
        payload = json.loads(self.base64url_decode(payload_encoded))
        
        # 检查过期时间
        if "exp" in payload:
            if datetime.utcnow().timestamp() > payload["exp"]:
                raise ValueError("令牌已过期")
        
        return payload

LLM应用中的JWT使用

服务间认证

from datetime import datetime, timedelta
from typing import Optional

class LLMServiceAuth:
    def __init__(self, secret_key: str, issuer: str):
        self.jwt = SimpleJWT(secret_key)
        self.issuer = issuer
    
    def create_service_token(self, service_name: str, 
                             permissions: list,
                             expires_in: timedelta = timedelta(hours=1)) -> str:
        payload = {
            "iss": self.issuer,
            "sub": service_name,
            "permissions": permissions,
            "type": "service"
        }
        
        return self.jwt.encode(payload, expires_in)
    
    def validate_service_token(self, token: str, 
                               required_permission: str) -> dict:
        try:
            payload = self.jwt.decode(token)
            
            # 验证发行者
            if payload.get("iss") != self.issuer:
                raise ValueError("无效的发行者")
            
            # 验证权限
            permissions = payload.get("permissions", [])
            if required_permission not in permissions:
                raise PermissionError(f"缺少必要权限: {required_permission}")
            
            return payload
            
        except ValueError as e:
            raise ValueError(f"令牌验证失败: {e}")

用户会话管理

class LLMUserSession:
    def __init__(self, secret_key: str, access_token_ttl: int = 3600,
                 refresh_token_ttl: int = 604800):
        self.jwt = SimpleJWT(secret_key)
        self.access_token_ttl = access_token_ttl
        self.refresh_token_ttl = refresh_token_ttl
        self.refresh_tokens = {}
    
    def create_tokens(self, user_id: str, roles: list) -> dict:
        # 创建访问令牌
        access_payload = {
            "user_id": user_id,
            "roles": roles,
            "type": "access"
        }
        access_token = self.jwt.encode(
            access_payload, 
            timedelta(seconds=self.access_token_ttl)
        )
        
        # 创建刷新令牌
        refresh_payload = {
            "user_id": user_id,
            "type": "refresh",
            "jti": secrets.token_hex(16)  # 唯一标识符
        }
        refresh_token = self.jwt.encode(
            refresh_payload,
            timedelta(seconds=self.refresh_token_ttl)
        )
        
        # 存储刷新令牌的jti
        refresh_payload_decoded = self.jwt.decode(refresh_token)
        self.refresh_tokens[refresh_payload_decoded["jti"]] = {
            "user_id": user_id,
            "expires_at": datetime.utcnow() + timedelta(seconds=self.refresh_token_ttl),
            "is_revoked": False
        }
        
        return {
            "access_token": access_token,
            "refresh_token": refresh_token,
            "token_type": "Bearer",
            "expires_in": self.access_token_ttl
        }
    
    def refresh_access_token(self, refresh_token: str) -> dict:
        try:
            payload = self.jwt.decode(refresh_token)
            
            if payload.get("type") != "refresh":
                raise ValueError("无效的刷新令牌类型")
            
            jti = payload.get("jti")
            if jti not in self.refresh_tokens:
                raise ValueError("刷新令牌不存在")
            
            token_info = self.refresh_tokens[jti]
            if token_info["is_revoked"]:
                raise ValueError("刷新令牌已被撤销")
            
            # 创建新的访问令牌
            access_payload = {
                "user_id": payload["user_id"],
                "roles": self._get_user_roles(payload["user_id"]),
                "type": "access"
            }
            
            return {
                "access_token": self.jwt.encode(
                    access_payload,
                    timedelta(seconds=self.access_token_ttl)
                ),
                "token_type": "Bearer",
                "expires_in": self.access_token_ttl
            }
            
        except ValueError as e:
            raise ValueError(f"刷新令牌验证失败: {e}")
    
    def revoke_token(self, jti: str):
        if jti in self.refresh_tokens:
            self.refresh_tokens[jti]["is_revoked"] = True
    
    def _get_user_roles(self, user_id: str) -> list:
        # 从数据库获取用户角色
        return ["user"]

JWT安全最佳实践

使用强密钥

def generate_strong_secret(length: int = 64) -> str:
    return secrets.token_urlsafe(length)

设置合理的过期时间

TOKEN_CONFIGS = {
    "access_token": {
        "ttl": 15 * 60,  # 15分钟
        "description": "短生命周期,用于API访问"
    },
    "refresh_token": {
        "ttl": 7 * 24 * 60 * 60,  # 7天
        "description": "长生命周期,用于获取新访问令牌"
    },
    "service_token": {
        "ttl": 60 * 60,  # 1小时
        "description": "服务间调用"
    }
}

实现令牌黑名单

class JWTBlacklist:
    def __init__(self):
        self.blacklisted = set()
    
    def revoke(self, jti: str):
        self.blacklisted.add(jti)
    
    def is_revoked(self, jti: str) -> bool:
        return jti in self.blacklisted

JWT为LLM应用提供了灵活、安全的身份认证机制,特别适合微服务架构和无状态应用。