文章

Flask - 文件上傳

目標

  • 配置文件上傳路徑和限制
  • 添加上傳圖片的端點
  • 將文件與文章關聯

步驟

  1. 準備環境

    • 繼續使用 flask_api/ 項目結構,激活虛擬環境:
      1
      2
      
      # Windows: flask_api_env\Scripts\activate
      # macOS/Linux: source flask_api_env/bin/activate
      
  2. 配置文件上傳

    • 修改 app/init.py,添加文件上傳配置:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      130
      131
      132
      133
      134
      135
      136
      
      from flask import Flask, jsonify, g
      from flask_sqlalchemy import SQLAlchemy
      from flask_marshmallow import Marshmallow
      from flask_bcrypt import Bcrypt
      from flask_restx import Api
      from flask_caching import Cache
      from flask_limiter import Limiter
      from flask_limiter.util import get_remote_address
      import jwt
      from functools import wraps
      from .routes.v1.todos import todos_bp as todos_v1_bp
      from .routes.v1.users import users_bp as users_v1_bp
      from .routes.v1.posts import posts_bp as posts_v1_bp
      from .routes.v2.todos import todos_bp as todos_v2_bp
      from .routes.v2.posts import posts_bp as posts_v2_bp
      from .config import config_map
      from .celery_config import make_celery
      import os
      import logging
      from logging.handlers import RotatingFileHandler
      
      db = SQLAlchemy()
      ma = Marshmallow()
      bcrypt = Bcrypt()
      cache = Cache()
      limiter = Limiter(key_func=get_remote_address)
      
      def setup_logging(app):
          if not app.debug:
              handler = RotatingFileHandler('app.log', maxBytes=10000, backupCount=3)
              handler.setLevel(logging.INFO)
              formatter = logging.Formatter(
                  '%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]'
              )
              handler.setFormatter(formatter)
              app.logger.addHandler(handler)
          console_handler = logging.StreamHandler()
          console_handler.setLevel(logging.DEBUG)
          console_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s'))
          app.logger.addHandler(console_handler)
          app.logger.setLevel(logging.DEBUG)
      
      def create_app():
          app = Flask(__name__)
          env = os.getenv('FLASK_ENV', 'development')
          app.config.from_object(config_map[env])
          app.config['CACHE_TYPE'] = 'simple'
          app.config['CACHE_DEFAULT_TIMEOUT'] = 300
      
          # 文件上傳配置
          app.config['UPLOAD_FOLDER'] = os.path.join(app.root_path, 'uploads')
          app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 最大 16MB
          os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)  # 創建上傳目錄
      
          db.init_app(app)
          ma.init_app(app)
          bcrypt.init_app(app)
          cache.init_app(app)
          limiter.init_app(app)
          setup_logging(app)
      
          api = Api(app,
                    title='Blog API',
                    version='1.0',
                    description='A simple blog API with user and post management',
                    doc='/api/docs/',
                    authorizations={
                        'jwt': {
                            'type': 'apiKey',
                            'in': 'header',
                            'name': 'Authorization',
                            'description': 'Enter "Bearer <token>"'
                        }
                    })
      
          global celery
          celery = make_celery(app)
      
          app.register_blueprint(todos_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(users_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(posts_v1_bp, url_prefix='/api/v1')
          app.register_blueprint(todos_v2_bp, url_prefix='/api/v2')
          app.register_blueprint(posts_v2_bp, url_prefix='/api/v2')
      
          @app.errorhandler(404)
          def not_found(error):
              app.logger.error(f'404 error: {str(error)}')
              return jsonify({'error': 'Not Found', 'message': str(error)}), 404
      
          @app.errorhandler(400)
          def bad_request(error):
              app.logger.warning(f'400 error: {str(error)}')
              return jsonify({'error': 'Bad Request', 'message': str(error)}), 400
      
          @app.errorhandler(500)
          def internal_error(error):
              app.logger.critical(f'500 error: {str(error)}')
              return jsonify({'error': 'Internal Server Error', 'message': 'Something went wrong on our end'}), 500
      
          with app.app_context():
              db.create_all()
      
          return app
      
      def login_required(f):
          @wraps(f)
          def decorated_function(*args, **kwargs):
              from .models import User
              token = request.headers.get('Authorization')
              if not token:
                  abort(401, description='Missing token')
              try:
                  if token.startswith('Bearer '):
                      token = token[7:]
                  data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=['HS256'])
                  user = User.query.get(data['user_id'])
                  if not user:
                      abort(401, description='Invalid token')
                  g.current_user = user
              except jwt.ExpiredSignatureError:
                  abort(401, description='Token has expired')
              except jwt.InvalidTokenError:
                  abort(401, description='Invalid token')
              return f(*args, **kwargs)
          return decorated_function
      
      def admin_required(f):
          @wraps(f)
          @login_required
          def decorated_function(*args, **kwargs):
              if not g.current_user.is_admin:
                  abort(403, description='Admin access required')
              return f(*args, **kwargs)
          return decorated_function
      
      celery = None
      
  3. 更新模型和序列化器

    • 修改 app/models.py,添加圖片字段:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      
      from . import db, bcrypt
      from datetime import datetime
      
      class User(db.Model):
          __tablename__ = 'users'
          id = db.Column(db.Integer, primary_key=True)
          username = db.Column(db.String(50), unique=True, nullable=False)
          password_hash = db.Column(db.String(128), nullable=False)
          role = db.Column(db.String(20), default='user')
          posts = db.relationship('Post', backref='user', lazy=True)
      
          def set_password(self, password):
              self.password_hash = bcrypt.generate_password_hash(password).decode('utf-8')
      
          def check_password(self, password):
              return bcrypt.check_password_hash(self.password_hash, password)
      
          @property
          def is_admin(self):
              return self.role == 'admin'
      
      class Post(db.Model):
          __tablename__ = 'posts'
          id = db.Column(db.Integer, primary_key=True)
          title = db.Column(db.String(100), nullable=False)
          content = db.Column(db.Text, nullable=False)
          created_at = db.Column(db.DateTime, default=datetime.utcnow)
          user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
          category = db.Column(db.String(50), default='general')
          image_path = db.Column(db.String(255))  # 新增圖片路徑字段
      
    • 修改 app/schemas.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      
      from . import ma
      from .models import User, Post
      
      class PostSchema(ma.SQLAlchemyAutoSchema):
          class Meta:
              model = Post
              include_fk = True
              fields = ('id', 'title', 'content', 'created_at', 'user_id', 'category', 'image_path')
      
      class UserSchema(ma.SQLAlchemyAutoSchema):
          class Meta:
              model = User
              fields = ('id', 'username', 'role', 'posts')
          posts = ma.Nested('PostSchema', many=True)
      
      post_schema = PostSchema()
      posts_schema = PostSchema(many=True)
      user_schema = UserSchema()
      users_schema = UserSchema(many=True)
      
  4. 實現文件上傳端點

    • 修改 app/routes/v1/posts.py,支持圖片上傳:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      102
      103
      104
      105
      106
      107
      108
      109
      110
      111
      112
      113
      114
      115
      116
      117
      118
      119
      120
      121
      122
      123
      124
      125
      126
      127
      128
      129
      130
      131
      132
      133
      134
      135
      136
      137
      138
      139
      140
      141
      142
      143
      144
      145
      146
      
      from flask import request, g, abort
      from flask_restx import Namespace, Resource, fields
      from ...models import Post, User
      from ... import db, login_required, admin_required, cache, limiter
      from ...schemas import post_schema, posts_schema
      from ...tasks import send_email_notification
      import os
      import uuid
      
      api = Namespace('posts', description='Post management operations')
      
      post_model = api.model('Post', {
          'id': fields.Integer(readonly=True),
          'title': fields.String(required=True, description='The post title'),
          'content': fields.String(required=True, description='The post content'),
          'created_at': fields.DateTime(readonly=True),
          'user_id': fields.Integer(readonly=True),
          'category': fields.String(default='general'),
          'image_path': fields.String(readonly=True, description='Path to uploaded image')
      })
      
      upload_parser = api.parser()
      upload_parser.add_argument('title', type=str, required=True, help='Post title')
      upload_parser.add_argument('content', type=str, required=True, help='Post content')
      upload_parser.add_argument('image', type='file', location='files', help='Image file')
      
      def allowed_file(filename):
          return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
      
      @api.route('')
      class PostList(Resource):
          @api.doc('list_posts')
          @api.marshal_list_with(post_model)
          @cache.cached(timeout=60, query_string=True)
          def get(self):
              """List all posts"""
              user_id = request.args.get('user_id', type=int)
              query = Post.query
              if user_id:
                  query = query.filter_by(user_id=user_id)
              posts = query.all()
              api.app.logger.debug('Fetched posts')
              return posts_schema.dump(posts)
      
          @api.doc('create_post', security='jwt')
          @api.expect(upload_parser)
          @api.marshal_with(post_model, code=201)
          @login_required
          @limiter.limit("3 per minute")
          def post(self):
              """Create a new post with optional image"""
              args = upload_parser.parse_args()
              title = args['title']
              content = args['content']
              image = args['image']
      
              post = Post(
                  title=title,
                  content=content,
                  user_id=g.current_user.id
              )
      
              if image and allowed_file(image.filename):
                  filename = f"{uuid.uuid4().hex}.{image.filename.rsplit('.', 1)[1].lower()}"
                  image.save(os.path.join(api.app.config['UPLOAD_FOLDER'], filename))
                  post.image_path = f"/uploads/{filename}"
                  api.app.logger.info(f'Image uploaded: {filename}')
      
              db.session.add(post)
              db.session.commit()
              api.app.logger.info(f'Post created: {post.title} by {g.current_user.username}')
              send_email_notification.delay(g.current_user.id, post.title)
              api.app.logger.info(f'Queued email notification for {post.title}')
              cache.delete('view/api/v1/posts')
              return post_schema.dump(post), 201
      
      @api.route('/<int:post_id>')
      class PostResource(Resource):
          @api.doc('get_post')
          @api.marshal_with(post_model)
          @cache.cached(timeout=60)
          def get(self, post_id):
              """Get a single post"""
              post = Post.query.get_or_404(post_id, description='Post not found')
              return post_schema.dump(post)
      
          @api.doc('update_post', security='jwt')
          @api.expect(post_model)
          @api.marshal_with(post_model)
          @login_required
          def put(self, post_id):
              """Update a post"""
              post = Post.query.get_or_404(post_id, description='Post not found')
              if post.user_id != g.current_user.id and not g.current_user.is_admin:
                  api.abort(403, 'You can only edit your own posts unless you are an admin')
              if not request.is_json:
                  api.abort(400, 'Request must be JSON')
              data = request.get_json()
              if 'title' in data:
                  post.title = data['title']
              if 'content' in data:
                  post.content = data['content']
              db.session.commit()
              api.app.logger.info(f'Post updated: {post.title} by {g.current_user.username}')
              cache.delete(f'view/api/v1/posts/{post_id}')
              cache.delete('view/api/v1/posts')
              return post_schema.dump(post)
      
          @api.doc('delete_post', security='jwt')
          @login_required
          def delete(self, post_id):
              """Delete a post"""
              post = Post.query.get_or_404(post_id, description='Post not found')
              if post.user_id != g.current_user.id and not g.current_user.is_admin:
                  api.abort(403, 'You can only delete your own posts unless you are an admin')
              if post.image_path:
                  image_path = os.path.join(api.app.root_path, post.image_path.lstrip('/'))
                  if os.path.exists(image_path):
                      os.remove(image_path)
                      api.app.logger.info(f'Image deleted: {post.image_path}')
              db.session.delete(post)
              db.session.commit()
              api.app.logger.info(f'Post deleted: {post.title} by {g.current_user.username}')
              cache.delete(f'view/api/v1/posts/{post_id}')
              cache.delete('view/api/v1/posts')
              return {'message': 'Post deleted'}, 200
      
      @api.route('/all')
      class PostAll(Resource):
          @api.doc('delete_all_posts', security='jwt')
          @admin_required
          def delete(self):
              """Delete all posts (admin only)"""
              posts = Post.query.all()
              for post in posts:
                  if post.image_path:
                      image_path = os.path.join(api.app.root_path, post.image_path.lstrip('/'))
                      if os.path.exists(image_path):
                          os.remove(image_path)
              Post.query.delete()
              db.session.commit()
              api.app.logger.warning(f'All posts deleted by admin {g.current_user.username}')
              cache.delete('view/api/v1/posts')
              return {'message': 'All posts deleted'}, 200
      
      posts_bp = api
      
  5. 提供靜態文件訪問

    • 修改 run.py,添加靜態文件路由:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      
      from app import create_app
      
      app = create_app()
      
      @app.route('/uploads/<filename>')
      def uploaded_file(filename):
          return app.send_static_file(os.path.join('uploads', filename))
      
      if __name__ == '__main__':
          app.run(debug=True)
      
    • uploads/ 目錄設置為靜態文件路徑,在 app/init.py 中添加:

      1
      
      app.static_folder = app.root_path
      
  6. 運行應用

    • 運行:
      1
      
      python run.py
      
  7. 測試文件上傳

    • 使用 Postman 測試:
      • POST /api/v1/posts
        • Headers:Authorization: Bearer <token>
        • Body:選擇 form-data
          • title"My Post"
          • content"With image"
          • image:選擇一個圖片文件(例如 test.jpg
        • 預期響應:201,包含 image_path(如 /uploads/abc123.jpg)。
      • **GET /uploads/**:
        • 訪問返回的圖片路徑,應顯示圖片。
      • **DELETE /api/v1/posts/**:
        • 刪除帖子後檢查圖片文件是否被移除。
  8. 作業

    • 在 v2 的 posts 路由中添加文件上傳功能,支持多個圖片(提示:使用 request.files.getlist)。
    • 添加文件類型和大小的更嚴格驗證(例如限制大小到 5MB)。

注意事項

  • MAX_CONTENT_LENGTH 限制總請求大小,超過會返回 413。
  • 文件名使用 UUID 避免重名。
  • 生產環境應使用雲存儲(如 AWS S3)替代本地文件系統。
  • 表結構改變後需重建數據庫(刪除舊的 blog.db)。

本文章以 CC BY 4.0 授權