AI 第22天:生成對抗網路(GANs)
生成對抗網路(Generative Adversarial Networks, GANs)是深度學習中一個令人興奮的領域。它由兩個神經網路——生成器(Generator)和判別器(Discriminator)——互相競爭並共同訓練,最終實現生成高品質的數據(如圖像或文本)。今天,我們將探討 GANs 的核心概念,並實作一個基礎的手寫數字生成模型。
課程目標
- 理解生成對抗網路的工作原理與架構。
- 學習如何構建一個簡單的 GAN,生成手寫數字圖片。
- 初步了解 GANs 的訓練挑戰與優化策略。
課程內容
1. GAN 的基本架構
1.1 核心概念
- 生成器(Generator):接收隨機噪聲作為輸入,生成類似真實數據的偽造數據。
- 判別器(Discriminator):對輸入數據進行判斷,分辨是真實數據還是生成數據。
- 生成器和判別器互相競爭,生成器希望騙過判別器,而判別器則不斷提升自己的判斷能力。
1.2 搭配損失函數
- 判別器的損失函數:將真實數據分類為真,將生成數據分類為假。
- 生成器的損失函數:讓生成的數據更像真實數據,以騙過判別器。
2. 數據集準備
我們將使用 MNIST 數據集(手寫數字)作為基礎數據集。
1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 載入數據集
(x_train, _), (_, _) = mnist.load_data()
# 將數據標準化到 [-1, 1] 區間,並展平
x_train = (x_train.astype("float32") - 127.5) / 127.5
x_train = x_train.reshape(x_train.shape[0], -1)
print(f"訓練數據形狀: {x_train.shape}")
3. 架構生成器與判別器
3.1 生成器模型
生成器將隨機噪聲轉換為逼真的手寫數字圖片。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Reshape
def build_generator(latent_dim):
model = Sequential([
Dense(256, input_dim=latent_dim),
LeakyReLU(alpha=0.2),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(1024),
LeakyReLU(alpha=0.2),
Dense(28 * 28, activation="tanh"),
Reshape((28, 28))
])
return model
latent_dim = 100
generator = build_generator(latent_dim)
generator.summary()
3.2 判別器模型
判別器將輸入圖片分類為真或假。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from tensorflow.keras.layers import Flatten, Dropout
from tensorflow.keras.optimizers import Adam
def build_discriminator():
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(512),
LeakyReLU(alpha=0.2),
Dropout(0.4),
Dense(256),
LeakyReLU(alpha=0.2),
Dropout(0.4),
Dense(1, activation="sigmoid")
])
return model
discriminator = build_discriminator()
discriminator.compile(optimizer=Adam(0.0002, 0.5), loss="binary_crossentropy", metrics=["accuracy"])
discriminator.summary()
4. 搭建 GAN 模型
將生成器與判別器結合,形成完整的 GAN。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from tensorflow.keras.models import Model
# 冷凍判別器權重(僅訓練生成器)
discriminator.trainable = False
# GAN 模型
gan_input = tf.keras.Input(shape=(latent_dim,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = Model(gan_input, gan_output)
gan.compile(optimizer=Adam(0.0002, 0.5), loss="binary_crossentropy")
gan.summary()
5. 訓練 GAN
5.1 訓練流程
- 使用生成器生成一批假的手寫數字圖片。
- 訓練判別器,分辨真實圖片與生成圖片。
- 訓練 GAN,優化生成器的權重以騙過判別器。
5.2 訓練代碼
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
def train_gan(epochs, batch_size):
for epoch in range(epochs):
# 隨機抽取真實數據
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images = x_train[idx]
# 生成假圖片
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
# 訓練判別器
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 訓練生成器
noise = np.random.normal(0, 1, (batch_size, latent_dim))
valid_labels = np.ones((batch_size, 1))
g_loss = gan.train_on_batch(noise, valid_labels)
# 每 100 回合打印損失
if epoch % 100 == 0:
print(f"Epoch {epoch}: D_loss={d_loss[0]}, G_loss={g_loss}")
# 開始訓練
train_gan(epochs=10000, batch_size=64)
6. 生成圖像與可視化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import matplotlib.pyplot as plt
def generate_images(num_images=10):
noise = np.random.normal(0, 1, (num_images, latent_dim))
generated_images = generator.predict(noise)
generated_images = (generated_images + 1) / 2.0 # 恢復至 [0, 1] 範圍
plt.figure(figsize=(10, 2))
for i in range(num_images):
plt.subplot(1, num_images, i + 1)
plt.imshow(generated_images[i], cmap="gray")
plt.axis("off")
plt.show()
generate_images()
課後作業
- 調整生成器和判別器的架構,觀察對生成結果的影響。
- 使用更高解析度的數據集,如 CIFAR-10,構建彩色圖像生成模型。
- 研究改進 GAN 的技術,如條件生成對抗網路(Conditional GAN)或 WGAN。
本文章以 CC BY 4.0 授權