← 返回首页

Django进阶:中间件、信号、缓存与REST Framework

📂 python ⏱ 6 min 1072 words

Django进阶:中间件、信号、缓存与REST Framework

掌握Django基础后,本文将带你深入学习中间件机制、信号系统、缓存策略和Django REST Framework,让你能够构建高性能、可扩展的Web应用。

中间件

中间件是Django处理请求和响应的钩子框架:

自定义中间件

# myapp/middleware.py
import time
import logging
from django.http import JsonResponse
from django.utils.deprecation import MiddlewareMixin

logger = logging.getLogger(__name__)

class RequestLoggingMiddleware(MiddlewareMixin):
    """请求日志中间件"""
    
    def process_request(self, request):
        request.start_time = time.time()
    
    def process_response(self, request, response):
        if hasattr(request, "start_time"):
            duration = time.time() - request.start_time
            logger.info(
                f"{request.method} {request.path} - "
                f"{response.status_code} ({duration:.3f}s)"
            )
        return response

class APIRateLimitMiddleware(MiddlewareMixin):
    """API速率限制中间件"""
    
    def __init__(self, get_response):
        super().__init__(get_response)
        self.request_counts = {}
    
    def process_request(self, request):
        if request.path.startswith("/api/"):
            ip = self.get_client_ip(request)
            current_time = time.time()
            
            # 清理过期记录
            self.request_counts = {
                k: v for k, v in self.request_counts.items()
                if current_time - v["time"] < 60
            }
            
            # 检查速率限制
            if ip in self.request_counts:
                count = self.request_counts[ip]["count"]
                if count >= 100:  # 每分钟最多100次
                    return JsonResponse(
                        {"error": "请求过于频繁,请稍后再试"},
                        status=429
                    )
                self.request_counts[ip]["count"] += 1
            else:
                self.request_counts[ip] = {"count": 1, "time": current_time}
    
    def get_client_ip(self, request):
        x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
        if x_forwarded_for:
            return x_forwarded_for.split(",")[0]
        return request.META.get("REMOTE_ADDR")

class CORSMiddleware(MiddlewareMixin):
    """CORS中间件"""
    
    def process_response(self, request, response):
        response["Access-Control-Allow-Origin"] = "*"
        response["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
        response["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
        return response

# settings.py注册中间件
MIDDLEWARE = [
    "django.middleware.security.SecurityMiddleware",
    "django.contrib.sessions.middleware.SessionMiddleware",
    "myapp.middleware.CORSMiddleware",
    "django.middleware.common.CommonMiddleware",
    "django.middleware.csrf.CsrfViewMiddleware",
    "django.contrib.auth.middleware.AuthenticationMiddleware",
    "django.contrib.messages.middleware.MessageMiddleware",
    "django.middleware.clickjacking.XFrameOptionsMiddleware",
    "myapp.middleware.RequestLoggingMiddleware",
    "myapp.middleware.APIRateLimitMiddleware",
]

信号系统

信号允许在特定操作发生时执行代码:

# myapp/signals.py
from django.db.models.signals import pre_save, post_save, pre_delete, post_delete
from django.dispatch import receiver
from django.contrib.auth.signals import user_logged_in, user_logged_out
from .models import Article, User

# 模型信号
@receiver(pre_save, sender=Article)
def article_pre_save(sender, instance, **kwargs):
    """文章保存前的处理"""
    if not instance.slug:
        instance.slug = slugify(instance.title)
    print(f"文章即将保存: {instance.title}")

@receiver(post_save, sender=Article)
def article_post_save(sender, instance, created, **kwargs):
    """文章保存后的处理"""
    if created:
        print(f"新文章创建: {instance.title}")
        # 发送通知邮件
        # send_notification(instance)
    else:
        print(f"文章更新: {instance.title}")

@receiver(post_delete, sender=Article)
def article_post_delete(sender, instance, **kwargs):
    """文章删除后的处理"""
    print(f"文章已删除: {instance.title}")
    # 清理相关文件
    # instance.image.delete()

# 用户信号
@receiver(user_logged_in)
def user_logged_in_handler(sender, request, user, **kwargs):
    """用户登录后"""
    print(f"用户 {user.username} 登录")

@receiver(user_logged_out)
def user_logged_out_handler(sender, request, user, **kwargs):
    """用户登出后"""
    print(f"用户 {user.username} 登出")

# 自定义信号
from django.dispatch import Signal

article_viewed = Signal()

@receiver(article_viewed)
def handle_article_viewed(sender, article, user, **kwargs):
    """处理文章浏览事件"""
    print(f"用户 {user} 浏览了文章 {article.title}")

# 在视图中发送信号
def article_detail(request, slug):
    article = get_object_or_404(Article, slug=slug)
    article_viewed.send(
        sender=Article,
        article=article,
        user=request.user
    )
    return render(request, "article_detail.html", {"article": article})

# apps.py中注册信号
# myapp/apps.py
from django.apps import AppConfig

class MyappConfig(AppConfig):
    default_auto_field = "django.db.models.BigAutoField"
    name = "myapp"
    
    def ready(self):
        import myapp.signals

缓存系统

# settings.py缓存配置
CACHES = {
    "default": {
        "BACKEND": "django.core.cache.backends.redis.RedisCache",
        "LOCATION": "redis://127.0.0.1:6379/1",
        "TIMEOUT": 300,
        "OPTIONS": {
            "db": "1",
        }
    }
}

# 使用Memcached
CACHES = {
    "default": {
        "BACKEND": "django.core.cache.backends.memcached.PyMemcacheCache",
        "LOCATION": "127.0.0.1:11211",
    }
}

# 视图缓存
from django.views.decorators.cache import cache_page

@cache_page(60 * 15)  # 缓存15分钟
def article_list(request):
    articles = Article.objects.filter(status="published")
    return render(request, "article_list.html", {"articles": articles})

# 模板片段缓存
"""
{% load cache %}

{% cache 300 article_list %}
    {% for article in articles %}
        <article>{{ article.title }}</article>
    {% endfor %}
{% endcache %}
"""

# 低级缓存API
from django.core.cache import cache

def get_article(article_id):
    # 尝试从缓存获取
    cache_key = f"article_{article_id}"
    article = cache.get(cache_key)
    
    if article is None:
        article = Article.objects.get(id=article_id)
        # 存入缓存
        cache.set(cache_key, article, timeout=300)
    
    return article

def update_article(article_id, data):
    article = Article.objects.get(id=article_id)
    article.title = data["title"]
    article.content = data["content"]
    article.save()
    
    # 更新缓存
    cache_key = f"article_{article_id}"
    cache.set(cache_key, article, timeout=300)
    
    # 或者删除缓存(下次访问时重新加载)
    # cache.delete(cache_key)

def get_article_list():
    cache_key = "article_list"
    articles = cache.get(cache_key)
    
    if articles is None:
        articles = list(Article.objects.filter(status="published"))
        cache.set(cache_key, articles, timeout=600)
    
    return articles

# Redis缓存实战
import redis
from django.conf import settings

redis_client = redis.Redis(
    host=settings.REDIS_HOST,
    port=settings.REDIS_PORT,
    db=settings.REDIS_DB
)

class ArticleCache:
    @staticmethod
    def get_views(article_id):
        key = f"article:{article_id}:views"
        return int(redis_client.get(key) or 0)
    
    @staticmethod
    def increment_views(article_id):
        key = f"article:{article_id}:views"
        redis_client.incr(key)
    
    @staticmethod
    def get_hot_articles(limit=10):
        # 使用Sorted Set存储热度
        key = "hot_articles"
        return redis_client.zrevrange(key, 0, limit - 1, withscores=True)
    
    @staticmethod
    def update_hot_score(article_id, score=1):
        key = "hot_articles"
        redis_client.zincrby(key, score, article_id)

Django REST Framework

# 安装
# pip install djangorestframework

# settings.py
INSTALLED_APPS = [
    ...
    "rest_framework",
    "myapp",
]

REST_FRAMEWORK = {
    "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
    "PAGE_SIZE": 10,
    "DEFAULT_AUTHENTICATION_CLASSES": [
        "rest_framework.authentication.SessionAuthentication",
        "rest_framework.authentication.TokenAuthentication",
    ],
    "DEFAULT_PERMISSION_CLASSES": [
        "rest_framework.permissions.IsAuthenticatedOrReadOnly",
    ],
    "DEFAULT_THROTTLE_CLASSES": [
        "rest_framework.throttling.AnonRateThrottle",
        "rest_framework.throttling.UserRateThrottle"
    ],
    "DEFAULT_THROTTLE_RATES": {
        "anon": "100/day",
        "user": "1000/day"
    }
}

# 序列化器
# myapp/serializers.py
from rest_framework import serializers
from .models import Article, Category, Tag

class TagSerializer(serializers.ModelSerializer):
    class Meta:
        model = Tag
        fields = ["id", "name"]

class CategorySerializer(serializers.ModelSerializer):
    article_count = serializers.SerializerMethodField()
    
    class Meta:
        model = Category
        fields = ["id", "name", "slug", "article_count"]
    
    def get_article_count(self, obj):
        return obj.article_set.count()

class ArticleListSerializer(serializers.ModelSerializer):
    author_name = serializers.CharField(source="author.username", read_only=True)
    category_name = serializers.CharField(source="category.name", read_only=True)
    
    class Meta:
        model = Article
        fields = [
            "id", "title", "slug", "author_name", "category_name",
            "status", "views_count", "created_at"
        ]

class ArticleDetailSerializer(serializers.ModelSerializer):
    author = serializers.StringRelatedField()
    category = CategorySerializer()
    tags = TagSerializer(many=True)
    
    class Meta:
        model = Article
        fields = "__all__"
    
    def validate_title(self, value):
        if len(value) < 5:
            raise serializers.ValidationError("标题至少需要5个字符")
        return value
    
    def validate(self, data):
        if data["status"] == "published" and not data.get("content"):
            raise serializers.ValidationError("发布状态必须有内容")
        return data

class ArticleCreateSerializer(serializers.ModelSerializer):
    class Meta:
        model = Article
        fields = ["title", "slug", "content", "category", "tags", "status"]
    
    def create(self, validated_data):
        tags = validated_data.pop("tags", [])
        article = Article.objects.create(**validated_data)
        article.tags.set(tags)
        return article

# 视图
# myapp/views.py
from rest_framework import viewsets, permissions, status, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from django_filters.rest_framework import DjangoFilterBackend
from .models import Article, Category, Tag
from .serializers import (
    ArticleListSerializer, ArticleDetailSerializer,
    ArticleCreateSerializer, CategorySerializer, TagSerializer
)
from .permissions import IsAuthorOrReadOnly

class ArticleViewSet(viewsets.ModelViewSet):
    queryset = Article.objects.all()
    filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter]
    filterset_fields = ["status", "category", "author"]
    search_fields = ["title", "content"]
    ordering_fields = ["created_at", "views_count"]
    ordering = ["-created_at"]
    
    def get_serializer_class(self):
        if self.action == "list":
            return ArticleListSerializer
        elif self.action in ["create", "update", "partial_update"]:
            return ArticleCreateSerializer
        return ArticleDetailSerializer
    
    def get_permissions(self):
        if self.action in ["list", "retrieve"]:
            return [permissions.AllowAny()]
        return [permissions.IsAuthenticated(), IsAuthorOrReadOnly()]
    
    def perform_create(self, serializer):
        serializer.save(author=self.request.user)
    
    @action(detail=True, methods=["post"])
    def publish(self, request, pk=None):
        article = self.get_object()
        if article.status == "draft":
            article.status = "published"
            article.save()
            return Response({"status": "published"})
        return Response(
            {"error": "文章已经是发布状态"},
            status=status.HTTP_400_BAD_REQUEST
        )
    
    @action(detail=False, methods=["get"])
    def hot(self, request):
        articles = Article.objects.filter(
            status="published"
        ).order_by("-views_count")[:10]
        serializer = ArticleListSerializer(articles, many=True)
        return Response(serializer.data)

class CategoryViewSet(viewsets.ModelViewSet):
    queryset = Category.objects.all()
    serializer_class = CategorySerializer
    permission_classes = [permissions.IsAuthenticatedOrReadOnly]

class TagViewSet(viewsets.ModelViewSet):
    queryset = Tag.objects.all()
    serializer_class = TagSerializer
    permission_classes = [permissions.IsAuthenticatedOrReadOnly]

# 自定义权限
# myapp/permissions.py
from rest_framework import permissions

class IsAuthorOrReadOnly(permissions.BasePermission):
    def has_object_permission(self, request, view, obj):
        if request.method in permissions.SAFE_METHODS:
            return True
        return obj.author == request.user

# URL配置
# myapp/api_urls.py
from rest_framework.routers import DefaultRouter
from . import views

router = DefaultRouter()
router.register(r"articles", views.ArticleViewSet)
router.register(r"categories", views.CategoryViewSet)
router.register(r"tags", views.TagViewSet)

urlpatterns = router.urls

# myproject/urls.py
from django.contrib import admin
from django.urls import path, include

urlpatterns = [
    path("admin/", admin.site.urls),
    path("api/", include("myapp.api_urls")),
    path("auth/", include("rest_framework.urls")),
]

JWT认证

# 安装
# pip install djangorestframework-simplejwt

# settings.py
from datetime import timedelta

REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": (
        "rest_framework_simplejwt.authentication.JWTAuthentication",
    ),
}

SIMPLE_JWT = {
    "ACCESS_TOKEN_LIFETIME": timedelta(minutes=30),
    "REFRESH_TOKEN_LIFETIME": timedelta(days=7),
    "ROTATE_REFRESH_TOKENS": True,
    "BLACKLIST_AFTER_ROTATION": True,
    "AUTH_HEADER_TYPES": ("Bearer",),
}

# urls.py
from rest_framework_simplejwt.views import (
    TokenObtainPairView,
    TokenRefreshView,
)

urlpatterns = [
    path("api/token/", TokenObtainPairView.as_view(), name="token_obtain_pair"),
    path("api/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"),
]

部署配置

# settings.py生产环境配置
DEBUG = False
ALLOWED_HOSTS = ["yourdomain.com"]

# 安全设置
SECURE_BROWSER_XSS_FILTER = True
SECURE_CONTENT_TYPE_NOSNIFF = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
SECURE_SSL_REDIRECT = True
SECURE_HSTS_SECONDS = 31536000
SECURE_HSTS_INCLUDE_SUBDOMAINS = True
SECURE_HSTS_PRELOAD = True

# 静态文件
STATIC_ROOT = "/var/www/static/"
MEDIA_ROOT = "/var/www/media/"

# 数据库
DATABASES = {
    "default": {
        "ENGINE": "django.db.backends.postgresql",
        "NAME": "mydb",
        "USER": "myuser",
        "PASSWORD": "password",
        "HOST": "localhost",
        "PORT": "5432",
    }
}

# 缓存
CACHES = {
    "default": {
        "BACKEND": "django.core.cache.backends.redis.RedisCache",
        "LOCATION": "redis://redis:6379/1",
    }
}

# 日志
LOGGING = {
    "version": 1,
    "handlers": {
        "file": {
            "level": "WARNING",
            "class": "logging.FileHandler",
            "filename": "/var/log/django.log",
        },
    },
    "loggers": {
        "django": {
            "handlers": ["file"],
            "level": "WARNING",
            "propagate": True,
        },
    },
}

总结

Django是一个功能强大的Web框架,掌握中间件、信号、缓存和REST Framework等进阶知识后,你可以构建高性能、可扩展的Web应用。在实际项目中,要根据需求选择合适的架构,遵循最佳实践,并做好性能优化和安全防护。