文章

AI 第22天:生成對抗網路(GANs)

生成對抗網路(Generative Adversarial Networks, GANs)是深度學習中一個令人興奮的領域。它由兩個神經網路——生成器(Generator)和判別器(Discriminator)——互相競爭並共同訓練,最終實現生成高品質的數據(如圖像或文本)。今天,我們將探討 GANs 的核心概念,並實作一個基礎的手寫數字生成模型。


課程目標

  1. 理解生成對抗網路的工作原理與架構。
  2. 學習如何構建一個簡單的 GAN,生成手寫數字圖片。
  3. 初步了解 GANs 的訓練挑戰與優化策略。

課程內容

1. GAN 的基本架構

1.1 核心概念

  • 生成器(Generator):接收隨機噪聲作為輸入,生成類似真實數據的偽造數據。
  • 判別器(Discriminator):對輸入數據進行判斷,分辨是真實數據還是生成數據。
  • 生成器和判別器互相競爭,生成器希望騙過判別器,而判別器則不斷提升自己的判斷能力。

1.2 搭配損失函數

  • 判別器的損失函數:將真實數據分類為真,將生成數據分類為假。
  • 生成器的損失函數:讓生成的數據更像真實數據,以騙過判別器。

2. 數據集準備

我們將使用 MNIST 數據集(手寫數字)作為基礎數據集。

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
from tensorflow.keras.datasets import mnist

# 載入數據集
(x_train, _), (_, _) = mnist.load_data()

# 將數據標準化到 [-1, 1] 區間,並展平
x_train = (x_train.astype("float32") - 127.5) / 127.5
x_train = x_train.reshape(x_train.shape[0], -1)

print(f"訓練數據形狀: {x_train.shape}")

3. 架構生成器與判別器

3.1 生成器模型

生成器將隨機噪聲轉換為逼真的手寫數字圖片。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Reshape

def build_generator(latent_dim):
    model = Sequential([
        Dense(256, input_dim=latent_dim),
        LeakyReLU(alpha=0.2),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        Dense(28 * 28, activation="tanh"),
        Reshape((28, 28))
    ])
    return model

latent_dim = 100
generator = build_generator(latent_dim)
generator.summary()

3.2 判別器模型

判別器將輸入圖片分類為真或假。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from tensorflow.keras.layers import Flatten, Dropout
from tensorflow.keras.optimizers import Adam

def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dropout(0.4),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dropout(0.4),
        Dense(1, activation="sigmoid")
    ])
    return model

discriminator = build_discriminator()
discriminator.compile(optimizer=Adam(0.0002, 0.5), loss="binary_crossentropy", metrics=["accuracy"])
discriminator.summary()

4. 搭建 GAN 模型

將生成器與判別器結合,形成完整的 GAN。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from tensorflow.keras.models import Model

# 冷凍判別器權重(僅訓練生成器)
discriminator.trainable = False

# GAN 模型
gan_input = tf.keras.Input(shape=(latent_dim,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = Model(gan_input, gan_output)
gan.compile(optimizer=Adam(0.0002, 0.5), loss="binary_crossentropy")

gan.summary()

5. 訓練 GAN

5.1 訓練流程

  1. 使用生成器生成一批假的手寫數字圖片。
  2. 訓練判別器,分辨真實圖片與生成圖片。
  3. 訓練 GAN,優化生成器的權重以騙過判別器。

5.2 訓練代碼

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np

def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        # 隨機抽取真實數據
        idx = np.random.randint(0, x_train.shape[0], batch_size)
        real_images = x_train[idx]

        # 生成假圖片
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        fake_images = generator.predict(noise)

        # 訓練判別器
        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))
        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 訓練生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        valid_labels = np.ones((batch_size, 1))
        g_loss = gan.train_on_batch(noise, valid_labels)

        # 每 100 回合打印損失
        if epoch % 100 == 0:
            print(f"Epoch {epoch}: D_loss={d_loss[0]}, G_loss={g_loss}")

# 開始訓練
train_gan(epochs=10000, batch_size=64)

6. 生成圖像與可視化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt

def generate_images(num_images=10):
    noise = np.random.normal(0, 1, (num_images, latent_dim))
    generated_images = generator.predict(noise)
    generated_images = (generated_images + 1) / 2.0  # 恢復至 [0, 1] 範圍

    plt.figure(figsize=(10, 2))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(generated_images[i], cmap="gray")
        plt.axis("off")
    plt.show()

generate_images()

課後作業

  1. 調整生成器和判別器的架構,觀察對生成結果的影響。
  2. 使用更高解析度的數據集,如 CIFAR-10,構建彩色圖像生成模型。
  3. 研究改進 GAN 的技術,如條件生成對抗網路(Conditional GAN)或 WGAN。

本文章以 CC BY 4.0 授權