Flask - 緩存
目標
- 安裝並配置 Flask-Caching
- 為 GET 端點添加緩存
- 處理緩存失效
步驟
準備環境
- 繼續使用
flask_api/
項目結構,激活虛擬環境:1 2
# Windows: flask_api_env\Scripts\activate # macOS/Linux: source flask_api_env/bin/activate
- 安裝 Flask-Caching:
1
pip install flask-caching
- 繼續使用
配置 Flask-Caching
修改 app/init.py,初始化緩存:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
from flask import Flask, jsonify, g from flask_sqlalchemy import SQLAlchemy from flask_marshmallow import Marshmallow from flask_bcrypt import Bcrypt from flask_restx import Api from flask_caching import Cache # 新增 import jwt from functools import wraps from .routes.v1.todos import todos_bp as todos_v1_bp from .routes.v1.users import users_bp as users_v1_bp from .routes.v1.posts import posts_bp as posts_v1_bp from .routes.v2.todos import todos_bp as todos_v2_bp from .routes.v2.posts import posts_bp as posts_v2_bp from .config import config_map from .celery_config import make_celery import os import logging from logging.handlers import RotatingFileHandler db = SQLAlchemy() ma = Marshmallow() bcrypt = Bcrypt() cache = Cache() # 初始化緩存 def setup_logging(app): if not app.debug: handler = RotatingFileHandler('app.log', maxBytes=10000, backupCount=3) handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]' ) handler.setFormatter(formatter) app.logger.addHandler(handler) console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) console_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s')) app.logger.addHandler(console_handler) app.logger.setLevel(logging.DEBUG) def create_app(): app = Flask(__name__) env = os.getenv('FLASK_ENV', 'development') app.config.from_object(config_map[env]) # 配置緩存(使用簡單內存緩存) app.config['CACHE_TYPE'] = 'simple' # 可選:'redis'(需安裝 redis) app.config['CACHE_DEFAULT_TIMEOUT'] = 300 # 默認緩存 5 分鐘 db.init_app(app) ma.init_app(app) bcrypt.init_app(app) cache.init_app(app) # 初始化緩存 setup_logging(app) api = Api(app, title='Blog API', version='1.0', description='A simple blog API with user and post management', doc='/api/docs/', authorizations={ 'jwt': { 'type': 'apiKey', 'in': 'header', 'name': 'Authorization', 'description': 'Enter "Bearer <token>"' } }) global celery celery = make_celery(app) app.register_blueprint(todos_v1_bp, url_prefix='/api/v1') app.register_blueprint(users_v1_bp, url_prefix='/api/v1') app.register_blueprint(posts_v1_bp, url_prefix='/api/v1') app.register_blueprint(todos_v2_bp, url_prefix='/api/v2') app.register_blueprint(posts_v2_bp, url_prefix='/api/v2') @app.errorhandler(404) def not_found(error): app.logger.error(f'404 error: {str(error)}') return jsonify({'error': 'Not Found', 'message': str(error)}), 404 @app.errorhandler(400) def bad_request(error): app.logger.warning(f'400 error: {str(error)}') return jsonify({'error': 'Bad Request', 'message': str(error)}), 400 @app.errorhandler(500) def internal_error(error): app.logger.critical(f'500 error: {str(error)}') return jsonify({'error': 'Internal Server Error', 'message': 'Something went wrong on our end'}), 500 with app.app_context(): db.create_all() return app def login_required(f): @wraps(f) def decorated_function(*args, **kwargs): from .models import User token = request.headers.get('Authorization') if not token: abort(401, description='Missing token') try: if token.startswith('Bearer '): token = token[7:] data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=['HS256']) user = User.query.get(data['user_id']) if not user: abort(401, description='Invalid token') g.current_user = user except jwt.ExpiredSignatureError: abort(401, description='Token has expired') except jwt.InvalidTokenError: abort(401, description='Invalid token') return f(*args, **kwargs) return decorated_function def admin_required(f): @wraps(f) @login_required def decorated_function(*args, **kwargs): if not g.current_user.is_admin: abort(403, description='Admin access required') return f(*args, **kwargs) return decorated_function celery = None
為端點添加緩存
修改 app/routes/v1/users.py,緩存用戶列表:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
from flask import request, g, abort from flask_restx import Namespace, Resource, fields from ...models import User from ... import db, admin_required, cache from ...schemas import user_schema, users_schema import jwt import datetime api = Namespace('users', description='User management operations') user_model = api.model('User', { 'id': fields.Integer(readonly=True), 'username': fields.String(required=True, description='The user username'), 'password': fields.String(required=True, description='The user password'), 'role': fields.String(description='User role (user or admin)', default='user') }) login_model = api.model('Login', { 'username': fields.String(required=True), 'password': fields.String(required=True) }) @api.route('') class UserList(Resource): @api.doc('list_users') @api.marshal_list_with(user_model) @cache.cached(timeout=60) # 緩存 60 秒 def get(self): """List all users""" users = User.query.all() api.app.logger.info('Fetched all users') return users_schema.dump(users) @api.doc('create_user') @api.expect(user_model) @api.marshal_with(user_model, code=201) def post(self): """Create a new user""" if not request.is_json: api.abort(400, 'Request must be JSON') data = request.get_json() if 'username' not in data or 'password' not in data: api.abort(400, 'Missing username or password') if User.query.filter_by(username=data['username']).first(): api.abort(400, 'Username already exists') user = User(username=data['username']) user.set_password(data['password']) if data.get('role') in ['user', 'admin']: user.role = data['role'] db.session.add(user) db.session.commit() api.app.logger.info(f'User created: {user.username}') cache.delete('view/api/v1/users') # 清除緩存 return user_schema.dump(user), 201 @api.route('/login') class Login(Resource): @api.doc('login_user') @api.expect(login_model) def post(self): """Login and get a JWT token""" if not request.is_json: api.abort(400, 'Request must be JSON') data = request.get_json() if 'username' not in data or 'password' not in data: api.abort(400, 'Missing username or password') user = User.query.filter_by(username=data['username']).first() if not user or not user.check_password(data['password']): api.app.logger.warning(f'Failed login attempt for {data["username"]}') api.abort(401, 'Invalid credentials') token = jwt.encode({ 'user_id': user.id, 'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=24) }, api.app.config['SECRET_KEY'], algorithm='HS256') api.app.logger.info(f'User logged in: {user.username}') return {'token': token} @api.route('/<int:user_id>') class UserResource(Resource): @api.doc('delete_user', security='jwt') @api.marshal_with(user_model) @admin_required def delete(self, user_id): """Delete a user (admin only)""" user = User.query.get_or_404(user_id, description='User not found') db.session.delete(user) db.session.commit() api.app.logger.info(f'User deleted: {user.username} by admin {g.current_user.username}') cache.delete('view/api/v1/users') # 清除緩存 return {'message': 'User deleted'}, 200 users_bp = api
修改 app/routes/v1/posts.py,緩存文章列表和單篇文章:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
from flask import request, g, abort from flask_restx import Namespace, Resource, fields from ...models import Post, User from ... import db, login_required, admin_required, cache from ...schemas import post_schema, posts_schema from ...tasks import send_email_notification api = Namespace('posts', description='Post management operations') post_model = api.model('Post', { 'id': fields.Integer(readonly=True), 'title': fields.String(required=True, description='The post title'), 'content': fields.String(required=True, description='The post content'), 'created_at': fields.DateTime(readonly=True), 'user_id': fields.Integer(readonly=True), 'category': fields.String(default='general') }) @api.route('') class PostList(Resource): @api.doc('list_posts') @api.marshal_list_with(post_model) @cache.cached(timeout=60, query_string=True) # 緩存 60 秒,考慮查詢參數 def get(self): """List all posts""" user_id = request.args.get('user_id', type=int) query = Post.query if user_id: query = query.filter_by(user_id=user_id) posts = query.all() api.app.logger.debug('Fetched posts') return posts_schema.dump(posts) @api.doc('create_post', security='jwt') @api.expect(post_model) @api.marshal_with(post_model, code=201) @login_required def post(self): """Create a new post""" if not request.is_json: api.abort(400, 'Request must be JSON') data = request.get_json() if 'title' not in data or 'content' not in data: api.abort(400, 'Missing title or content') post = Post( title=data['title'], content=data['content'], user_id=g.current_user.id ) db.session.add(post) db.session.commit() api.app.logger.info(f'Post created: {post.title} by {g.current_user.username}') send_email_notification.delay(g.current_user.id, post.title) api.app.logger.info(f'Queued email notification for {post.title}') cache.delete('view/api/v1/posts') # 清除列表緩存 return post_schema.dump(post), 201 @api.route('/<int:post_id>') class PostResource(Resource): @api.doc('get_post') @api.marshal_with(post_model) @cache.cached(timeout=60) # 緩存單篇 60 秒 def get(self, post_id): """Get a single post""" post = Post.query.get_or_404(post_id, description='Post not found') return post_schema.dump(post) @api.doc('update_post', security='jwt') @api.expect(post_model) @api.marshal_with(post_model) @login_required def put(self, post_id): """Update a post""" post = Post.query.get_or_404(post_id, description='Post not found') if post.user_id != g.current_user.id and not g.current_user.is_admin: api.abort(403, 'You can only edit your own posts unless you are an admin') if not request.is_json: api.abort(400, 'Request must be JSON') data = request.get_json() if 'title' in data: post.title = data['title'] if 'content' in data: post.content = data['content'] db.session.commit() api.app.logger.info(f'Post updated: {post.title} by {g.current_user.username}') cache.delete(f'view/api/v1/posts/{post_id}') # 清除單篇緩存 cache.delete('view/api/v1/posts') # 清除列表緩存 return post_schema.dump(post) @api.doc('delete_post', security='jwt') @login_required def delete(self, post_id): """Delete a post""" post = Post.query.get_or_404(post_id, description='Post not found') if post.user_id != g.current_user.id and not g.current_user.is_admin: api.abort(403, 'You can only delete your own posts unless you are an admin') db.session.delete(post) db.session.commit() api.app.logger.info(f'Post deleted: {post.title} by {g.current_user.username}') cache.delete(f'view/api/v1/posts/{post_id}') cache.delete('view/api/v1/posts') return {'message': 'Post deleted'}, 200 @api.route('/all') class PostAll(Resource): @api.doc('delete_all_posts', security='jwt') @admin_required def delete(self): """Delete all posts (admin only)""" Post.query.delete() db.session.commit() api.app.logger.warning(f'All posts deleted by admin {g.current_user.username}') cache.delete('view/api/v1/posts') return {'message': 'All posts deleted'}, 200 posts_bp = api
運行應用
- 運行:
1
python run.py
- 運行:
測試緩存
- 使用 Postman 測試:
- GET /api/v1/users:
- 第一次請求:查詢數據庫,記錄時間。
- 第二次請求(60 秒內):應更快(緩存命中)。
- POST /api/v1/users:
- 創建用戶後再次 GET,應返回最新數據(緩存已清除)。
- GET /api/v1/posts:
- 重複請求應使用緩存。
- POST /api/v1/posts(用 token):
- 創建後 GET,應反映新數據。
- GET /api/v1/users:
- 使用 Postman 測試:
作業
- 在 v2 的 posts 路由中添加緩存,支持
category
查詢參數。 - 使用 Redis 作為緩存後端(修改
CACHE_TYPE
為 ‘redis’,配置 Redis URL)。
- 在 v2 的 posts 路由中添加緩存,支持
注意事項
@cache.cached
的timeout
單位是秒,根據需求調整。query_string=True
確保不同查詢參數生成不同緩存。- 數據更新時必須清除相關緩存,否則返回過期數據。
- 生產環境建議使用 Redis 替代簡單內存緩存。
本文章以 CC BY 4.0 授權