AI 第20天:RNN 與 LSTM
遞歸神經網路(Recurrent Neural Network, RNN)與長短期記憶網路(Long Short-Term Memory, LSTM)是處理序列數據的重要工具,廣泛應用於自然語言處理、時間序列分析、語音識別等領域。今天,我們將學習 RNN 和 LSTM 的核心概念,並實作簡單的 LSTM 模型。
課程目標
- 理解 RNN 與 LSTM 的結構與工作原理。
- 瞭解它們如何處理序列數據(例如時間序列或文本)。
- 使用 TensorFlow/Keras 實作一個基於 LSTM 的簡單模型。
課程內容
1. RNN 的核心概念
1.1 為什麼使用 RNN?
RNN 的特點是擁有循環結構,能記住序列數據的上下文資訊,適合處理有時間相關性的數據(如文本、語音)。
1.2 RNN 的結構
RNN 會將前一個時間步的輸出作為下一個時間步的輸入,公式如下:
\(h_t = f(W_h h_{t-1} + W_x x_t + b)\)
其中:
- (h_t):當前時間步的隱藏狀態
- (x_t):當前時間步的輸入
- (W_h)、(W_x):權重矩陣
- (b):偏置
1.3 RNN 的缺點
- 梯度消失或爆炸問題:在處理長序列時,梯度可能會消失或爆炸,導致模型無法學習遠距資訊。
- 短期記憶:普通 RNN 無法有效記住長期上下文資訊。
2. LSTM 的核心概念
為了解決 RNN 的缺點,LSTM 被提出,它是一種特殊的 RNN,透過「記憶單元」與「門機制」有效處理長期依賴問題。
2.1 LSTM 的結構
LSTM 包含三個門:
- 遺忘門(Forget Gate):決定丟棄多少先前的信息。
- 輸入門(Input Gate):決定更新多少新信息到記憶單元中。
- 輸出門(Output Gate):決定從記憶單元輸出多少信息。
數學公式:
\(f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)\)
\(i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)\)
\(\tilde{C}_t = \tanh(W_C [h_{t-1}, x_t] + b_C)\)
\(C_t = f_t * C_{t-1} + i_t * \tilde{C}_t\)
\(o_t = \sigma(W_o [h_{t-1}, x_t] + b_o)\)
\(h_t = o_t * \tanh(C_t)\)
3. 實作:基於 LSTM 的時間序列預測
3.1 載入與準備數據
使用 Sine 函數生成時間序列數據,作為模型的輸入。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import matplotlib.pyplot as plt
# 生成時間序列數據
time_steps = np.linspace(0, 100, 1000)
data = np.sin(time_steps)
# 構建訓練數據
def create_sequences(data, seq_length):
X, y = [], []
for i in range(len(data) - seq_length):
X.append(data[i:i+seq_length])
y.append(data[i+seq_length])
return np.array(X), np.array(y)
seq_length = 50
X, y = create_sequences(data, seq_length)
# 分割訓練集與測試集
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
# 增加維度以符合 LSTM 輸入
X_train = X_train[:, :, np.newaxis]
X_test = X_test[:, :, np.newaxis]
print(f"訓練集形狀: {X_train.shape}, 測試集形狀: {X_test.shape}")
3.2 建構 LSTM 模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
# 建立 LSTM 模型
model = Sequential([
LSTM(50, activation='tanh', input_shape=(seq_length, 1)),
Dense(1) # 單一輸出
])
# 編譯模型
model.compile(optimizer='adam', loss='mse')
# 查看模型架構
model.summary()
3.3 訓練與評估模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 訓練模型
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2, verbose=1)
# 評估模型
loss = model.evaluate(X_test, y_test)
print(f"測試集損失: {loss:.4f}")
# 預測與可視化結果
y_pred = model.predict(X_test)
plt.figure(figsize=(10, 6))
plt.plot(y_test, label="真實值")
plt.plot(y_pred, label="預測值")
plt.legend()
plt.title("LSTM 時間序列預測結果")
plt.show()
4. LSTM 的進階應用
- 文本生成:使用 LSTM 訓練語料庫,生成新文本。
- 語音識別:處理聲音的時間序列特徵。
- 股價預測:分析股票的歷史數據進行價格預測。
課後作業
- 調整 LSTM 層數或神經元數量,觀察模型性能的變化。
- 嘗試使用 GRU(門控循環單元)替代 LSTM,並比較結果。
- 使用更複雜的時間序列數據(如氣象數據)進行模型訓練與預測。
本文章以 CC BY 4.0 授權