当前位置: 首页 > 工具软件 > Keras > 使用案例 >

用keras计算F1

公西培
2023-12-01

加粗样式@TOC

from keras.callbacks import Callback
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

class Metrics(Callback):
    def on_train_begin(self, logs={}):
        self.val_f1s = []
    def on_epoch_end(self, epoch, logs={}):
        val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round()  ##.model
        val_targ = self.validation_data[1]  ###.model
        _val_f1 = f1_score(val_targ, val_predict, average='micro')

        self.val_f1s.append(_val_f1)
        print("— val_f1: %f " % _val_f1)
        return
f1=Metrics()
history = model.fit(train_data_array,  # 训练集输入特征
                        train_label_array,  # 训练集标签
                        batch_size=batchsize,  # 每次喂入网络256组数据
                        epochs=nb_epoch,  # 数据集迭代10次 (可自己设置)
                        verbose=1,  # 日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
                        validation_data=(x_validation, y_validation),
                        callbacks=[f1,check,time_callback]) 

出现报错:ValueError: unknown is not supported
解决方法: 数据维度转变,sklearn只能接受二维的,现在传入的是三维数据。把[batchsize,len,classes]转变为[batchsize*len,classes]的数据。
在上面添加下面两行代码。
val_predict = np.reshape(val_predict, (-1, val_predict.shape[-1]))
val_targ = np.reshape(val_targ, (-1, val_targ.shape[-1]))

最后总的代码如下

from keras.callbacks import Callback
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

class Metrics(Callback):
    def on_train_begin(self, logs={}):
        self.val_f1s = []
    def on_epoch_end(self, epoch, logs={}):
        val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round()  ##.model
        val_predict = np.reshape(val_predict, (-1, val_predict.shape[-1]))

        val_targ = self.validation_data[1]  ###.model
        val_targ = np.reshape(val_targ, (-1, val_targ.shape[-1]))
        _val_f1 = f1_score(val_targ, val_predict, average='micro')

        self.val_f1s.append(_val_f1)
        print("— val_f1: %f " % _val_f1)
        return


看了网上好多博客,转换成list array 都没成功,最后debug发现还是维度不对。费了一天时间

 类似资料: