当前位置: 首页 > 知识库问答 >
问题:

Keras fit_generator和fit结果不同

淳于思淼
2023-03-14

我正在用人脸图像数据集训练一个卷积神经网络。该数据集有10,000个尺寸为700x700的图像。我的模型有12层。我正在使用生成器函数将图像读取到Keras fit_generator函数中,如下所示。

train_file_names==>包含训练实例文件名的Python列表
train_class_labels==>一个热编码类标记的数组([0,1,0],[0,0,1]等)
train_data==>训练实例的数组
train_steps_epoch==>16(批大小为400,我有6400个实例用于训练,因此一次通过整个数据集需要16次迭代)
batch_size==>400
calls_made==>当生成器到达训练实例的末尾时,它会重置索引以从下一个时期的第一个索引加载数据。

我将此生成器作为参数传递给keras的'fit_generater'函数,以生成每个历元的新数据批。

val_data,val_class_labels==>验证数据个数组
epochs==>epochs个数

使用Keras fit_generator:

model.fit_generator(generator=train_generator, steps_per_epoch=train_steps_per_epoch, epochs=epochs, use_multiprocessing=False, validation_data=[val_data, val_class_labels], verbose=True, callbacks=[history, model_checkpoint], shuffle=True, initial_epoch=0) 

代码

def train_data_generator(self):     
    index_start = index_end = 0 
    temp = 0
    calls_made = 0

    while temp < train_steps_per_epoch:
        index_end = index_start + batch_size
        for temp1 in range(index_start, index_end):
            index = 0
            # Read image
            img = cv2.imread(str(TRAIN_DIR / train_file_names[temp1]), cv2.IMREAD_GRAYSCALE).T
            train_data[index]  = cv2.resize(img, (self.ROWS, self.COLS), interpolation=cv2.INTER_CUBIC)
            index += 1       
        yield train_data, self.train_class_labels[index_start:index_end]
        calls_made += 1
        if calls_made == train_steps_per_epoch:
            index_start = 0
            temp = 0
            calls_made = 0
        else:
            index_start = index_end
            temp += 1  
        gc.collect()

fit_generator的输出

86/300
16/16纪元[

我的问题是,在使用具有上述生成器功能的fit_generator时,我的模型丢失没有得到任何改善,并且验证精度非常差。但是当我使用keras的fit函数时,模型损失减少了,验证精度也提高了很多。

在不使用生成器的情况下使用Keras fit函数

model.fit(self.train_data, self.train_class_labels, batch_size=self.batch_size, epochs=self.epochs, validation_data=[self.val_data, self.val_class_labels], verbose=True, callbacks=[history, model_checkpoint])    

使用fit函数训练时输出

Epoch 25/300
6400/6400CC:0.7795
Epoch 32/300
6400/6400[=================================================================================================================================================]-20秒3ms/步进-损耗:0.0106-Acc:0.9969-

共有1个答案

萧繁
2023-03-14

您必须确保您的数据生成器在各个时期之间对数据进行洗牌。我建议您在循环外部创建一个可能的索引列表,使用random.shuffle将其随机化,然后在循环内部迭代。

来源:https://github.com/keras-team/keras/issues/2389和自己的经验。

 类似资料:
  • 问题内容: 我很难理解scikit-learn和scikit-learn之间的区别(如果有)。 试图预测具有不平衡类的二进制输出(Y = 1时约为1.5%)。 分类器 大鹏曲线 AUC的 和 有人可以解释这种差异吗?我以为两者都只是在计算ROC曲线下的面积。可能是因为数据集不平衡,但我不知道为什么。 谢谢! 问题答案: AUC并不总是在ROC曲线的曲线下方。曲线下面积为下(抽象)地区 的一些 曲线

  • 我正在使用neo4j来计算一个数据集上的一些统计数据。为此,我经常在浮点值上使用sum。我得到不同的结果取决于环境。例如,执行以下操作的查询: 差别很小(与)。但使简单的等式检查失败就足够了。另一个例子是数据库的不同实例,使用相同的加载过程加载相同的数据可能会产生相同的问题(dbs可能不是1:1,某些关系的加载顺序可能不同)。我取了neo4j求和的原始值(通过简单地移除),并验证它们在所有情况下都

  • 这是我的代码: null

  • 主要内容:分发结果类型:,FreeMaker结果类型:,重定向结果类型:正如前面提到的,<results>标签在Struts2的MVC框架的视图中所扮演的角色。动作是负责执行业务逻辑。执行业务逻辑后,接下来的步骤是使用<results>标签显示的视图。  经常有一些附带导航规则的结果。例如,如果在操作方法是对用户进行验证,有三种可能的结果。 (一)成功登录;(二)不成功的登录,用户名或密码错误;(三)帐户锁定。 在这种情况下的动作方法将被配置呈现的结果有三种可能的结果

  • Fit

    在此布局中,容器填充单个面板,当没有与布局相关的特定要求时,将使用此布局。 语法 (Syntax) 以下是使用Fit布局的简单语法。 layout: 'fit' 例子 (Example) 以下是显示Fit布局用法的简单示例。 <!DOCTYPE html> <html> <head> <link href = "https://cdnjs.cloudflare.com/aja

  • Fit

    Fit 是一个软件开发团队中的增强协作的工具,用来帮助客户、测试人员和开发人员来了解他们的软件以及接下来所要做的工作,它可以自动的根客户的预期要求进行比较生成相应报告。