Django进阶:中间件、信号、缓存与REST Framework
Django进阶:中间件、信号、缓存与REST Framework
掌握Django基础后,本文将带你深入学习中间件机制、信号系统、缓存策略和Django REST Framework,让你能够构建高性能、可扩展的Web应用。
中间件
中间件是Django处理请求和响应的钩子框架:
自定义中间件
# myapp/middleware.py
import time
import logging
from django.http import JsonResponse
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger(__name__)
class RequestLoggingMiddleware(MiddlewareMixin):
"""请求日志中间件"""
def process_request(self, request):
request.start_time = time.time()
def process_response(self, request, response):
if hasattr(request, "start_time"):
duration = time.time() - request.start_time
logger.info(
f"{request.method} {request.path} - "
f"{response.status_code} ({duration:.3f}s)"
)
return response
class APIRateLimitMiddleware(MiddlewareMixin):
"""API速率限制中间件"""
def __init__(self, get_response):
super().__init__(get_response)
self.request_counts = {}
def process_request(self, request):
if request.path.startswith("/api/"):
ip = self.get_client_ip(request)
current_time = time.time()
# 清理过期记录
self.request_counts = {
k: v for k, v in self.request_counts.items()
if current_time - v["time"] < 60
}
# 检查速率限制
if ip in self.request_counts:
count = self.request_counts[ip]["count"]
if count >= 100: # 每分钟最多100次
return JsonResponse(
{"error": "请求过于频繁,请稍后再试"},
status=429
)
self.request_counts[ip]["count"] += 1
else:
self.request_counts[ip] = {"count": 1, "time": current_time}
def get_client_ip(self, request):
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
return x_forwarded_for.split(",")[0]
return request.META.get("REMOTE_ADDR")
class CORSMiddleware(MiddlewareMixin):
"""CORS中间件"""
def process_response(self, request, response):
response["Access-Control-Allow-Origin"] = "*"
response["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response
# settings.py注册中间件
MIDDLEWARE = [
"django.middleware.security.SecurityMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"myapp.middleware.CORSMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
"myapp.middleware.RequestLoggingMiddleware",
"myapp.middleware.APIRateLimitMiddleware",
]
信号系统
信号允许在特定操作发生时执行代码:
# myapp/signals.py
from django.db.models.signals import pre_save, post_save, pre_delete, post_delete
from django.dispatch import receiver
from django.contrib.auth.signals import user_logged_in, user_logged_out
from .models import Article, User
# 模型信号
@receiver(pre_save, sender=Article)
def article_pre_save(sender, instance, **kwargs):
"""文章保存前的处理"""
if not instance.slug:
instance.slug = slugify(instance.title)
print(f"文章即将保存: {instance.title}")
@receiver(post_save, sender=Article)
def article_post_save(sender, instance, created, **kwargs):
"""文章保存后的处理"""
if created:
print(f"新文章创建: {instance.title}")
# 发送通知邮件
# send_notification(instance)
else:
print(f"文章更新: {instance.title}")
@receiver(post_delete, sender=Article)
def article_post_delete(sender, instance, **kwargs):
"""文章删除后的处理"""
print(f"文章已删除: {instance.title}")
# 清理相关文件
# instance.image.delete()
# 用户信号
@receiver(user_logged_in)
def user_logged_in_handler(sender, request, user, **kwargs):
"""用户登录后"""
print(f"用户 {user.username} 登录")
@receiver(user_logged_out)
def user_logged_out_handler(sender, request, user, **kwargs):
"""用户登出后"""
print(f"用户 {user.username} 登出")
# 自定义信号
from django.dispatch import Signal
article_viewed = Signal()
@receiver(article_viewed)
def handle_article_viewed(sender, article, user, **kwargs):
"""处理文章浏览事件"""
print(f"用户 {user} 浏览了文章 {article.title}")
# 在视图中发送信号
def article_detail(request, slug):
article = get_object_or_404(Article, slug=slug)
article_viewed.send(
sender=Article,
article=article,
user=request.user
)
return render(request, "article_detail.html", {"article": article})
# apps.py中注册信号
# myapp/apps.py
from django.apps import AppConfig
class MyappConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "myapp"
def ready(self):
import myapp.signals
缓存系统
# settings.py缓存配置
CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.redis.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/1",
"TIMEOUT": 300,
"OPTIONS": {
"db": "1",
}
}
}
# 使用Memcached
CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.memcached.PyMemcacheCache",
"LOCATION": "127.0.0.1:11211",
}
}
# 视图缓存
from django.views.decorators.cache import cache_page
@cache_page(60 * 15) # 缓存15分钟
def article_list(request):
articles = Article.objects.filter(status="published")
return render(request, "article_list.html", {"articles": articles})
# 模板片段缓存
"""
{% load cache %}
{% cache 300 article_list %}
{% for article in articles %}
<article>{{ article.title }}</article>
{% endfor %}
{% endcache %}
"""
# 低级缓存API
from django.core.cache import cache
def get_article(article_id):
# 尝试从缓存获取
cache_key = f"article_{article_id}"
article = cache.get(cache_key)
if article is None:
article = Article.objects.get(id=article_id)
# 存入缓存
cache.set(cache_key, article, timeout=300)
return article
def update_article(article_id, data):
article = Article.objects.get(id=article_id)
article.title = data["title"]
article.content = data["content"]
article.save()
# 更新缓存
cache_key = f"article_{article_id}"
cache.set(cache_key, article, timeout=300)
# 或者删除缓存(下次访问时重新加载)
# cache.delete(cache_key)
def get_article_list():
cache_key = "article_list"
articles = cache.get(cache_key)
if articles is None:
articles = list(Article.objects.filter(status="published"))
cache.set(cache_key, articles, timeout=600)
return articles
# Redis缓存实战
import redis
from django.conf import settings
redis_client = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB
)
class ArticleCache:
@staticmethod
def get_views(article_id):
key = f"article:{article_id}:views"
return int(redis_client.get(key) or 0)
@staticmethod
def increment_views(article_id):
key = f"article:{article_id}:views"
redis_client.incr(key)
@staticmethod
def get_hot_articles(limit=10):
# 使用Sorted Set存储热度
key = "hot_articles"
return redis_client.zrevrange(key, 0, limit - 1, withscores=True)
@staticmethod
def update_hot_score(article_id, score=1):
key = "hot_articles"
redis_client.zincrby(key, score, article_id)
Django REST Framework
# 安装
# pip install djangorestframework
# settings.py
INSTALLED_APPS = [
...
"rest_framework",
"myapp",
]
REST_FRAMEWORK = {
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
"PAGE_SIZE": 10,
"DEFAULT_AUTHENTICATION_CLASSES": [
"rest_framework.authentication.SessionAuthentication",
"rest_framework.authentication.TokenAuthentication",
],
"DEFAULT_PERMISSION_CLASSES": [
"rest_framework.permissions.IsAuthenticatedOrReadOnly",
],
"DEFAULT_THROTTLE_CLASSES": [
"rest_framework.throttling.AnonRateThrottle",
"rest_framework.throttling.UserRateThrottle"
],
"DEFAULT_THROTTLE_RATES": {
"anon": "100/day",
"user": "1000/day"
}
}
# 序列化器
# myapp/serializers.py
from rest_framework import serializers
from .models import Article, Category, Tag
class TagSerializer(serializers.ModelSerializer):
class Meta:
model = Tag
fields = ["id", "name"]
class CategorySerializer(serializers.ModelSerializer):
article_count = serializers.SerializerMethodField()
class Meta:
model = Category
fields = ["id", "name", "slug", "article_count"]
def get_article_count(self, obj):
return obj.article_set.count()
class ArticleListSerializer(serializers.ModelSerializer):
author_name = serializers.CharField(source="author.username", read_only=True)
category_name = serializers.CharField(source="category.name", read_only=True)
class Meta:
model = Article
fields = [
"id", "title", "slug", "author_name", "category_name",
"status", "views_count", "created_at"
]
class ArticleDetailSerializer(serializers.ModelSerializer):
author = serializers.StringRelatedField()
category = CategorySerializer()
tags = TagSerializer(many=True)
class Meta:
model = Article
fields = "__all__"
def validate_title(self, value):
if len(value) < 5:
raise serializers.ValidationError("标题至少需要5个字符")
return value
def validate(self, data):
if data["status"] == "published" and not data.get("content"):
raise serializers.ValidationError("发布状态必须有内容")
return data
class ArticleCreateSerializer(serializers.ModelSerializer):
class Meta:
model = Article
fields = ["title", "slug", "content", "category", "tags", "status"]
def create(self, validated_data):
tags = validated_data.pop("tags", [])
article = Article.objects.create(**validated_data)
article.tags.set(tags)
return article
# 视图
# myapp/views.py
from rest_framework import viewsets, permissions, status, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from django_filters.rest_framework import DjangoFilterBackend
from .models import Article, Category, Tag
from .serializers import (
ArticleListSerializer, ArticleDetailSerializer,
ArticleCreateSerializer, CategorySerializer, TagSerializer
)
from .permissions import IsAuthorOrReadOnly
class ArticleViewSet(viewsets.ModelViewSet):
queryset = Article.objects.all()
filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter]
filterset_fields = ["status", "category", "author"]
search_fields = ["title", "content"]
ordering_fields = ["created_at", "views_count"]
ordering = ["-created_at"]
def get_serializer_class(self):
if self.action == "list":
return ArticleListSerializer
elif self.action in ["create", "update", "partial_update"]:
return ArticleCreateSerializer
return ArticleDetailSerializer
def get_permissions(self):
if self.action in ["list", "retrieve"]:
return [permissions.AllowAny()]
return [permissions.IsAuthenticated(), IsAuthorOrReadOnly()]
def perform_create(self, serializer):
serializer.save(author=self.request.user)
@action(detail=True, methods=["post"])
def publish(self, request, pk=None):
article = self.get_object()
if article.status == "draft":
article.status = "published"
article.save()
return Response({"status": "published"})
return Response(
{"error": "文章已经是发布状态"},
status=status.HTTP_400_BAD_REQUEST
)
@action(detail=False, methods=["get"])
def hot(self, request):
articles = Article.objects.filter(
status="published"
).order_by("-views_count")[:10]
serializer = ArticleListSerializer(articles, many=True)
return Response(serializer.data)
class CategoryViewSet(viewsets.ModelViewSet):
queryset = Category.objects.all()
serializer_class = CategorySerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
class TagViewSet(viewsets.ModelViewSet):
queryset = Tag.objects.all()
serializer_class = TagSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
# 自定义权限
# myapp/permissions.py
from rest_framework import permissions
class IsAuthorOrReadOnly(permissions.BasePermission):
def has_object_permission(self, request, view, obj):
if request.method in permissions.SAFE_METHODS:
return True
return obj.author == request.user
# URL配置
# myapp/api_urls.py
from rest_framework.routers import DefaultRouter
from . import views
router = DefaultRouter()
router.register(r"articles", views.ArticleViewSet)
router.register(r"categories", views.CategoryViewSet)
router.register(r"tags", views.TagViewSet)
urlpatterns = router.urls
# myproject/urls.py
from django.contrib import admin
from django.urls import path, include
urlpatterns = [
path("admin/", admin.site.urls),
path("api/", include("myapp.api_urls")),
path("auth/", include("rest_framework.urls")),
]
JWT认证
# 安装
# pip install djangorestframework-simplejwt
# settings.py
from datetime import timedelta
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework_simplejwt.authentication.JWTAuthentication",
),
}
SIMPLE_JWT = {
"ACCESS_TOKEN_LIFETIME": timedelta(minutes=30),
"REFRESH_TOKEN_LIFETIME": timedelta(days=7),
"ROTATE_REFRESH_TOKENS": True,
"BLACKLIST_AFTER_ROTATION": True,
"AUTH_HEADER_TYPES": ("Bearer",),
}
# urls.py
from rest_framework_simplejwt.views import (
TokenObtainPairView,
TokenRefreshView,
)
urlpatterns = [
path("api/token/", TokenObtainPairView.as_view(), name="token_obtain_pair"),
path("api/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"),
]
部署配置
# settings.py生产环境配置
DEBUG = False
ALLOWED_HOSTS = ["yourdomain.com"]
# 安全设置
SECURE_BROWSER_XSS_FILTER = True
SECURE_CONTENT_TYPE_NOSNIFF = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
SECURE_SSL_REDIRECT = True
SECURE_HSTS_SECONDS = 31536000
SECURE_HSTS_INCLUDE_SUBDOMAINS = True
SECURE_HSTS_PRELOAD = True
# 静态文件
STATIC_ROOT = "/var/www/static/"
MEDIA_ROOT = "/var/www/media/"
# 数据库
DATABASES = {
"default": {
"ENGINE": "django.db.backends.postgresql",
"NAME": "mydb",
"USER": "myuser",
"PASSWORD": "password",
"HOST": "localhost",
"PORT": "5432",
}
}
# 缓存
CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.redis.RedisCache",
"LOCATION": "redis://redis:6379/1",
}
}
# 日志
LOGGING = {
"version": 1,
"handlers": {
"file": {
"level": "WARNING",
"class": "logging.FileHandler",
"filename": "/var/log/django.log",
},
},
"loggers": {
"django": {
"handlers": ["file"],
"level": "WARNING",
"propagate": True,
},
},
}
总结
Django是一个功能强大的Web框架,掌握中间件、信号、缓存和REST Framework等进阶知识后,你可以构建高性能、可扩展的Web应用。在实际项目中,要根据需求选择合适的架构,遵循最佳实践,并做好性能优化和安全防护。