Flask - JWT 認證
目標
- 安裝並配置 PyJWT
- 實現基於 JWT 的認證
- 更新Blog API 使用 token 進行身份驗證
- 實現 token 刷新機制
- 實現 token 黑名單機制
步驟
準備環境
- 繼續使用
flask_api/
項目結構,激活虛擬環境:1 2
# Windows: flask_api_env\Scripts\activate # macOS/Linux: source flask_api_env/bin/activate
- 安裝 PyJWT:
1
pip install pyjwt
- 繼續使用
移除 Flask-HTTPAuth
- 我們將替換 Flask-HTTPAuth,因此先移除相關依賴,但保留密碼字段。
更新應用配置
修改 app/config.py,添加 JWT 相關配置:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
import os from dotenv import load_dotenv load_dotenv() class Config: SECRET_KEY = os.getenv('SECRET_KEY', 'default-secret-key') JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'jwt-secret-key') JWT_ACCESS_TOKEN_EXPIRES = int(os.getenv('JWT_ACCESS_TOKEN_EXPIRES', 3600)) # 1小時 JWT_REFRESH_TOKEN_EXPIRES = int(os.getenv('JWT_REFRESH_TOKEN_EXPIRES', 604800)) # 7天 SQLALCHEMY_TRACK_MODIFICATIONS = False JWT_BLACKLIST_ENABLED = True JWT_BLACKLIST_TOKEN_CHECKS = ['access', 'refresh'] class DevelopmentConfig(Config): DEBUG = True SQLALCHEMY_DATABASE_URI = os.getenv('DATABASE_URL', 'sqlite:///blog_dev.db') class TestingConfig(Config): TESTING = True SQLALCHEMY_DATABASE_URI = os.getenv('DATABASE_URL', 'sqlite:///blog_test.db') WTF_CSRF_ENABLED = False class ProductionConfig(Config): DEBUG = False SQLALCHEMY_DATABASE_URI = os.getenv('DATABASE_URL', 'sqlite:///blog_prod.db') config_map = { 'development': DevelopmentConfig, 'testing': TestingConfig, 'production': ProductionConfig }
實現 JWT 認證
修改 app/init.py,添加 JWT 工具函數:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
from flask import Flask, jsonify, g, request, current_app from flask_sqlalchemy import SQLAlchemy from flask_marshmallow import Marshmallow import jwt from functools import wraps from datetime import datetime, timedelta from .models import TokenBlacklist from .routes.v1.todos import todos_bp as todos_v1_bp from .routes.v1.users import users_bp as users_v1_bp from .routes.v1.posts import posts_bp as posts_v1_bp from .routes.v2.todos import todos_bp as todos_v2_bp from .routes.v2.posts import posts_bp as posts_v2_bp from .config import config_map import os db = SQLAlchemy() ma = Marshmallow() def create_app(): app = Flask(__name__) env = os.getenv('FLASK_ENV', 'development') app.config.from_object(config_map[env]) db.init_app(app) ma.init_app(app) app.register_blueprint(todos_v1_bp, url_prefix='/api/v1') app.register_blueprint(users_v1_bp, url_prefix='/api/v1') app.register_blueprint(posts_v1_bp, url_prefix='/api/v1') app.register_blueprint(todos_v2_bp, url_prefix='/api/v2') app.register_blueprint(posts_v2_bp, url_prefix='/api/v2') @app.errorhandler(404) def not_found(error): return jsonify({'error': 'Not Found', 'message': str(error)}), 404 @app.errorhandler(400) def bad_request(error): return jsonify({'error': 'Bad Request', 'message': str(error)}), 400 @app.errorhandler(500) def internal_error(error): return jsonify({'error': 'Internal Server Error', 'message': 'Something went wrong on our end'}), 500 with app.app_context(): db.create_all() return app def create_access_token(user_id): """創建訪問令牌""" return jwt.encode({ 'user_id': user_id, 'exp': datetime.utcnow() + timedelta(seconds=current_app.config['JWT_ACCESS_TOKEN_EXPIRES']), 'type': 'access' }, current_app.config['JWT_SECRET_KEY'], algorithm='HS256') def create_refresh_token(user_id): """創建刷新令牌""" return jwt.encode({ 'user_id': user_id, 'exp': datetime.utcnow() + timedelta(seconds=current_app.config['JWT_REFRESH_TOKEN_EXPIRES']), 'type': 'refresh' }, current_app.config['JWT_SECRET_KEY'], algorithm='HS256') def is_token_blacklisted(token): """檢查令牌是否在黑名單中""" return TokenBlacklist.query.filter_by(token=token).first() is not None def add_token_to_blacklist(token): """將令牌添加到黑名單""" blacklisted_token = TokenBlacklist(token=token) db.session.add(blacklisted_token) db.session.commit() # JWT 認證裝飾器 def login_required(f): @wraps(f) def decorated_function(*args, **kwargs): from .models import User token = request.headers.get('Authorization') if not token: abort(401, description='Missing token') try: if token.startswith('Bearer '): token = token[7:] # 移除 "Bearer " 前綴 # 檢查令牌是否在黑名單中 if is_token_blacklisted(token): abort(401, description='Token has been revoked') data = jwt.decode(token, current_app.config['JWT_SECRET_KEY'], algorithms=['HS256']) # 驗證令牌類型 if data.get('type') != 'access': abort(401, description='Invalid token type') user = User.query.get(data['user_id']) if not user: abort(401, description='Invalid token') g.current_user = user except jwt.ExpiredSignatureError: abort(401, description='Token has expired') except jwt.InvalidTokenError as e: abort(401, description=f'Invalid token: {str(e)}') return f(*args, **kwargs) return decorated_function
添加登錄端點
修改 app/routes/v1/users.py,添加生成 JWT 的登錄路由:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
from flask import Blueprint, jsonify, request, abort, current_app from ...models import User from ... import db from ...schemas import user_schema, users_schema from ... import create_access_token, create_refresh_token users_bp = Blueprint('users_v1', __name__) @users_bp.route('/users', methods=['GET']) def get_users(): users = User.query.all() return jsonify({'users': users_schema.dump(users)}) @users_bp.route('/users', methods=['POST']) def create_user(): if not request.is_json: abort(400, description='Request must be JSON') data = request.get_json() if 'username' not in data or 'password' not in data: abort(400, description='Missing username or password') if User.query.filter_by(username=data['username']).first(): abort(400, description='Username already exists') user = User(username=data['username'], password=data['password']) db.session.add(user) db.session.commit() return jsonify(user_schema.dump(user)), 201 @users_bp.route('/login', methods=['POST']) def login(): if not request.is_json: abort(400, description='Request must be JSON') data = request.get_json() if 'username' not in data or 'password' not in data: abort(400, description='Missing username or password') user = User.query.filter_by(username=data['username']).first() if not user or not user.check_password(data['password']): abort(401, description='Invalid credentials') access_token = create_access_token(user.id) refresh_token = create_refresh_token(user.id) return jsonify({ 'access_token': access_token, 'refresh_token': refresh_token, 'token_type': 'Bearer', 'expires_in': current_app.config['JWT_ACCESS_TOKEN_EXPIRES'] }) @users_bp.route('/refresh', methods=['POST']) def refresh(): if not request.is_json: abort(400, description='Request must be JSON') data = request.get_json() if 'refresh_token' not in data: abort(400, description='Missing refresh token') try: token = data['refresh_token'] if is_token_blacklisted(token): abort(401, description='Token has been revoked') data = jwt.decode(token, current_app.config['JWT_SECRET_KEY'], algorithms=['HS256']) if data.get('type') != 'refresh': abort(401, description='Invalid token type') user = User.query.get(data['user_id']) if not user: abort(401, description='Invalid token') access_token = create_access_token(user.id) return jsonify({ 'access_token': access_token, 'token_type': 'Bearer', 'expires_in': current_app.config['JWT_ACCESS_TOKEN_EXPIRES'] }) except jwt.ExpiredSignatureError: abort(401, description='Refresh token has expired') except jwt.InvalidTokenError as e: abort(401, description=f'Invalid token: {str(e)}') @users_bp.route('/logout', methods=['POST']) @login_required def logout(): token = request.headers.get('Authorization')[7:] # 移除 "Bearer " 前綴 add_token_to_blacklist(token) return jsonify({'message': 'Successfully logged out'})
保護路由
修改 app/routes/v1/posts.py,使用 JWT 認證:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
from flask import Blueprint, jsonify, request, abort, g from ...models import Post, User from ... import db, login_required from ...schemas import post_schema, posts_schema posts_bp = Blueprint('posts_v1', __name__) @posts_bp.route('/posts', methods=['GET']) def get_posts(): user_id = request.args.get('user_id', type=int) query = Post.query if user_id: query = query.filter_by(user_id=user_id) posts = query.all() return jsonify({'posts': posts_schema.dump(posts)}) @posts_bp.route('/posts/<int:post_id>', methods=['GET']) def get_post(post_id): post = Post.query.get_or_404(post_id, description='Post not found') return jsonify(post_schema.dump(post)) @posts_bp.route('/posts', methods=['POST']) @login_required def create_post(): if not request.is_json: abort(400, description='Request must be JSON') data = request.get_json() if 'title' not in data or 'content' not in data: abort(400, description='Missing title or content') post = Post( title=data['title'], content=data['content'], user_id=g.current_user.id ) db.session.add(post) db.session.commit() return jsonify(post_schema.dump(post)), 201 @posts_bp.route('/posts/<int:post_id>', methods=['PUT']) @login_required def update_post(post_id): post = Post.query.get_or_404(post_id, description='Post not found') if post.user_id != g.current_user.id: abort(403, description='You can only edit your own posts') if not request.is_json: abort(400, description='Request must be JSON') data = request.get_json() if 'title' in data: post.title = data['title'] if 'content' in data: post.content = data['content'] db.session.commit() return jsonify(post_schema.dump(post)) @posts_bp.route('/posts/<int:post_id>', methods=['DELETE']) @login_required def delete_post(post_id): post = Post.query.get_or_404(post_id, description='Post not found') if post.user_id != g.current_user.id: abort(403, description='You can only delete your own posts') db.session.delete(post) db.session.commit() return '', 204
添加 Token 黑名單模型
創建 app/models.py,添加 TokenBlacklist 模型:
1 2 3 4 5 6 7 8 9 10 11
from . import db from datetime import datetime class TokenBlacklist(db.Model): __tablename__ = 'token_blacklist' id = db.Column(db.Integer, primary_key=True) token = db.Column(db.String(500), unique=True, nullable=False) blacklisted_on = db.Column(db.DateTime, default=datetime.utcnow) def __repr__(self): return f'<Token {self.token}>'
運行應用
- 刪除舊的
blog.db
(因表結構改變),運行:1
python run.py
- 刪除舊的
測試 API
- 使用 Postman 測試:
- POST /api/v1/login:
- Body:
{"username": "alice", "password": "1234"}
- 預期響應:包含 access_token 和 refresh_token
- Body:
- POST /api/v1/refresh:
- Body:
{"refresh_token": "..."}
- 預期響應:新的 access_token
- Body:
- POST /api/v1/logout:
- Headers:
Authorization: Bearer <access_token>
- 預期響應:登出成功消息
- Headers:
- POST /api/v1/posts:
- Headers:
Authorization: Bearer <access_token>
- Body:
{"title": "My Post", "content": "Hello"}
- 預期響應:201
- Headers:
- POST /api/v1/login:
- 使用 Postman 測試:
安全最佳實踐
- 令牌管理
- 使用不同的密鑰進行簽名和驗證
- 實現令牌黑名單機制
- 設置合理的令牌過期時間
- 使用 HTTPS 傳輸令牌
- 錯誤處理
- 提供清晰的錯誤消息
- 記錄安全相關的錯誤
- 實現速率限制防止暴力破解
- 配置管理
- 使用環境變量存儲敏感信息
- 為不同環境使用不同的配置
- 定期輪換密鑰
- 監控與日誌
- 記錄認證相關的事件
- 監控異常的登錄嘗試
- 實現審計日誌
注意事項
- 生產環境應使用更強的密鑰
- 定期清理過期的黑名單令牌
- 考慮實現雙因素認證
- 實現密碼策略和強度檢查
- 考慮使用 OAuth2 或 OpenID Connect 進行第三方認證
本文章以 CC BY 4.0 授權