文章

Flask - 速率限制

目標

  • 安裝並配置 Flask-Limiter
  • 為關鍵端點添加速率限制
  • 測試限制效果

步驟

  1. 準備環境

    • 繼續使用 flask_api/ 項目結構,激活虛擬環境:
      1
      2
      
      # Windows: flask_api_env\Scripts\activate
      # macOS/Linux: source flask_api_env/bin/activate
      
    • 安裝 Flask-Limiter:
      1
      
      pip install flask-limiter
      
  2. 配置 Flask-Limiter

    • 修改 app/init.py,初始化速率限制:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      130
      131
      
      from flask import Flask, jsonify, g
      from flask_sqlalchemy import SQLAlchemy
      from flask_marshmallow import Marshmallow
      from flask_bcrypt import Bcrypt
      from flask_restx import Api
      from flask_caching import Cache
      from flask_limiter import Limiter  # 新增
      from flask_limiter.util import get_remote_address  # 新增
      import jwt
      from functools import wraps
      from .routes.v1.todos import todos_bp as todos_v1_bp
      from .routes.v1.users import users_bp as users_v1_bp
      from .routes.v1.posts import posts_bp as posts_v1_bp
      from .routes.v2.todos import todos_bp as todos_v2_bp
      from .routes.v2.posts import posts_bp as posts_v2_bp
      from .config import config_map
      from .celery_config import make_celery
      import os
      import logging
      from logging.handlers import RotatingFileHandler
      
      db = SQLAlchemy()
      ma = Marshmallow()
      bcrypt = Bcrypt()
      cache = Cache()
      limiter = Limiter(key_func=get_remote_address)  # 默認基於 IP
      
      def setup_logging(app):
          if not app.debug:
              handler = RotatingFileHandler('app.log', maxBytes=10000, backupCount=3)
              handler.setLevel(logging.INFO)
              formatter = logging.Formatter(
                  '%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
              )
              handler.setFormatter(formatter)
              app.logger.addHandler(handler)
          console_handler = logging.StreamHandler()
          console_handler.setLevel(logging.DEBUG)
          console_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s'))
          app.logger.addHandler(console_handler)
          app.logger.setLevel(logging.DEBUG)
      
      def create_app():
          app = Flask(__name__)
          env = os.getenv('FLASK_ENV', 'development')
          app.config.from_object(config_map[env])
          app.config['CACHE_TYPE'] = 'simple'
          app.config['CACHE_DEFAULT_TIMEOUT'] = 300
      
          db.init_app(app)
          ma.init_app(app)
          bcrypt.init_app(app)
          cache.init_app(app)
          limiter.init_app(app)  # 初始化 Limiter
          setup_logging(app)
      
          api = Api(app,
                    title='Blog API',
                    version='1.0',
                    description='A simple blog API with user and post management',
                    doc='/api/docs/',
                    authorizations={
                        'jwt': {
                            'type': 'apiKey',
                            'in': 'header',
                            'name': 'Authorization',
                            'description': 'Enter "Bearer <token>"'
                        }
                    })
      
          global celery
          celery = make_celery(app)
      
          app.register_blueprint(todos_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(users_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(posts_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(todos_v2_bp, url_prefix='/api/v2')
          app.register_blueprint(posts_v2_bp, url_prefix='/api/v2')
      
          @app.errorhandler(404)
          def not_found(error):
              app.logger.error(f'404 error: {str(error)}')
              return jsonify({'error': 'Not Found', 'message': str(error)}), 404
      
          @app.errorhandler(400)
          def bad_request(error):
              app.logger.warning(f'400 error: {str(error)}')
              return jsonify({'error': 'Bad Request', 'message': str(error)}), 400
      
          @app.errorhandler(500)
          def internal_error(error):
              app.logger.critical(f'500 error: {str(error)}')
              return jsonify({'error': 'Internal Server Error', 'message': 'Something went wrong on our end'}), 500
      
          with app.app_context():
              db.create_all()
      
          return app
      
      def login_required(f):
          @wraps(f)
          def decorated_function(*args, **kwargs):
              from .models import User
              token = request.headers.get('Authorization')
              if not token:
                  abort(401, description='Missing token')
              try:
                  if token.startswith('Bearer '):
                      token = token[7:]
                  data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=['HS256'])
                  user = User.query.get(data['user_id'])
                  if not user:
                      abort(401, description='Invalid token')
                  g.current_user = user
              except jwt.ExpiredSignatureError:
                  abort(401, description='Token has expired')
              except jwt.InvalidTokenError:
                  abort(401, description='Invalid token')
              return f(*args, **kwargs)
          return decorated_function
      
      def admin_required(f):
          @wraps(f)
          @login_required
          def decorated_function(*args, **kwargs):
              if not g.current_user.is_admin:
                  abort(403, description='Admin access required')
              return f(*args, **kwargs)
          return decorated_function
      
      celery = None
      
  3. 添加速率限制

    • 修改 app/routes/v1/users.py,限制登錄和用戶創建:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      
      from flask import request, g, abort
      from flask_restx import Namespace, Resource, fields
      from ...models import User
      from ... import db, admin_required, cache, limiter
      from ...schemas import user_schema, users_schema
      import jwt
      import datetime
      
      api = Namespace('users', description='User management operations')
      
      user_model = api.model('User', {
          'id': fields.Integer(readonly=True),
          'username': fields.String(required=True, description='The user username'),
          'password': fields.String(required=True, description='The user password'),
          'role': fields.String(description='User role (user or admin)', default='user')
      })
      
      login_model = api.model('Login', {
          'username': fields.String(required=True),
          'password': fields.String(required=True)
      })
      
      @api.route('')
      class UserList(Resource):
          @api.doc('list_users')
          @api.marshal_list_with(user_model)
          @cache.cached(timeout=60)
          def get(self):
              """List all users"""
              users = User.query.all()
              api.app.logger.info('Fetched all users')
              return users_schema.dump(users)
      
          @api.doc('create_user')
          @api.expect(user_model)
          @api.marshal_with(user_model, code=201)
          @limiter.limit("5 per minute")  # 每分鐘 5 次
          def post(self):
              """Create a new user"""
              if not request.is_json:
                  api.abort(400, 'Request must be JSON')
              data = request.get_json()
              if 'username' not in data or 'password' not in data:
                  api.abort(400, 'Missing username or password')
              if User.query.filter_by(username=data['username']).first():
                  api.abort(400, 'Username already exists')
              user = User(username=data['username'])
              user.set_password(data['password'])
              if data.get('role') in ['user', 'admin']:
                  user.role = data['role']
              db.session.add(user)
              db.session.commit()
              api.app.logger.info(f'User created: {user.username}')
              cache.delete('view/api/v1/users')
              return user_schema.dump(user), 201
      
      @api.route('/login')
      class Login(Resource):
          @api.doc('login_user')
          @api.expect(login_model)
          @limiter.limit("10 per minute")  # 每分鐘 10 次
          def post(self):
              """Login and get a JWT token"""
              if not request.is_json:
                  api.abort(400, 'Request must be JSON')
              data = request.get_json()
              if 'username' not in data or 'password' not in data:
                  api.abort(400, 'Missing username or password')
              user = User.query.filter_by(username=data['username']).first()
              if not user or not user.check_password(data['password']):
                  api.app.logger.warning(f'Failed login attempt for {data["username"]}')
                  api.abort(401, 'Invalid credentials')
              token = jwt.encode({
                  'user_id': user.id,
                  'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=24)
              }, api.app.config['SECRET_KEY'], algorithm='HS256')
              api.app.logger.info(f'User logged in: {user.username}')
              return {'token': token}
      
      @api.route('/<int:user_id>')
      class UserResource(Resource):
          @api.doc('delete_user', security='jwt')
          @api.marshal_with(user_model)
          @admin_required
          def delete(self, user_id):
              """Delete a user (admin only)"""
              user = User.query.get_or_404(user_id, description='User not found')
              db.session.delete(user)
              db.session.commit()
              api.app.logger.info(f'User deleted: {user.username} by admin {g.current_user.username}')
              cache.delete('view/api/v1/users')
              return {'message': 'User deleted'}, 200
      
      users_bp = api
      
    • 修改 app/routes/v1/posts.py,限制文章創建:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      
      from flask import request, g, abort
      from flask_restx import Namespace, Resource, fields
      from ...models import Post, User
      from ... import db, login_required, admin_required, cache, limiter
      from ...schemas import post_schema, posts_schema
      from ...tasks import send_email_notification
      
      api = Namespace('posts', description='Post management operations')
      
      post_model = api.model('Post', {
          'id': fields.Integer(readonly=True),
          'title': fields.String(required=True, description='The post title'),
          'content': fields.String(required=True, description='The post content'),
          'created_at': fields.DateTime(readonly=True),
          'user_id': fields.Integer(readonly=True),
          'category': fields.String(default='general')
      })
      
      @api.route('')
      class PostList(Resource):
          @api.doc('list_posts')
          @api.marshal_list_with(post_model)
          @cache.cached(timeout=60, query_string=True)
          def get(self):
              """List all posts"""
              user_id = request.args.get('user_id', type=int)
              query = Post.query
              if user_id:
                  query = query.filter_by(user_id=user_id)
              posts = query.all()
              api.app.logger.debug('Fetched posts')
              return posts_schema.dump(posts)
      
          @api.doc('create_post', security='jwt')
          @api.expect(post_model)
          @api.marshal_with(post_model, code=201)
          @login_required
          @limiter.limit("3 per minute")  # 每分鐘 3 次
          def post(self):
              """Create a new post"""
              if not request.is_json:
                  api.abort(400, 'Request must be JSON')
              data = request.get_json()
              if 'title' not in data or 'content' not in data:
                  api.abort(400, 'Missing title or content')
              post = Post(
                  title=data['title'],
                  content=data['content'],
                  user_id=g.current_user.id
              )
              db.session.add(post)
              db.session.commit()
              api.app.logger.info(f'Post created: {post.title} by {g.current_user.username}')
              send_email_notification.delay(g.current_user.id, post.title)
              api.app.logger.info(f'Queued email notification for {post.title}')
              cache.delete('view/api/v1/posts')
              return post_schema.dump(post), 201
      
      @api.route('/<int:post_id>')
      class PostResource(Resource):
          @api.doc('get_post')
          @api.marshal_with(post_model)
          @cache.cached(timeout=60)
          def get(self, post_id):
              """Get a single post"""
              post = Post.query.get_or_404(post_id, description='Post not found')
              return post_schema.dump(post)
      
          @api.doc('update_post', security='jwt')
          @api.expect(post_model)
          @api.marshal_with(post_model)
          @login_required
          def put(self, post_id):
              """Update a post"""
              post = Post.query.get_or_404(post_id, description='Post not found')
              if post.user_id != g.current_user.id and not g.current_user.is_admin:
                  api.abort(403, 'You can only edit your own posts unless you are an admin')
              if not request.is_json:
                  api.abort(400, 'Request must be JSON')
              data = request.get_json()
              if 'title' in data:
                  post.title = data['title']
              if 'content' in data:
                  post.content = data['content']
              db.session.commit()
              api.app.logger.info(f'Post updated: {post.title} by {g.current_user.username}')
              cache.delete(f'view/api/v1/posts/{post_id}')
              cache.delete('view/api/v1/posts')
              return post_schema.dump(post)
      
          @api.doc('delete_post', security='jwt')
          @login_required
          def delete(self, post_id):
              """Delete a post"""
              post = Post.query.get_or_404(post_id, description='Post not found')
              if post.user_id != g.current_user.id and not g.current_user.is_admin:
                  api.abort(403, 'You can only delete your own posts unless you are an admin')
              db.session.delete(post)
              db.session.commit()
              api.app.logger.info(f'Post deleted: {post.title} by {g.current_user.username}')
              cache.delete(f'view/api/v1/posts/{post_id}')
              cache.delete('view/api/v1/posts')
              return {'message': 'Post deleted'}, 200
      
      @api.route('/all')
      class PostAll(Resource):
          @api.doc('delete_all_posts', security='jwt')
          @admin_required
          def delete(self):
              """Delete all posts (admin only)"""
              Post.query.delete()
              db.session.commit()
              api.app.logger.warning(f'All posts deleted by admin {g.current_user.username}')
              cache.delete('view/api/v1/posts')
              return {'message': 'All posts deleted'}, 200
      
      posts_bp = api
      
  4. 運行應用

    • 運行:
      1
      
      python run.py
      
  5. 測試速率限制

    • 使用 Postman 測試:
      • POST /api/v1/users
        • 連續發送 6 次請求,前 5 次應返回 201,第 6 次應返回 429(Too Many Requests)。
      • POST /api/v1/login
        • 連續發送 11 次,前 10 次正常,第 11 次 429。
      • POST /api/v1/posts(用 token):
        • 連續發送 4 次,前 3 次 201,第 4 次 429。
    • 等待 1 分鐘後,限制應重置。
  6. 作業

    • 在 v2 的 posts 路由中添加速率限制,例如 GET /api/v2/posts 每分鐘 10 次。
    • 修改 Limiter 使用用戶 ID 作為 key(提示:自定義 key_func 檢查 JWT token)。

注意事項

  • 默認使用內存存儲限制計數,生產環境建議配置 Redis(limiter = Limiter(app, storage_uri="redis://localhost:6379"))。
  • 429 響應包含 Retry-After 頭,指示何時可重試。
  • 限制應根據實際需求調整,例如登錄可更寬鬆,創建內容更嚴格。

本文章以 CC BY 4.0 授權