← 返回首页
🚀

Flask进阶:扩展、RESTful API、认证与数据库集成

📂 python ⏱ 6 min 1029 words

Flask进阶:扩展、RESTful API、认证与数据库集成

掌握Flask基础后,本文将带你深入学习Flask的高级特性,包括扩展系统、RESTful API设计、用户认证和数据库集成,让你能够构建生产级的Web应用。

Flask扩展系统

Flask通过扩展机制提供丰富的功能:

# 常用Flask扩展
"""
Flask-SQLAlchemy    - 数据库集成
Flask-Migrate       - 数据库迁移
Flask-Login         - 用户认证
Flask-WTF           - 表单处理
Flask-RESTful       - RESTful API
Flask-JWT-Extended  - JWT认证
Flask-CORS          - 跨域支持
Flask-Mail          - 邮件发送
Flask-Caching       - 缓存
"""

# 安装扩展
# pip install flask-sqlalchemy flask-migrate flask-login flask-cors flask-jwt-extended

扩展初始化模式

from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_login import LoginManager
from flask_cors import CORS

db = SQLAlchemy()
migrate = Migrate()
login_manager = LoginManager()
cors = CORS()

def create_app(config_name="development"):
    app = Flask(__name__)
    
    # 加载配置
    from config import config
    app.config.from_object(config[config_name])
    
    # 初始化扩展
    db.init_app(app)
    migrate.init_app(app, db)
    login_manager.init_app(app)
    cors.init_app(app)
    
    # 注册蓝图
    from .api import api_bp
    app.register_blueprint(api_bp, url_prefix="/api")
    
    from .auth import auth_bp
    app.register_blueprint(auth_bp, url_prefix="/auth")
    
    return app

RESTful API设计

使用Flask-RESTful

from flask import Flask
from flask_restful import Api, Resource, reqparse, fields, marshal_with
from flask_sqlalchemy import SQLAlchemy

app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///api.db"
db = SQLAlchemy(app)
api = Api(app)

# 数据模型
class User(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False)
    email = db.Column(db.String(120), unique=True, nullable=False)
    
    def to_dict(self):
        return {
            "id": self.id,
            "username": self.username,
            "email": self.email
        }

# 资源字段定义
user_fields = {
    "id": fields.Integer,
    "username": fields.String,
    "email": fields.String
}

# 请求解析器
user_parser = reqparse.RequestParser()
user_parser.add_argument("username", type=str, required=True, help="用户名不能为空")
user_parser.add_argument("email", type=str, required=True, help="邮箱不能为空")

# API资源
class UserResource(Resource):
    @marshal_with(user_fields)
    def get(self, user_id):
        user = User.query.get_or_404(user_id)
        return user
    
    def put(self, user_id):
        user = User.query.get_or_404(user_id)
        args = user_parser.parse_args()
        user.username = args["username"]
        user.email = args["email"]
        db.session.commit()
        return user.to_dict()
    
    def delete(self, user_id):
        user = User.query.get_or_404(user_id)
        db.session.delete(user)
        db.session.commit()
        return {"message": "用户已删除"}, 204

class UserListResource(Resource):
    @marshal_with(user_fields)
    def get(self):
        return User.query.all()
    
    def post(self):
        args = user_parser.parse_args()
        user = User(username=args["username"], email=args["email"])
        db.session.add(user)
        db.session.commit()
        return user.to_dict(), 201

# 注册路由
api.add_resource(UserResource, "/api/users/<int:user_id>")
api.add_resource(UserListResource, "/api/users")

RESTful最佳实践

from flask import Blueprint, jsonify, request
from functools import wraps

api = Blueprint("api", __name__)

# 统一响应格式
def success_response(data=None, message="success", status_code=200):
    response = {
        "code": status_code,
        "message": message,
        "data": data
    }
    return jsonify(response), status_code

def error_response(message="error", status_code=400, errors=None):
    response = {
        "code": status_code,
        "message": message,
        "errors": errors
    }
    return jsonify(response), status_code

# 分页装饰器
def paginate(query, page=1, per_page=10):
    pagination = query.paginate(page=page, per_page=per_page, error_out=False)
    return {
        "items": [item.to_dict() for item in pagination.items],
        "total": pagination.total,
        "page": pagination.page,
        "pages": pagination.pages,
        "has_next": pagination.has_next,
        "has_prev": pagination.has_prev
    }

# API示例
@api.route("/articles", methods=["GET"])
def get_articles():
    page = request.args.get("page", 1, type=int)
    per_page = request.args.get("per_page", 10, type=int)
    
    # 查询参数过滤
    query = Article.query
    
    category = request.args.get("category")
    if category:
        query = query.filter_by(category=category)
    
    keyword = request.args.get("q")
    if keyword:
        query = query.filter(Article.title.contains(keyword))
    
    # 排序
    sort = request.args.get("sort", "created_at")
    order = request.args.get("order", "desc")
    if order == "desc":
        query = query.order_by(getattr(Article, sort).desc())
    else:
        query = query.order_by(getattr(Article, sort).asc())
    
    result = paginate(query, page, per_page)
    return success_response(result)

@api.route("/articles", methods=["POST"])
def create_article():
    data = request.get_json()
    
    # 数据验证
    if not data or not data.get("title"):
        return error_response("标题不能为空")
    
    article = Article(
        title=data["title"],
        content=data.get("content", ""),
        author_id=data.get("author_id")
    )
    db.session.add(article)
    db.session.commit()
    
    return success_response(article.to_dict(), "创建成功", 201)

用户认证

JWT认证实现

from flask import Flask, request, jsonify
from flask_jwt_extended import (
    JWTManager, create_access_token, create_refresh_token,
    jwt_required, get_jwt_identity, get_jwt
)
from werkzeug.security import generate_password_hash, check_password_hash
from datetime import timedelta

app = Flask(__name__)
app.config["JWT_SECRET_KEY"] = "your-secret-key"
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)

jwt = JWTManager(app)

# 简化的用户存储
users_db = {}

class User:
    def __init__(self, username, email, password):
        self.id = len(users_db) + 1
        self.username = username
        self.email = email
        self.password_hash = generate_password_hash(password)
    
    def check_password(self, password):
        return check_password_hash(self.password_hash, password)
    
    def to_dict(self):
        return {"id": self.id, "username": self.username, "email": self.email}

@app.route("/auth/register", methods=["POST"])
def register():
    data = request.get_json()
    
    if users_db.get(data["username"]):
        return jsonify({"error": "用户名已存在"}), 400
    
    user = User(data["username"], data["email"], data["password"])
    users_db[user.username] = user
    
    access_token = create_access_token(identity=user.id)
    refresh_token = create_refresh_token(identity=user.id)
    
    return jsonify({
        "user": user.to_dict(),
        "access_token": access_token,
        "refresh_token": refresh_token
    }), 201

@app.route("/auth/login", methods=["POST"])
def login():
    data = request.get_json()
    user = users_db.get(data.get("username"))
    
    if not user or not user.check_password(data.get("password")):
        return jsonify({"error": "用户名或密码错误"}), 401
    
    access_token = create_access_token(identity=user.id)
    refresh_token = create_refresh_token(identity=user.id)
    
    return jsonify({
        "user": user.to_dict(),
        "access_token": access_token,
        "refresh_token": refresh_token
    })

@app.route("/auth/refresh", methods=["POST"])
@jwt_required(refresh=True)
def refresh():
    identity = get_jwt_identity()
    access_token = create_access_token(identity=identity)
    return jsonify({"access_token": access_token})

@app.route("/auth/me", methods=["GET"])
@jwt_required()
def get_current_user():
    user_id = get_jwt_identity()
    # 查询用户
    return jsonify({"user_id": user_id})

# JWT错误处理
@jwt.expired_token_loader
def expired_token_callback(jwt_header, jwt_payload):
    return jsonify({"error": "令牌已过期"}), 401

@jwt.invalid_token_loader
def invalid_token_callback(error):
    return jsonify({"error": "无效的令牌"}), 401

角色权限控制

from functools import wraps
from flask_jwt_extended import get_jwt

def role_required(*roles):
    """角色验证装饰器"""
    def decorator(fn):
        @wraps(fn)
        @jwt_required()
        def wrapper(*args, **kwargs):
            claims = get_jwt()
            user_role = claims.get("role")
            if user_role not in roles:
                return jsonify({"error": "权限不足"}), 403
            return fn(*args, **kwargs)
        return wrapper
    return decorator

@app.route("/admin/users", methods=["GET"])
@role_required("admin", "superadmin")
def admin_users():
    return jsonify({"users": []})

@app.route("/api/data", methods=["GET"])
@role_required("user", "admin", "superadmin")
def get_data():
    return jsonify({"data": "sensitive data"})

数据库集成

Flask-SQLAlchemy完整示例

from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from datetime import datetime

app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///app.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False

db = SQLAlchemy(app)
migrate = Migrate(app, db)

# 模型定义
class User(db.Model):
    __tablename__ = "users"
    
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False)
    email = db.Column(db.String(120), unique=True, nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    posts = db.relationship("Post", backref="author", lazy="dynamic")
    
    def to_dict(self):
        return {
            "id": self.id,
            "username": self.username,
            "email": self.email,
            "post_count": self.posts.count()
        }

class Post(db.Model):
    __tablename__ = "posts"
    
    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(200), nullable=False)
    content = db.Column(db.Text)
    user_id = db.Column(db.Integer, db.ForeignKey("users.id"))
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    
    def to_dict(self):
        return {
            "id": self.id,
            "title": self.title,
            "content": self.content,
            "author": self.author.username if self.author else None,
            "created_at": self.created_at.isoformat()
        }

# 数据库操作
def init_db():
    with app.app_context():
        db.create_all()

def add_user(username, email):
    with app.app_context():
        user = User(username=username, email=email)
        db.session.add(user)
        db.session.commit()
        return user

应用工厂模式

from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager

db = SQLAlchemy()
login_manager = LoginManager()

def create_app(config_name="development"):
    app = Flask(__name__)
    
    # 配置
    from config import config
    app.config.from_object(config[config_name])
    
    # 初始化扩展
    db.init_app(app)
    login_manager.init_app(app)
    login_manager.login_view = "auth.login"
    
    # 注册蓝图
    from .main import main as main_blueprint
    app.register_blueprint(main_blueprint)
    
    from .auth import auth as auth_blueprint
    app.register_blueprint(auth_blueprint, url_prefix="/auth")
    
    from .api import api as api_blueprint
    app.register_blueprint(api_blueprint, url_prefix="/api")
    
    # 创建数据库表
    with app.app_context():
        db.create_all()
    
    return app

# config.py
class Config:
    SECRET_KEY = os.environ.get("SECRET_KEY") or "hard-to-guess-string"
    SQLALCHEMY_TRACK_MODIFICATIONS = False

class DevelopmentConfig(Config):
    DEBUG = True
    SQLALCHEMY_DATABASE_URI = os.environ.get("DEV_DATABASE_URL") or \
        "sqlite://" + os.path.join(basedir, "dev.db")

class ProductionConfig(Config):
    SQLALCHEMY_DATABASE_URI = os.environ.get("DATABASE_URL") or \
        "sqlite://" + os.path.join(basedir, "prod.db")

config = {
    "development": DevelopmentConfig,
    "production": ProductionConfig,
    "default": DevelopmentConfig
}

测试

import pytest
from app import create_app, db

@pytest.fixture
def app():
    app = create_app("testing")
    with app.app_context():
        db.create_all()
        yield app
        db.drop_all()

@pytest.fixture
def client(app):
    return app.test_client()

@pytest.fixture
def runner(app):
    return app.test_cli_runner()

def test_index(client):
    response = client.get("/")
    assert response.status_code == 200

def test_create_user(client):
    response = client.post("/api/users", json={
        "username": "testuser",
        "email": "test@example.com"
    })
    assert response.status_code == 201
    assert response.json["username"] == "testuser"

def test_login(client):
    # 先注册
    client.post("/auth/register", json={
        "username": "testuser",
        "email": "test@example.com",
        "password": "123456"
    })
    
    # 登录
    response = client.post("/auth/login", json={
        "username": "testuser",
        "password": "123456"
    })
    assert response.status_code == 200
    assert "access_token" in response.json

总结

Flask是一个功能强大且灵活的Web框架。掌握扩展系统、RESTful API设计、用户认证和数据库集成后,你可以构建各种复杂的Web应用。记住遵循最佳实践:使用应用工厂模式、蓝图组织代码、统一响应格式、完善的错误处理和测试覆盖。