Flask进阶:扩展、RESTful API、认证与数据库集成
Flask进阶:扩展、RESTful API、认证与数据库集成
掌握Flask基础后,本文将带你深入学习Flask的高级特性,包括扩展系统、RESTful API设计、用户认证和数据库集成,让你能够构建生产级的Web应用。
Flask扩展系统
Flask通过扩展机制提供丰富的功能:
# 常用Flask扩展
"""
Flask-SQLAlchemy - 数据库集成
Flask-Migrate - 数据库迁移
Flask-Login - 用户认证
Flask-WTF - 表单处理
Flask-RESTful - RESTful API
Flask-JWT-Extended - JWT认证
Flask-CORS - 跨域支持
Flask-Mail - 邮件发送
Flask-Caching - 缓存
"""
# 安装扩展
# pip install flask-sqlalchemy flask-migrate flask-login flask-cors flask-jwt-extended
扩展初始化模式
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_login import LoginManager
from flask_cors import CORS
db = SQLAlchemy()
migrate = Migrate()
login_manager = LoginManager()
cors = CORS()
def create_app(config_name="development"):
app = Flask(__name__)
# 加载配置
from config import config
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
migrate.init_app(app, db)
login_manager.init_app(app)
cors.init_app(app)
# 注册蓝图
from .api import api_bp
app.register_blueprint(api_bp, url_prefix="/api")
from .auth import auth_bp
app.register_blueprint(auth_bp, url_prefix="/auth")
return app
RESTful API设计
使用Flask-RESTful
from flask import Flask
from flask_restful import Api, Resource, reqparse, fields, marshal_with
from flask_sqlalchemy import SQLAlchemy
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///api.db"
db = SQLAlchemy(app)
api = Api(app)
# 数据模型
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(120), unique=True, nullable=False)
def to_dict(self):
return {
"id": self.id,
"username": self.username,
"email": self.email
}
# 资源字段定义
user_fields = {
"id": fields.Integer,
"username": fields.String,
"email": fields.String
}
# 请求解析器
user_parser = reqparse.RequestParser()
user_parser.add_argument("username", type=str, required=True, help="用户名不能为空")
user_parser.add_argument("email", type=str, required=True, help="邮箱不能为空")
# API资源
class UserResource(Resource):
@marshal_with(user_fields)
def get(self, user_id):
user = User.query.get_or_404(user_id)
return user
def put(self, user_id):
user = User.query.get_or_404(user_id)
args = user_parser.parse_args()
user.username = args["username"]
user.email = args["email"]
db.session.commit()
return user.to_dict()
def delete(self, user_id):
user = User.query.get_or_404(user_id)
db.session.delete(user)
db.session.commit()
return {"message": "用户已删除"}, 204
class UserListResource(Resource):
@marshal_with(user_fields)
def get(self):
return User.query.all()
def post(self):
args = user_parser.parse_args()
user = User(username=args["username"], email=args["email"])
db.session.add(user)
db.session.commit()
return user.to_dict(), 201
# 注册路由
api.add_resource(UserResource, "/api/users/<int:user_id>")
api.add_resource(UserListResource, "/api/users")
RESTful最佳实践
from flask import Blueprint, jsonify, request
from functools import wraps
api = Blueprint("api", __name__)
# 统一响应格式
def success_response(data=None, message="success", status_code=200):
response = {
"code": status_code,
"message": message,
"data": data
}
return jsonify(response), status_code
def error_response(message="error", status_code=400, errors=None):
response = {
"code": status_code,
"message": message,
"errors": errors
}
return jsonify(response), status_code
# 分页装饰器
def paginate(query, page=1, per_page=10):
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
return {
"items": [item.to_dict() for item in pagination.items],
"total": pagination.total,
"page": pagination.page,
"pages": pagination.pages,
"has_next": pagination.has_next,
"has_prev": pagination.has_prev
}
# API示例
@api.route("/articles", methods=["GET"])
def get_articles():
page = request.args.get("page", 1, type=int)
per_page = request.args.get("per_page", 10, type=int)
# 查询参数过滤
query = Article.query
category = request.args.get("category")
if category:
query = query.filter_by(category=category)
keyword = request.args.get("q")
if keyword:
query = query.filter(Article.title.contains(keyword))
# 排序
sort = request.args.get("sort", "created_at")
order = request.args.get("order", "desc")
if order == "desc":
query = query.order_by(getattr(Article, sort).desc())
else:
query = query.order_by(getattr(Article, sort).asc())
result = paginate(query, page, per_page)
return success_response(result)
@api.route("/articles", methods=["POST"])
def create_article():
data = request.get_json()
# 数据验证
if not data or not data.get("title"):
return error_response("标题不能为空")
article = Article(
title=data["title"],
content=data.get("content", ""),
author_id=data.get("author_id")
)
db.session.add(article)
db.session.commit()
return success_response(article.to_dict(), "创建成功", 201)
用户认证
JWT认证实现
from flask import Flask, request, jsonify
from flask_jwt_extended import (
JWTManager, create_access_token, create_refresh_token,
jwt_required, get_jwt_identity, get_jwt
)
from werkzeug.security import generate_password_hash, check_password_hash
from datetime import timedelta
app = Flask(__name__)
app.config["JWT_SECRET_KEY"] = "your-secret-key"
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
jwt = JWTManager(app)
# 简化的用户存储
users_db = {}
class User:
def __init__(self, username, email, password):
self.id = len(users_db) + 1
self.username = username
self.email = email
self.password_hash = generate_password_hash(password)
def check_password(self, password):
return check_password_hash(self.password_hash, password)
def to_dict(self):
return {"id": self.id, "username": self.username, "email": self.email}
@app.route("/auth/register", methods=["POST"])
def register():
data = request.get_json()
if users_db.get(data["username"]):
return jsonify({"error": "用户名已存在"}), 400
user = User(data["username"], data["email"], data["password"])
users_db[user.username] = user
access_token = create_access_token(identity=user.id)
refresh_token = create_refresh_token(identity=user.id)
return jsonify({
"user": user.to_dict(),
"access_token": access_token,
"refresh_token": refresh_token
}), 201
@app.route("/auth/login", methods=["POST"])
def login():
data = request.get_json()
user = users_db.get(data.get("username"))
if not user or not user.check_password(data.get("password")):
return jsonify({"error": "用户名或密码错误"}), 401
access_token = create_access_token(identity=user.id)
refresh_token = create_refresh_token(identity=user.id)
return jsonify({
"user": user.to_dict(),
"access_token": access_token,
"refresh_token": refresh_token
})
@app.route("/auth/refresh", methods=["POST"])
@jwt_required(refresh=True)
def refresh():
identity = get_jwt_identity()
access_token = create_access_token(identity=identity)
return jsonify({"access_token": access_token})
@app.route("/auth/me", methods=["GET"])
@jwt_required()
def get_current_user():
user_id = get_jwt_identity()
# 查询用户
return jsonify({"user_id": user_id})
# JWT错误处理
@jwt.expired_token_loader
def expired_token_callback(jwt_header, jwt_payload):
return jsonify({"error": "令牌已过期"}), 401
@jwt.invalid_token_loader
def invalid_token_callback(error):
return jsonify({"error": "无效的令牌"}), 401
角色权限控制
from functools import wraps
from flask_jwt_extended import get_jwt
def role_required(*roles):
"""角色验证装饰器"""
def decorator(fn):
@wraps(fn)
@jwt_required()
def wrapper(*args, **kwargs):
claims = get_jwt()
user_role = claims.get("role")
if user_role not in roles:
return jsonify({"error": "权限不足"}), 403
return fn(*args, **kwargs)
return wrapper
return decorator
@app.route("/admin/users", methods=["GET"])
@role_required("admin", "superadmin")
def admin_users():
return jsonify({"users": []})
@app.route("/api/data", methods=["GET"])
@role_required("user", "admin", "superadmin")
def get_data():
return jsonify({"data": "sensitive data"})
数据库集成
Flask-SQLAlchemy完整示例
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from datetime import datetime
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///app.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(app)
migrate = Migrate(app, db)
# 模型定义
class User(db.Model):
__tablename__ = "users"
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(120), unique=True, nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
posts = db.relationship("Post", backref="author", lazy="dynamic")
def to_dict(self):
return {
"id": self.id,
"username": self.username,
"email": self.email,
"post_count": self.posts.count()
}
class Post(db.Model):
__tablename__ = "posts"
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(200), nullable=False)
content = db.Column(db.Text)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"))
created_at = db.Column(db.DateTime, default=datetime.utcnow)
def to_dict(self):
return {
"id": self.id,
"title": self.title,
"content": self.content,
"author": self.author.username if self.author else None,
"created_at": self.created_at.isoformat()
}
# 数据库操作
def init_db():
with app.app_context():
db.create_all()
def add_user(username, email):
with app.app_context():
user = User(username=username, email=email)
db.session.add(user)
db.session.commit()
return user
应用工厂模式
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
db = SQLAlchemy()
login_manager = LoginManager()
def create_app(config_name="development"):
app = Flask(__name__)
# 配置
from config import config
app.config.from_object(config[config_name])
# 初始化扩展
db.init_app(app)
login_manager.init_app(app)
login_manager.login_view = "auth.login"
# 注册蓝图
from .main import main as main_blueprint
app.register_blueprint(main_blueprint)
from .auth import auth as auth_blueprint
app.register_blueprint(auth_blueprint, url_prefix="/auth")
from .api import api as api_blueprint
app.register_blueprint(api_blueprint, url_prefix="/api")
# 创建数据库表
with app.app_context():
db.create_all()
return app
# config.py
class Config:
SECRET_KEY = os.environ.get("SECRET_KEY") or "hard-to-guess-string"
SQLALCHEMY_TRACK_MODIFICATIONS = False
class DevelopmentConfig(Config):
DEBUG = True
SQLALCHEMY_DATABASE_URI = os.environ.get("DEV_DATABASE_URL") or \
"sqlite://" + os.path.join(basedir, "dev.db")
class ProductionConfig(Config):
SQLALCHEMY_DATABASE_URI = os.environ.get("DATABASE_URL") or \
"sqlite://" + os.path.join(basedir, "prod.db")
config = {
"development": DevelopmentConfig,
"production": ProductionConfig,
"default": DevelopmentConfig
}
测试
import pytest
from app import create_app, db
@pytest.fixture
def app():
app = create_app("testing")
with app.app_context():
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def client(app):
return app.test_client()
@pytest.fixture
def runner(app):
return app.test_cli_runner()
def test_index(client):
response = client.get("/")
assert response.status_code == 200
def test_create_user(client):
response = client.post("/api/users", json={
"username": "testuser",
"email": "test@example.com"
})
assert response.status_code == 201
assert response.json["username"] == "testuser"
def test_login(client):
# 先注册
client.post("/auth/register", json={
"username": "testuser",
"email": "test@example.com",
"password": "123456"
})
# 登录
response = client.post("/auth/login", json={
"username": "testuser",
"password": "123456"
})
assert response.status_code == 200
assert "access_token" in response.json
总结
Flask是一个功能强大且灵活的Web框架。掌握扩展系统、RESTful API设计、用户认证和数据库集成后,你可以构建各种复杂的Web应用。记住遵循最佳实践:使用应用工厂模式、蓝图组织代码、统一响应格式、完善的错误处理和测试覆盖。