要保存已经训练好的模型在构造期末尾创建一个 saver 节点,在执行期中调用save()方法,传入一个会话和检查点文件的路径即可。以保存加州房价线性模型为例(构造阶段前的代码并无改动,原代码复制即可):
# 导入包
# ......
# 下载及整理数据
# ......
# 数据预处理
# ......
# 构造阶段
X = tf.placeholder(tf.float32, shape=(None, n+1), name="X")
y = tf.placeholder(tf.float32, shape=(None, 1), name="y")
n_epochs = 1000
batch_size = 100
n_batches= int(np.ceil(m/batch_size))
global_learning_rate = 0.01
XT = tf.transpose(X)
theta = tf.Variable(tf.random_uniform([n+1,1],-1.0,1.0),name="theta") # * 参数 seed=42
y_pred = tf.matmul(X, theta, name="prediction") # 预测值
error = y_pred-y # 误差
mse = tf.reduce_mean(tf.square(error), name="mse") # 均方误差(成本函数)
# 调用特定的优化器求解梯度并优化
optimizer = tf.train.GradientDescentOptimizer(learning_rate = global_learning_rate)
training_op = optimizer.minimize(mse)
# 添加一个saver节点用来保存模型参数
saver = tf.train.Saver()
# 执行阶段
init = tf.global_variables_initializer()
def fetch_batch(epoch, batch_index, batch_size):
np.random.seed(epoch * n_batches + batch_index)
indices = np.random.randint(m, size=batch_size)
X_batch = scaled_housing_data_plus_bias[indices]
y_batch = housing.target.reshape(-1, 1)[indices]
return X_batch, y_batch
with tf.Session() as sess:
sess.run(init)
for epoch in range(n_epochs):
if epoch%100 == 0:
save_path = saver.save(sess, "./tmp/my_model_1.ckpt")
for batch_index in range(n_batches):
X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
best_theta = theta.eval()
print("The best theta is", best_theta)
save_path = saver.save(sess, "./tmp/my_model_final_1.ckpt")
与保存模型一样,恢复使用模型时需要在构造期末尾创建一个saver节点,但在执行期开始时候不是用init节点来初始化所有变量,而是调用Saver对象上的restore()方法。继续以加州房价线性模型为例:
# 导入包
# .....
# 下载并整理数据
# .....
# 数据预处理
# .....
# 构造期
# .....
# 执行期
def fetch_batch(epoch, batch_index, batch_size):
np.random.seed(epoch * n_batches + batch_index)
indices = np.random.randint(m, size=batch_size)
X_batch = scaled_housing_data_plus_bias[indices]
y_batch = housing.target.reshape(-1, 1)[indices]
return X_batch, y_batch
with tf.Session() as sess:
saver.restore(sess, "./tmp/my_model_final.ckpt")
for epoch in range(n_epochs):
for batch_index in range(n_batches):
X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
best_theta = theta.eval()
print("The best theta is", best_theta)
训练完毕的Tensorflow模型可以保存为checkpoint文件(.ckpt)或protocolbuff文件(.pb)。其中ckpt文件是权重与结构相分离的四个文件,而pb文件是储存固定模型结构及权重的一个序列化文件。ckpt文件适合进行训练,而pb文件适合发布和离线预测。可以利用官方提供的freeze_grapah.py脚本将ckpt文件转换为pb文件(具体使用可见这个博客)。
模型文件checkpoint中的四个文件的内容及作用如下所示:
模型文件pb的内容及作用可参见另一博客TensorFlow的三种文件 。
[1] jimlee.为什么tesnorflow保存model.ckpt文件会生成4个文件?
[2] pan_jinquan.tensorflow实现将ckpt转pb文件