当前位置: 首页 > 工具软件 > Regain > 使用案例 >

Save&Regain Model

宋飞掣
2023-12-01

Table of Contents

第六章 模型保存与恢复

6.1 保存模型

       要保存已经训练好的模型在构造期末尾创建一个 saver 节点,在执行期中调用save()方法,传入一个会话和检查点文件的路径即可。以保存加州房价线性模型为例(构造阶段前的代码并无改动,原代码复制即可):

# 导入包
# ......
# 下载及整理数据
# ......
# 数据预处理
# ......
# 构造阶段  
X = tf.placeholder(tf.float32, shape=(None, n+1), name="X")
y = tf.placeholder(tf.float32, shape=(None, 1), name="y")
n_epochs = 1000
batch_size = 100
n_batches= int(np.ceil(m/batch_size))
global_learning_rate = 0.01
XT = tf.transpose(X)
theta = tf.Variable(tf.random_uniform([n+1,1],-1.0,1.0),name="theta")     # *    参数 seed=42
y_pred =  tf.matmul(X, theta, name="prediction")                          # 预测值
error = y_pred-y                                                          # 误差
mse = tf.reduce_mean(tf.square(error), name="mse")                        # 均方误差(成本函数)                                 
# 调用特定的优化器求解梯度并优化
optimizer = tf.train.GradientDescentOptimizer(learning_rate =   global_learning_rate)
training_op = optimizer.minimize(mse)
# 添加一个saver节点用来保存模型参数
saver = tf.train.Saver()
  
# 执行阶段  
init = tf.global_variables_initializer()

def fetch_batch(epoch, batch_index, batch_size):
    np.random.seed(epoch * n_batches + batch_index) 
    indices = np.random.randint(m, size=batch_size)
    X_batch = scaled_housing_data_plus_bias[indices] 
    y_batch = housing.target.reshape(-1, 1)[indices] 
    return X_batch, y_batch

with tf.Session() as sess:
    sess.run(init)
 
    for epoch in range(n_epochs):
        if epoch%100 == 0:
            save_path =  saver.save(sess, "./tmp/my_model_1.ckpt")
        for batch_index in range(n_batches):
            X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
            sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
    best_theta = theta.eval()
    print("The best theta is", best_theta)
    save_path = saver.save(sess, "./tmp/my_model_final_1.ckpt")

6.2 恢复模型

       与保存模型一样,恢复使用模型时需要在构造期末尾创建一个saver节点,但在执行期开始时候不是用init节点来初始化所有变量,而是调用Saver对象上的restore()方法。继续以加州房价线性模型为例:

# 导入包
# .....
# 下载并整理数据
# .....
# 数据预处理
# .....
# 构造期
# .....
# 执行期
def fetch_batch(epoch, batch_index, batch_size):
    np.random.seed(epoch * n_batches + batch_index) 
    indices = np.random.randint(m, size=batch_size)
    X_batch = scaled_housing_data_plus_bias[indices] 
    y_batch = housing.target.reshape(-1, 1)[indices] 
    return X_batch, y_batch

with tf.Session() as sess:
    saver.restore(sess, "./tmp/my_model_final.ckpt")
 
    for epoch in range(n_epochs):
        for batch_index in range(n_batches):
            X_batch, y_batch = fetch_batch(epoch, batch_index, batch_size)
            sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
    best_theta = theta.eval()
    print("The best theta is", best_theta)

6.3 模型文件

       训练完毕的Tensorflow模型可以保存为checkpoint文件(.ckpt)或protocolbuff文件(.pb)。其中ckpt文件是权重与结构相分离的四个文件,而pb文件是储存固定模型结构及权重的一个序列化文件。ckpt文件适合进行训练,而pb文件适合发布和离线预测。可以利用官方提供的freeze_grapah.py脚本将ckpt文件转换为pb文件(具体使用可见这个博客)。
       模型文件checkpoint中的四个文件的内容及作用如下所示:

  1. Checkpoint:是一个文本文件,用于保存断点文件列表和迅速查找最近一次的断点文件;
  2. meta:序列化二进制文件,保存图结构信息;
  3. data:保存模型变量即参数的值;
  4. index:保存模型参数名。

        模型文件pb的内容及作用可参见另一博客TensorFlow的三种文件

6.4 参考

[1] jimlee.为什么tesnorflow保存model.ckpt文件会生成4个文件?
[2] pan_jinquan.tensorflow实现将ckpt转pb文件

 类似资料: