文章

Flask - 異步任務

目標

  • 安裝並配置 Celery
  • 實現一個簡單的異步任務(模擬發送郵件)
  • 在 API 中調用異步任務

步驟

  1. 準備環境

    • 繼續使用 flask_api/ 項目結構,激活虛擬環境:
      1
      2
      
      # Windows: flask_api_env\Scripts\activate
      # macOS/Linux: source flask_api_env/bin/activate
      
    • 安裝 Celery 和 Redis(作為消息代理):
      1
      
      pip install celery redis
      
    • 安裝 Redis 服務:
      • Windows:下載 Redis(例如從 GitHub),解壓後運行 redis-server.exe
      • macOS/Linux:使用 brew install redissudo apt install redis-server,然後啟動 redis-server
  2. 配置 Celery

    • 新建 app/celery_config.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      
      from celery import Celery
      
      def make_celery(app):
          celery = Celery(
              app.import_name,
              backend='redis://localhost:6379/0',  # 結果存儲
              broker='redis://localhost:6379/0'   # 消息代理
          )
          celery.conf.update(app.config)
          return celery
      
    • 修改 app/init.py,整合 Celery:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      
      from flask import Flask, jsonify, g
      from flask_sqlalchemy import SQLAlchemy
      from flask_marshmallow import Marshmallow
      from flask_bcrypt import Bcrypt
      import jwt
      from functools import wraps
      from .routes.v1.todos import todos_bp as todos_v1_bp
      from .routes.v1.users import users_bp as users_v1_bp
      from .routes.v1.posts import posts_bp as posts_v1_bp
      from .routes.v2.todos import todos_bp as todos_v2_bp
      from .routes.v2.posts import posts_bp as posts_v2_bp
      from .config import config_map
      from .celery_config import make_celery  # 新增
      import os
      import logging
      from logging.handlers import RotatingFileHandler
      
      db = SQLAlchemy()
      ma = Marshmallow()
      bcrypt = Bcrypt()
      
      def setup_logging(app):
          if not app.debug:
              handler = RotatingFileHandler('app.log', maxBytes=10000, backupCount=3)
              handler.setLevel(logging.INFO)
              formatter = logging.Formatter(
                  '%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
              )
              handler.setFormatter(formatter)
              app.logger.addHandler(handler)
          console_handler = logging.StreamHandler()
          console_handler.setLevel(logging.DEBUG)
          console_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s'))
          app.logger.addHandler(console_handler)
          app.logger.setLevel(logging.DEBUG)
      
      def create_app():
          app = Flask(__name__)
          env = os.getenv('FLASK_ENV', 'development')
          app.config.from_object(config_map[env])
          db.init_app(app)
          ma.init_app(app)
          bcrypt.init_app(app)
          setup_logging(app)
      
          # 初始化 Celery
          global celery
          celery = make_celery(app)
      
          app.register_blueprint(todos_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(users_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(posts_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(todos_v2_bp, url_prefix='/api/v2')
          app.register_blueprint(posts_v2_bp, url_prefix='/api/v2')
      
          @app.errorhandler(404)
          def not_found(error):
              app.logger.error(f'404 error: {str(error)}')
              return jsonify({'error': 'Not Found', 'message': str(error)}), 404
      
          @app.errorhandler(400)
          def bad_request(error):
              app.logger.warning(f'400 error: {str(error)}')
              return jsonify({'error': 'Bad Request', 'message': str(error)}), 400
      
          @app.errorhandler(500)
          def internal_error(error):
              app.logger.critical(f'500 error: {str(error)}')
              return jsonify({'error': 'Internal Server Error', 'message': 'Something went wrong on our end'}), 500
      
          with app.app_context():
              db.create_all()
      
          return app
      
      def login_required(f):
          @wraps(f)
          def decorated_function(*args, **kwargs):
              from .models import User
              token = request.headers.get('Authorization')
              if not token:
                  abort(401, description='Missing token')
              try:
                  if token.startswith('Bearer '):
                      token = token[7:]
                  data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=['HS256'])
                  user = User.query.get(data['user_id'])
                  if not user:
                      abort(401, description='Invalid token')
                  g.current_user = user
              except jwt.ExpiredSignatureError:
                  abort(401, description='Token has expired')
              except jwt.InvalidTokenError:
                  abort(401, description='Invalid token')
              return f(*args, **kwargs)
          return decorated_function
      
      def admin_required(f):
          @wraps(f)
          @login_required
          def decorated_function(*args, **kwargs):
              if not g.current_user.is_admin:
                  abort(403, description='Admin access required')
              return f(*args, **kwargs)
          return decorated_function
      
      celery = None  # 在 create_app 中初始化
      
  3. 定義異步任務

    • 新建 app/tasks.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
      
      from . import celery
      import time
      
      @celery.task
      def send_email_notification(user_id, post_title):
          # 模擬發送郵件(這裡只是睡眠 5 秒)
          time.sleep(5)
          print(f"Email sent to user {user_id} for post '{post_title}'")
          return f"Notification for {post_title} completed"
      
  4. 在路由中使用異步任務

    • 修改 app/routes/v1/posts.py,在創建文章時觸發任務:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      
      from flask import Blueprint, jsonify, request, abort, g
      from ...models import Post, User
      from ... import db, login_required, admin_required
      from ...schemas import post_schema, posts_schema
      from ...tasks import send_email_notification
      
      posts_bp = Blueprint('posts_v1', __name__)
      
      @posts_bp.route('/posts', methods=['GET'])
      def get_posts():
          user_id = request.args.get('user_id', type=int)
          query = Post.query
          if user_id:
              query = query.filter_by(user_id=user_id)
          posts = query.all()
          posts_bp.app.logger.debug('Fetched posts')
          return jsonify({'posts': posts_schema.dump(posts)})
      
      @posts_bp.route('/posts/<int:post_id>', methods=['GET'])
      def get_post(post_id):
          post = Post.query.get_or_404(post_id, description='Post not found')
          return jsonify(post_schema.dump(post))
      
      @posts_bp.route('/posts', methods=['POST'])
      @login_required
      def create_post():
          if not request.is_json:
              abort(400, description='Request must be JSON')
          data = request.get_json()
          if 'title' not in data or 'content' not in data:
              abort(400, description='Missing title or content')
          post = Post(
              title=data['title'],
              content=data['content'],
              user_id=g.current_user.id
          )
          db.session.add(post)
          db.session.commit()
          posts_bp.app.logger.info(f'Post created: {post.title} by {g.current_user.username}')
      
          # 觸發異步任務
          send_email_notification.delay(g.current_user.id, post.title)
          posts_bp.app.logger.info(f'Queued email notification for {post.title}')
      
          return jsonify(post_schema.dump(post)), 201
      
      @posts_bp.route('/posts/<int:post_id>', methods=['PUT'])
      @login_required
      def update_post(post_id):
          post = Post.query.get_or_404(post_id, description='Post not found')
          if post.user_id != g.current_user.id and not g.current_user.is_admin:
              abort(403, description='You can only edit your own posts unless you are an admin')
          if not request.is_json:
              abort(400, description='Request must be JSON')
          data = request.get_json()
          if 'title' in data:
              post.title = data['title']
          if 'content' in data:
              post.content = data['content']
          db.session.commit()
          posts_bp.app.logger.info(f'Post updated: {post.title} by {g.current_user.username}')
          return jsonify(post_schema.dump(post)), 200
      
      @posts_bp.route('/posts/<int:post_id>', methods=['DELETE'])
      @login_required
      def delete_post(post_id):
          post = Post.query.get_or_404(post_id, description='Post not found')
          if post.user_id != g.current_user.id and not g.current_user.is_admin:
              abort(403, description='You can only delete your own posts unless you are an admin')
          db.session.delete(post)
          db.session.commit()
          posts_bp.app.logger.info(f'Post deleted: {post.title} by {g.current_user.username}')
          return jsonify({'message': 'Post deleted'}), 200
      
      @posts_bp.route('/posts/all', methods=['DELETE'])
      @admin_required
      def delete_all_posts():
          Post.query.delete()
          db.session.commit()
          posts_bp.app.logger.warning(f'All posts deleted by admin {g.current_user.username}')
          return jsonify({'message': 'All posts deleted'}), 200
      
  5. 啟動 Celery Worker

    • 在另一個終端窗口,進入項目根目錄,啟動 Celery:
      1
      
      celery -A app worker --loglevel=info
      
    • 注意:Windows 用戶可能需額外設置,例如:
      1
      
      celery -A app worker --loglevel=info --pool=solo
      
  6. 運行應用

    • 確保 Redis 正在運行,然後啟動 Flask:
      1
      
      python run.py
      
  7. 測試 API

    • 使用 Postman 測試:
      • POST /api/v1/users
        • Body:{"username": "alice", "password": "1234"}
      • POST /api/v1/login
        • Body:{"username": "alice", "password": "1234"}
        • 獲取 token。
      • POST /api/v1/posts
        • Headers:Authorization: Bearer <token>
        • Body:{"title": "My Post", "content": "Hello"}
        • 預期響應:立即返回 201。
        • 檢查 Celery 終端:約 5 秒後顯示 Email sent to user 1 for post 'My Post'
  8. 作業

    • 在 v2 的 posts 路由中添加一個異步任務,例如記錄文章創建到外部日誌文件。
    • 修改 send_email_notification,模擬失敗情況(例如隨機拋出異常),並觀察 Celery 的重試機制(提示:設置 max_retries)。

注意事項

  • 確保 Redis 運行在默認端口 6379,否則需調整 celery_config.py 中的地址。
  • Celery Worker 需單獨運行,且與 Flask 在同一虛擬環境中。
  • 實際應用中,應使用真正的郵件服務(如 smtplib)替代模擬。

本文章以 CC BY 4.0 授權