当前位置: 首页 > 面试题库 >

从Keras多类模型获取混淆矩阵

陈正业
2023-03-14
问题内容

我正在用Keras构建多类模型。

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[checkpoint], validation_data=(X_test, y_test))  # starts training

这是我的测试数据的外观(文本数据)。

X_test
Out[25]: 
array([[621, 139, 549, ...,   0,   0,   0],
       [621, 139, 543, ...,   0,   0,   0]])

y_test
Out[26]: 
array([[0, 0, 1],
       [0, 1, 0]])

生成预测后…

predictions = model.predict(X_test)
predictions
Out[27]: 
array([[ 0.29071924,  0.2483743 ,  0.46090645],
       [ 0.29566404,  0.45295066,  0.25138539]], dtype=float32)

我做了以下工作来获得混淆矩阵。

y_pred = (predictions > 0.5)

confusion_matrix(y_test, y_pred)
Traceback (most recent call last):

  File "<ipython-input-38-430e012b2078>", line 1, in <module>
    confusion_matrix(y_test, y_pred)

  File "/Users/abrahammathew/anaconda3/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 252, in confusion_matrix
    raise ValueError("%s is not supported" % y_type)

ValueError: multilabel-indicator is not supported

但是,我收到上述错误。

在Keras中建立多类神经网络时,如何获得混淆矩阵?


问题答案:

您输入的内容confusion_matrix必须是整数数组,而不是一种热编码。

matrix = metrics.confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))


 类似资料:
  • X代表特征,Y代表图像分类的标签。我使用CNN进行二进制图像分类,就像猫和狗一样。 预测和y_test形状的输出是(90,2)和(90,),当我使用混淆矩阵时,它刷新:-ValueError:分类指标不能处理二进制和连续多输出目标的混合。

  • 我正在对实际数据和来自分类器的预测数据进行多标签分类。实际数据包括三类(c1、c2和c3),同样,预测数据也包括三类(c1、c2和c3)。数据如下 在多标签分类中,文档可能属于多个类别。在上述数据中,1表示文档属于特定类,0表示文档不属于特定类。 第一行Actual\u数据表示文档属于c1类和c2类,不属于c3类。类似地,第一行predicted\u数据表示文档属于类别c1、c2和c3。 最初我使

  • 我试图弄清楚如何使用神经网络为多标签分类任务生成混淆矩阵。我之前设法使用函数“交集”计算准确性,因为对此我不关心任何排序。 然而,为了计算混淆矩阵,我确实关心预测/标签的索引顺序。由于标签的值始终相同(

  • 我正在使用分类器的多类多标签输出。类的总数为14,实例可以关联多个类。例如: 我现在制作混淆矩阵的方式: 输出如下: 现在,我不确定sklearn的混淆矩阵是否能够处理多标签多类数据。谁能帮我一下吗?

  • 假设我有一个具有n个级别的因子变量y,我有预测和实际结果。如何构造混淆矩阵? 对于n=2的情况,这个问题已经得到了回答。看见 R:如何为预测模型制作混淆矩阵? 我试过的 这就是我能走多远 现在这必须以矩阵的形式呈现。 出身背景 混淆矩阵具有水平标签“实际类别”和垂直标签“预测类别”。矩阵元素的计数如下所示: 元素(1,1)=实际类的计数数为A,预测类的计数数为A 元素(1,2)=实际类别为A,预测

  • 我有一个具有登录功能的控制器类。当我输入用户名和密码并按submit时,它将调用此控制器并在会话中存储客户信息。但有一件事让我感到困惑,那就是@model属性 我将使用@ModelAttribute Customer存储我输入的用户名和密码,并使用Customer c存储我从customService获得的所有信息,并将其存储到会话中。但是会话存储的是客户。 如果我这样改变论点。它工作正常

  • 对不起,我是新来WEKA,刚刚学习。 在我的决策树(J48)分类器输出中,有一个混淆矩阵: 我如何读取这个矩阵?

  • 问题内容: 为什么在定义枚举时将其传递给字段名称列表,然后为什么这些字段名称(例如Days.MONDAY)最终引用字段 值 呢?我可以传递一个字段(例如Days.MONDAY),然后使用开关获取该字段的值。更奇怪的是,当我声明枚举字段时,即使它们实际上是值,我什至不必用引号将它们引起来。 问题答案: 将Java枚举视为定义类的一种不错的语法。这是一个可能有帮助的shell脚本: 是的,您可以说是c