当前位置: 首页 > 工具软件 > Fit > 使用案例 >

【Python-Keras】keras.fit()和keras.fit_generator()的解析与使用

袁旻
2023-12-01

1 作用与区别

作用: 用于训练神经网络模型,两者可以完成相同的任务

区别:
.fit()时使用的整个训练数据集可以放入内存,并没有应用数据增强,就是.fit()无需使用Keras生成器(即无需数据参数)

当我们有一个巨大的数据集可容纳到我们的内存中或需要应用数据扩充时,将使用.fit_generator()。就是需要使用Keras生成器去扩充数据等等操作。

2 解析与使用

2.1 keras.fit()

(1)参数介绍

keras.fit(object, #要训练的模型
	x = NULL, #训练数据。可以是向量,数组或矩阵
	y = NULL, #训练标签。可以是向量,数组或矩阵
	batch_size = NULL, #它可以接受任何整数值或NULL,默认情况下,它将设置为32。它指定否。每个梯度样本数
	epochs = 10,#一个整数,我们要训练模型epochs的数量
  	verbose = getOption("keras.fit_verbose", default = 1),#指定详细模式(0 =静音,1 =进度栏,2 = 1每行记录)      
  	callbacks = NULL, 
  	view_metrics = getOption("keras.view_metrics",
  	default = "auto"), 
  	validation_split = 0, 
  	validation_data = NULL,
  	shuffle = TRUE, 
  	class_weight = NULL, 
  	sample_weight = NULL,
  	initial_epoch = 0, 
  	steps_per_epoch = NULL, #它指定之前执行的步骤总数
  	validation_steps = NULL,
  ...)

(2)举例使用

我们首先输入训练数据(Xtrain)和训练标签(Ytrain)。然后,我们使用Keras允许我们的模型以batch_size为32训练100个epoch。

model.fit(Xtrain, Ytrain, batch_size = 32, epochs = 100)

(3)原理讲解

当我们调用.fit()函数时,它会做一些假设:

  • 整个训练集可以放入计算机的随机存取存储器(RAM)中。
  • 调用模型。fit方法第二次不会重新初始化我们已经训练好的权重,这意味着如果需要,我们实际上可连续调用fit以进行调整。
  • 无需使用Keras生成器(即无需数据参数)
  • 原始数据本身就是用于训练我们的网络的,而我们的原始数据只能放入内存中

2.2 keras.fit_generator()

(1)参数介绍

fit_generator(object, #Keras对象模型
	generator, #生成器,其输出必须是以下形式的列表:
			#  - (inputs, targets)    
            #  - (input, targets, sample_weights)
            # 生成器的单个输出进行单个批处理,因此列表中的所有数组 长度必须等于批次的大小。生成器是期望的
			# 遍历其数据无限。有时,它永远不会返回或退出。
	steps_per_epoch, #它指定从生成器采取的步骤总数
	epochs = 1,
  	verbose = getOption("keras.fit_verbose", default = 1),
  	callbacks = NULL, 
  	view_metrics = getOption("keras.view_metrics",
  	default = "auto"), 
  	validation_data = NULL, #可以是以下的其中一种
  			# - 一个 inputs 和 targets 列表
            # - 一个发生器
  			# - inputs, targets, 和sample_weights 列表,可用于在任何时期结束后评估任何模型的损失和度量。
  	validation_steps = NULL,#仅当validation_data是生成器时,才此参数可以使用。它指定生成器之前从生成器采取的步骤总数,在每个epoch停止,其值=在数据集中验证数据点的总数/验证batch大小。
  	class_weight = NULL, 
  	max_queue_size = 10, 
  	workers = 1,
  	initial_epoch = 0)

(2)举例使用

数据增强 是一种从现有训练数据集中人为创建新数据集进行训练的方法,以利用可用数据量来提高深度学习神经网络的性能。这是一种正则化形式,使我们的模型比以前更好地推广。
在这里,我们使用Keras ImageDataGenerator对象将数据增强应用于图像的随机平移,调整大小,旋转等。每一批新数据都会根据提供给ImageDataGenerator的参数进行随机调整。

#通过训练图像生成器执行数据论证
dataAugmentaion = ImageDataGenerator(rotation_range = 30,
									zoom_range = 0.20, 
									fill_mode =“ nearest”,
									shear_range= 0.20,
									horizo​​ntal_flip = True, 
									width_shift_range = 0.1,
									height_shift_range = 0.1)

#训练模型
model.fit_generator(dataAugmentaion.flow(trainX,trainY,batch_size = 32),
 					validate_data =(testX,testY),
 					steps_per_epoch = len(trainX)// 32,
 					epoch= 10)

网络训练10个epoch,默认batch大小为32。
对于较小和较不复杂的数据集,建议使用keras.fit函数,而在处理实际数据集时,并不是那么简单,因为实际数据集的大小很大,很难放入计算机内存中。
处理这些数据集更具挑战性,处理这些数据集的重要步骤是执行数据扩充,以避免模型的过拟合,并提高模型的泛化能力。

(3)原理解析

当调用.fit_generator()函数时,它会做一些假设:

  • Keras首先调用了生成器函数(dataAugmentaion)
  • 生成器函数为.fit_generator()函数提供了32的batch_size。
  • .fit_generator()函数首先接受一批数据集,然后对其进行反向传播,然后更新模型中的权重。
  • 对于指定的epoch数(在本例中为10),将重复此过程。
 类似资料: