使用Tensorflow keras 创建LSTM

奚瑾瑜
2023-12-01

本文介绍使用LSTM和 RNN+LSTMCell 等2种方法实现LSTM网络。SimpleRNN的全连接循环神经网络收敛速度是比较慢,而LSTM就快多了。

  1. LSTM
    代码如下:
import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
#读取本地mnist数据
def my_load_data(path='mnist.npz'):
    origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
    path = tf.keras.utils.get_file(
        path,
        origin=origin_folder + 'mnist.npz',
        cache_dir='DataSet/',
        cache_subdir=""
    )
    with np.load(path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']

        return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = my_load_data(path='mnist.npz')
#归一化数据
x_train=x_train/255.
x_test=x_test/255.
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10)

# 数据长度 一行有28个像素
input_size = 28
# 序列的长度
time_steps = 28
# 隐藏层block的个数
cell_size = 64
model = keras.Sequential()
#使用LSTM,如果用RNN+LSTMCell的方法,只需要替换下面的语句即可
model.add(keras.layers.LSTM(
        units = cell_size, # 输出
        input_shape= (time_steps, input_size), # 输入
))
# 输出层
model.add(keras.layers.Dense(10, activation='softmax'))

# 定义优化器
adam = keras.optimizers.Adam(lr=1e-4)
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, batch_size=64, epochs=4)
# 评估模型
loss, accuracy = model.evaluate(x_test, y_test)
print('test loss', loss)
print('test accuracy', accuracy)

i = np.random.randint(0,10000)
plt.imshow(x_test[i],cmap=plt.cm.binary)
plt.show()

predictions = model.predict(tf.expand_dims(x_test[i],0))
print("RNN 结果:%d"%np.argmax(predictions))

  1. RNN+LSTMCell
    此方法就是用RNN+LSTMCell替换LSTM。
    修改的代码是:
model.add(keras.layers.RNN(keras.layers.LSTMCell(
        units = cell_size, # 输出
        input_shape= (time_steps, input_size), # 输入
   )))

 类似资料: