当前位置: 首页 > 面试题库 >

恢复模型需要很长时间

荣俊
2023-03-14
问题内容

我在重新整理模型时遇到问题。我训练了模型并使用此代码保存了模型。我不太确定这是否是正确的方法,我将不胜感激。当我尝试还原模型时会发生问题。我只需要预测,就不会再接受过培训了。从模型中恢复参数需要花费很多时间。在我仅需要预测的前提下,如何改进模型保护程序或模型恢复程序以使其快速完成。

X = tf.placeholder(tf.float32, [None, 56, 56, 1])
Y_ = tf.placeholder(tf.float32, [None, 36])

L1 = 432
L2 = 72
L3 = 36

W1 = tf.Variable(tf.truncated_normal([3136, L1], stddev=0.1))
b1 = tf.Variable(tf.zeros([L1]))
W2 = tf.Variable(tf.truncated_normal([L1, L2], stddev=0.1))
b2 = tf.Variable(tf.zeros([L2]))
W3 = tf.Variable(tf.truncated_normal([L2, L3], stddev=0.1))
b3 = tf.Variable(tf.zeros([L3]))

XX = tf.reshape(X, [-1, 3136])

Y1 = tf.nn.sigmoid(tf.matmul(XX, W1) + b1)
Y1 = tf.nn.dropout(Y1, keep_prob=0.8)
Y2 = tf.nn.sigmoid(tf.matmul(Y1, W2) + b2)
Y2 = tf.nn.dropout(Y2, keep_prob=0.8)
Ylogits = tf.matmul(Y2, W3) + b3
Y = tf.nn.softmax(Ylogits)

cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=Ylogits, labels=Y_)
cross_entropy = tf.reduce_mean(cross_entropy) * 100
correct_prediction = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_step = tf.train.GradientDescentOptimizer(0.0001).minimize(cross_entropy)

allweights = tf.concat([tf.reshape(W1, [-1]), tf.reshape(W2, [-1]), tf.reshape(W3, [-1])], 0)
allbiases = tf.concat([tf.reshape(b1, [-1]), tf.reshape(b2, [-1]), tf.reshape(b3, [-1])], 0)

init = tf.global_variables_initializer()

saver = tf.train.Saver()

def next_batch(x, y, batch, step):
    x_temp = x[cur_step:(step+batch)]
    y_temp = np.squeeze(y[step:(step + batch)])
    return x_temp, y_temp


with tf.Session() as sess:
    sess.run(init)
    cur_step = 0
    for i in range(NUM_ITERS + 1):
        batch_X, batch_Y = next_batch(train_xx, train_yy, BATCH, cur_step)
        if i % DISPLAY_STEP == 0:
            acc_trn, loss_trn, w, b = sess.run([accuracy, cross_entropy, allweights, allbiases], feed_dict={X: batch_X, Y_: batch_Y})
            acc_tst, loss_tst = sess.run([accuracy, cross_entropy], feed_dict={X: test_xx, Y_: test_yy})

        sess.run(train_step, feed_dict={X: batch_X, Y_: batch_Y})
    save_path = saver.save(sess, "abc/model")

恢复:

X = tf.placeholder(tf.float32, [None, 56, 56, 1])
Y_ = tf.placeholder(tf.float32, [None, 36])

L1 = 432
L2 = 72
L3 = 36

W1 = tf.Variable(tf.truncated_normal([3136, L1], stddev=0.1))
b1 = tf.Variable(tf.zeros([L1]))

W2 = tf.Variable(tf.truncated_normal([L1, L2], stddev=0.1))
b2 = tf.Variable(tf.zeros([L2]))

W3 = tf.Variable(tf.truncated_normal([L2, L3], stddev=0.1))
b3 = tf.Variable(tf.zeros([L3]))


XX = tf.reshape(X, [-1, 3136])

Y1 = tf.nn.sigmoid(tf.matmul(XX, W1) + b1)
Y1 = tf.nn.dropout(Y1, keep_prob=0.8)
Y2 = tf.nn.sigmoid(tf.matmul(Y1, W2) + b2)
Y2 = tf.nn.dropout(Y2, keep_prob=0.8)
Ylogits = tf.matmul(Y2, W3) + b3
Y = tf.nn.softmax(Ylogits)

with tf.Session() as sess:
    saver = tf.train.Saver()
    saver = tf.train.import_meta_graph('model.meta')
    saver.restore(sess, 'model')

编辑:也许使用Google Colab的GPU训练模型,然后将其还原到我的PC上这一事实很重要。


问题答案:

它的重复项:Tensorflow:如何保存/恢复模型?。

您对模型的保存是正确的,但不正确。您正在做的是尝试创建一个与保存的模型具有相同节点的新图,而不是从保存的图中还原它。以下步骤应解决有关如何还原模型的问题:

#Start with resetting the default graph
tf.reset_default_graph()

with tf.Session() as sess:

   # Nodes:Before loading the graph
   print([n.name for n in tf.get_default_graph().as_graph_def().node])
   # Output is [] as no graph is loaded yet.

   # First let's load meta graph 
   saver = tf.train.import_meta_graph("abc/model.meta")
   # Nodes:after loading the graph'
   print([n.name for n in tf.get_default_graph().as_graph_def().node])
   # Output is [save/RestoreV2/shape_and_slices', 'save/RestoreV2/tensor_ ...]

   # The above step doesnt load the weights, can be checked by
   print(sess.run('Variable_1:0'))
   # Error: attempting to use uninitialized graph.

   #load the weights 
   saver.restore(sess,tf.train.latest_checkpoint('./abc/'))
   print(sess.run('Variable_1:0'))
   # Output: [-2.80421402e-04  3.53254407e-04 ...]

现在我们已经加载并准备好节点,您需要访问其中的一些以进行推断。但是,由于节点的命名不正确,因此很难确定哪个节点是输入和输出。为了避免这种情况,name在使用以下name参数正确保存模型时,需要张量/运算:

X = tf.placeholder(tf.float32, [None, 56, 56, 1], name='X')
Y = tf.identity(f.nn.softmax(Ylogits), name='logits').

加载图和权重后,可以在推理图中使用以下张量获得这些张量get_tensor_by_name

with tf.Session() as sess:

   #Load the graph and weights as above
   ....

   graph = tf.get_default_graph()
   X_infer = graph.get_tensor_by_name('X:0')
   Y_infer = graph.get_tensor_by_name('logits:0')
   sess.run(Y_infer,{X_infer:new_input}


 类似资料:
  • 我使用javamail通过IMAP协议从exchage帐户读取邮件。这些邮件是纯格式的,内容是XML。 几乎所有这些邮件的大小都很短(通常小于100Kb)。然而,有时我不得不处理大型邮件(大约10Mb-15Mb)。例如,昨天我收到一封13Mb大小的电子邮件。仅仅读它就花了50多分钟。这正常吗?有没有办法提高它的性能?代码是: 花费如此长时间的方法是。我做错了什么?有什么提示吗? 非常感谢,我的英语

  • 给出结果需要20多秒,而在mongo控制台中同样的查询需要不到一秒。 为什么会出现这种情况,如何减少速度差距?

  • 我有以下PHP代码在Laravel正在执行一个MySql查询: 执行此查询需要很长时间。 我对所排序的列以及其他查询的许多列都有索引。 我该怎么办? 更新: 执行的查询: 结果:

  • 在我们的kafka broker设置中,GC平均需要20毫秒,但随机增加到1-2秒。极端情况持续9秒。这种情况的发生频率相当随机。平均每天发生15次。我尝试过使用GCEasy,但没有给出任何见解。我的内存使用率为20%,但进程仍然使用交换,尽管内存可用。感谢您对如何将其最小化的任何意见 JVM选择: GC日志:

  • 问题内容: 我正在使用Hibernate 4.2,JPA 2.0和Postgres 9.2 代码卡在 在进一步调查中,我发现Hibernate调用了class 方法。此方法尝试加载有关每个数据库对象的元数据 的代码是Postgers的JDBC驱动程序的一部分,而确实是花费时间来执行该方法的驱动程序(我加载了驱动程序源并尝试了跟踪)。但是由于这个问题在Hibernate 3.3(我之前使用过)中没有

  • 我知道要冬眠。我有一个sql语句 我尝试用createCriteria和HQL实现它。 HQL: 问题是,此HQL的执行时间延长了10倍。并执行许多不必要的查询。我尝试使用注释字符串进行转换,它有了一些改进,但仍然比createCriteria查询长5倍,此外,我无法进行此转换 <代码>列表 版本数据防御