文章

Flask - 測試

目標

  • 配置測試環境
  • 編寫針對用戶和文章端點的單元測試
  • 運行測試並檢查覆蓋率

步驟

  1. 準備環境

    • 繼續使用 flask_api/ 項目結構,激活虛擬環境:
      1
      2
      
      # Windows: flask_api_env\Scripts\activate
      # macOS/Linux: source flask_api_env/bin/activate
      
    • 安裝測試相關依賴:
      1
      
      pip install coverage
      
  2. 設置測試文件結構

    • 在項目根目錄下創建 tests/ 目錄:
      1
      2
      3
      4
      5
      6
      7
      
      flask_api/
      ├── app/
      ├── tests/
      │   ├── __init__.py
      │   ├── test_users.py
      │   └── test_posts.py
      └── run.py
      
  3. 配置測試基類

    • tests/init.py 中留空(作為包標識)。
    • 新建 tests/test_users.py,設置測試基類並測試用戶相關功能:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      
      import unittest
      from app import create_app, db
      from app.models import User
      
      class BaseTestCase(unittest.TestCase):
          def setUp(self):
              self.app = create_app()
              self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///test.db'
              self.app.config['TESTING'] = True
              self.client = self.app.test_client()
              with self.app.app_context():
                  db.create_all()
      
          def tearDown(self):
              with self.app.app_context():
                  db.session.remove()
                  db.drop_all()
      
          def get_headers(self, token=None):
              headers = {'Content-Type': 'application/json'}
              if token:
                  headers['Authorization'] = f'Bearer {token}'
              return headers
      
      class TestUsers(BaseTestCase):
          def test_create_user(self):
              response = self.client.post('/api/v1/users',
                                        json={'username': 'alice', 'password': '1234'},
                                        headers=self.get_headers())
              self.assertEqual(response.status_code, 201)
              data = response.get_json()
              self.assertEqual(data['username'], 'alice')
              self.assertEqual(data['role'], 'user')
      
          def test_create_duplicate_user(self):
              self.client.post('/api/v1/users',
                             json={'username': 'alice', 'password': '1234'},
                             headers=self.get_headers())
              response = self.client.post('/api/v1/users',
                                        json={'username': 'alice', 'password': '5678'},
                                        headers=self.get_headers())
              self.assertEqual(response.status_code, 400)
      
          def test_login(self):
              self.client.post('/api/v1/users',
                             json={'username': 'alice', 'password': '1234'},
                             headers=self.get_headers())
              response = self.client.post('/api/v1/login',
                                        json={'username': 'alice', 'password': '1234'},
                                        headers=self.get_headers())
              self.assertEqual(response.status_code, 200)
              data = response.get_json()
              self.assertIn('token', data)
      
          def test_login_invalid(self):
              response = self.client.post('/api/v1/login',
                                        json={'username': 'alice', 'password': 'wrong'},
                                        headers=self.get_headers())
              self.assertEqual(response.status_code, 401)
      
      if __name__ == '__main__':
          unittest.main()
      
  4. 測試文章端點

    • 新建 tests/test_posts.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      
      import unittest
      from app import create_app, db
      from app.models import User, Post
      import jwt
      import datetime
      
      class TestPosts(BaseTestCase):
          def setUp(self):
              super().setUp()
              with self.app.app_context():
                  # 創建測試用戶
                  user = User(username='alice')
                  user.set_password('1234')
                  db.session.add(user)
                  db.session.commit()
                  self.user_id = user.id
                  self.token = jwt.encode({
                      'user_id': user.id,
                      'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=24)
                  }, self.app.config['SECRET_KEY'], algorithm='HS256')
      
          def test_create_post(self):
              response = self.client.post('/api/v1/posts',
                                        json={'title': 'Test Post', 'content': 'Hello'},
                                        headers=self.get_headers(self.token))
              self.assertEqual(response.status_code, 201)
              data = response.get_json()
              self.assertEqual(data['title'], 'Test Post')
              self.assertEqual(data['user_id'], self.user_id)
      
          def test_create_post_unauthorized(self):
              response = self.client.post('/api/v1/posts',
                                        json={'title': 'Test Post', 'content': 'Hello'},
                                        headers=self.get_headers())
              self.assertEqual(response.status_code, 401)
      
          def test_update_post(self):
              # 先創建一篇帖子
              response = self.client.post('/api/v1/posts',
                                        json={'title': 'Test Post', 'content': 'Hello'},
                                        headers=self.get_headers(self.token))
              post_id = response.get_json()['id']
      
              # 更新帖子
              response = self.client.put(f'/api/v1/posts/{post_id}',
                                       json={'title': 'Updated Post'},
                                       headers=self.get_headers(self.token))
              self.assertEqual(response.status_code, 200)
              data = response.get_json()
              self.assertEqual(data['title'], 'Updated Post')
      
          def test_delete_post(self):
              response = self.client.post('/api/v1/posts',
                                        json={'title': 'Test Post', 'content': 'Hello'},
                                        headers=self.get_headers(self.token))
              post_id = response.get_json()['id']
      
              response = self.client.delete(f'/api/v1/posts/{post_id}',
                                          headers=self.get_headers(self.token))
              self.assertEqual(response.status_code, 200)
              self.assertEqual(response.get_json()['message'], 'Post deleted')
      
      if __name__ == '__main__':
          unittest.main()
      
  5. 運行測試

    • 在項目根目錄下運行單個測試文件:
      1
      2
      
      python -m unittest tests/test_users.py
      python -m unittest tests/test_posts.py
      
    • 運行所有測試:
      1
      
      python -m unittest discover -s tests
      
    • 使用 coverage 檢查測試覆蓋率:
      1
      2
      
      coverage run -m unittest discover -s tests
      coverage report -m
      
  6. 測試結果

    • 確保所有測試通過(輸出顯示 OK)。
    • coverage report 會顯示代碼覆蓋率,例如:
      1
      2
      3
      4
      5
      6
      
      Name                    Stmts   Miss  Cover   Missing
      -----------------------------------------------------
      app/__init__.py           60     20    67%   20-30, 40-50
      app/models.py             25      5    80%   15-20
      app/routes/v1/posts.py    50     10    80%   35-45
      ...
      
  7. 作業

    • 為 v2 的 posts 路由編寫測試,涵蓋 category 過濾功能。
    • 添加一個測試用例,驗證管理员可以刪除他人文章(提示:創建兩個用戶,一個設為 admin)。

注意事項

  • setUptearDown 確保每次測試使用乾淨的數據庫。
  • 測試數據庫使用獨立的 test.db,避免影響開發數據。
  • coverage 需關注未覆蓋的代碼,可能是邏輯分支未測試。

本文章以 CC BY 4.0 授權