深入理解Django REST Framework的认证、权限与频率控制机制
认证、权限与频率控制的层级关系
在Django REST Framework中,认证、权限和频率控制形成了层层递进的安全机制:
- 权限验证:认证 + 权限检查
- 频率控制:认证 + 权限检查 + 频率限制
局部认证实现
URL配置 (urls.py)
from django.conf.urls import url, include
from api import views
urlpatterns = [
url(r'^login/', views.LoginView.as_view()),
url(r'^products/', views.ProductView.as_view()),
url(r'^products_detail/(\d+)/$', views.ProductDetail.as_view()),
url(r'^order_info/$', views.OrderInfoView.as_view()),
]
视图实现 (views.py)
from rest_framework.views import APIView
from api.models import *
import uuid
from django.http import JsonResponse
from rest_framework.response import Response
from api.serializers import ProductSerializer
from rest_framework import exceptions
from rest_framework.generics import GenericAPIView
USER_ORDERS = {
1: {
'customer': "张三",
'age': 28,
'gender': '男',
'details': '购买商品A'
},
2: {
'customer': "李四",
'age': 32,
'gender': '女',
'details': '购买商品B'
},
}
class LoginView(APIView):
def post(self, request):
result = {'status': 0, 'message': None}
try:
username = request._request.POST.get('username')
password = request._request.POST.get('password')
user = User.objects.filter(username=username, password=password).first()
if not user:
result['status'] = 1
result['message'] = '用户名或密码错误'
else:
token = str(uuid.uuid4())
UserToken.objects.update_or_create(user=user, defaults={'token': token})
result['token'] = token
except Exception as e:
result['status'] = 2
result['message'] = '请求异常'
return JsonResponse(result)
class CustomAuthentication(object):
def authenticate(self, request):
token = request._request.GET.get('token')
token_obj = UserToken.objects.filter(token=token).first()
if not token_obj:
raise exceptions.AuthenticationFailed('认证失败')
return (token_obj.user, token_obj)
def authenticate_header(self, request):
pass
class OrderInfoView(APIView):
authentication_classes = [CustomAuthentication,]
def get(self, request):
result = {'status': 0, 'message': None, 'data': None}
result['data'] = USER_ORDERS
return JsonResponse(result)
模型定义 (models.py)
from django.db import models
class User(models.Model):
username = models.CharField(max_length=32)
password = models.CharField(max_length=32)
user_types = ((1, "普通用户"), (2, "VIP用户"), (3, "SVIP用户"))
user_type = models.IntegerField(choices=user_types, default=1)
class UserToken(models.Model):
user = models.OneToOneField("User")
token = models.CharField(max_length=128)
def __str__(self):
return self.token
全局认证配置
settings.py 配置
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'app.service.auth.CustomAuthentication',
]
}
认证服务模块 (app.service.auth.py)
from app.models import *
from rest_framework import exceptions
class CustomAuthentication(object):
def authenticate(self, request):
token = request._request.GET.get('token')
token_obj = UserToken.objects.filter(token=token).first()
if not token_obj:
raise exceptions.AuthenticationFailed('认证失败')
return (token_obj.user, token_obj)
def authenticate_header(self, request):
pass
认证流程源码分析
REST Framework的认证流程如下:
- 请求进入视图前,会依次执行所有认证类
- 每个认证类调用authenticate方法
- 如果认证失败,抛出AuthenticationFailed异常
- 如果认证成功,将(user, auth)元组赋值给request.user和request.auth
权限控制
class VIPPermission(object):
def has_permission(self, request, view):
# 只有VIP用户可以访问
if request.user.user_type == 2:
return True
return False
class OrderInfoView(APIView):
permission_classes = [VIPPermission,]
def get(self, request):
result = {'status': 0, 'message': None, 'data': None}
if request.user.user_type == 2:
result['data'] = USER_ORDERS
return JsonResponse(result)
频率控制
import time
VISIT_LOG = {} # 格式: {ip: [timestamp1, timestamp2, ...]}
class RateThrottle(object):
"""60秒内最多访问5次"""
def __init__(self):
self.history = None
def allow_request(self, request, view):
# 获取客户端IP
client_ip = request.META.get('REMOTE_ADDR')
current_time = time.time()
if client_ip not in VISIT_LOG:
VISIT_LOG[client_ip] = [current_time]
return True
history = VISIT_LOG.get(client_ip)
self.history = history
# 移除60秒前的记录
while history and history[-1] < current_time - 60:
history.pop()
# 检查访问次数是否超过限制
if len(history) < 5:
history.insert(0, current_time)
return True
# 超过限制,返回False
return False
def wait(self):
"""计算还需等待多少秒才能再次访问"""
current_time = time.time()
return 60 - (current_time - self.history[-1])