当前位置: 首页 > 知识库问答 >
问题:

当达到特定验证精度时,如何停止培训?

姬经义
2023-03-14

我正在训练卷积网络,一旦验证错误达到90%,我想停止训练。我想过使用EarlyStop并将基线设置为0.90,但是每当给定时代数的验证精度低于该基线时,它就会停止训练(此处仅为0)。所以我的代码是:

es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])

当我使用此代码时,我的训练将在第一个历元后停止,并显示给定的结果:

培训60000个样本,验证10000个样本

纪元1/30 60000/60000 - 7s-损失: 0.4600-acc: 0.8330-val_loss: 0.3426-val_acc: 0.8787

一旦验证准确率达到90%或以上,我还可以尝试停止培训吗?

下面是代码的其余部分:

  tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
  tf.keras.layers.MaxPooling2D(2, 2),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(152, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer=Adam(learning_rate=0.001),loss='sparse_categorical_crossentropy', metrics=['accuracy'])
es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])

谢谢你!

共有2个答案

毋宏茂
2023-03-14

现有答案看起来不错,但我在过去使用了一个较短的版本:

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs.get('accuracy') >= 9e-1:
            self.model.stop_training = True

您可以这样实现它:

callback = CustomCallback()

history = model.fit(..., callbacks=[callback])
班建义
2023-03-14

早期停止回调将搜索停止增加(或减少)的值,因此它不适合您的问题。但是tf.keras允许您使用自定义回调。

例如:

class MyThresholdCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(MyThresholdCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None): 
        val_acc = logs["val_acc"]
        if val_acc >= self.threshold:
            self.model.stop_training = True

对于TF版本2.3或更高版本,您可能必须使用“val\u acc”而不是“val\u acc”。感谢Christian Westbrook在评论中的注释。

上面的回调,在每个纪元结束时,将从所有可用的日志中提取验证精度。然后它将与用户定义的阈值(在您的情况下为90%)进行比较。如果符合标准,训练将停止。

有了它,你可以简单地调用:

my_callback = MyThresholdCallback(threshold=0.9)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2, callbacks=[my_callback])

或者,您可以在批处理端(…)上使用def 如果要立即停止。但是,这需要参数batch,logs,而不是epoch,logs

 类似资料:
  • 我试图创建一个回归模型,但我的验证精度保持在。我试图训练图像,让网络找到一个物体的位置和它覆盖的粗糙区域。我增加了解冻层和精度的平台下降到。如果能帮我找出我做错了什么,我将不胜感激。 数据是由一个包含图像数据和一些标签的tf记录文件生成的。这是发电机的最后一点。 获取批处理:数据集从两个目录加载,其中包含tfrecord文件,一个用于培训,另一个用于验证 1月1/1000/1000 1000/10

  • 问题内容: 我必须编写上述程序才能将go float64变量的精度降低到2。在这种情况下,我同时使用了strconv和fmt。还有其他逻辑方法可以完成吗? 问题答案: 您不需要任何额外的代码…它就像 测试代码

  • 我通过创建固定数量的线程来使用执行器服务来进行HTTP GET数据检索。 当Tomcat停止时,我们会出现以下错误: 严重:web应用程序[/viewer]似乎已启动名为[ThreadExecutor_51616156]的线程,但未能停止该线程。这很可能会造成内存泄漏。 这是真的吗?在没有这些服务错误的情况下,如何正确停止tomcat。

  • 问题内容: 我正在使用SQL Server 2012,并且正在尝试执行以下操作: 因此,换句话说,我想从表中选择具有指定的截止日期和日期的随机行,并在里程总和等于或大于该数字时停止:3250 问题答案: 由于您使用的是SQL Server 2012,因此这是一种无需循环的简单方法。 SQLfiddle演示-多次单击“运行SQL”以查看随机结果。

  • 我有一个iOS应用程序,它运行几个不同的块,在屏幕上为几个不同的对象设置动画。这一切都是可行的,但我如何停止其中一个动画块而不停止其余的? 我尝试过使用以下方法: 但它什么都没做,动画还在继续。 然后,我尝试使用一个简单的值,并在设置为"NO"后,让动画从方法返回,但这也不起作用。 以下是我试图停止的动画: 谢谢你,丹。

  • 我是android编程的新手,所以这些问题可能是愚蠢的。我读了一些书,但不能完全得到答案。 我有一个广播接收器,从一个服务注册了一些意图- 由于我移除了“setforeground”调用以保持我的服务运行(因为我不想要状态栏图标,我想知道我是否懒惰使用这种方法),我的服务现在将定期关闭,通常在短时间后再次启动(但有时我看到它是5分钟)。