当前位置: 首页 > 知识库问答 >
问题:

无法打印正确的混淆矩阵,也无法打印热图值,例如2e 2,e 4等

贺彬
2023-03-14

无法正确打印混淆矩阵和打印热图。在示例2e 2、e 4等中,某些块或列中的值正在打印。请在这方面帮助我

导入numpy作为np导入matplotlib.pyplot作为plt导入seaborn作为sns从keras导入pandas作为pd。模型从keras导入顺序。图层从keras导入卷积2D。图层从keras导入MaxPoolig2D。图层从keras导入展平。图层从sklearn导入稠密。度量导入分类报告,混淆矩阵

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Initialising the CNN
classifier = Sequential()

# Step 1 - Convolution
classifier.add(Convolution2D(64, 3, 3, input_shape = (64, 64, 3), activation = 'relu'))

# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))

# Adding a second convolutional layer
classifier.add(Convolution2D(64, 3, 3, activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))


classifier.add(Convolution2D(64, 3, 3, activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))


# Step 3 - Flattening
classifier.add(Flatten())

# Step 4 - Full connection
classifier.add(Dense(output_dim = 128, activation = 'relu'))
classifier.add(Dense(output_dim = 10, activation = 'sigmoid'))

# Compiling the CNN
classifier.compile(optimizer = 'Adam', loss = 'binary_crossentropy', metrics = ['accuracy'])

# Part 2 - Fitting the CNN to the images

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.4,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

#importing training data

training_set = train_datagen.flow_from_directory('Dataset/train',
                                                 target_size = (64,64),
                                                 batch_size = 64,
                                                 class_mode = 'categorical')

#importing test data
test_set = test_datagen.flow_from_directory('Dataset/test',target_size = (64,64),
                                            batch_size = 64,
                                            class_mode = 'categorical',shuffle=False)

#storing all the history

history = classifier.fit_generator(
        training_set,
        steps_per_epoch=20,
        epochs=5,
        validation_data=test_set,
        validation_steps=2000)
print(history.history.keys())

#汇总精度

plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

#Confution Matrix 
Y_pred = classifier.predict_generator(test_set, steps=len(test_set), max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
y_pred = np.argmax(Y_pred, axis=1)

#assigning values 
confusion=(confusion_matrix(test_set.classes, y_pred))
confusion_df = pd.DataFrame(confusion,
                     index = ['Airplan','Car','Birds','Cats','Deer', 'Dogs','Frog', 'Horse','Ship','Truck'], 
                     columns = ['Airplan','Car','Birds','Cats','Deer', 'Dogs','Frog', 'Horse','Ship','Truck'])

#heatmap    
sns.heatmap(confusion_df, annot=True)
print(confusion_df)

#classification report
print('Classification Report')
target_names = ['Airplan','Car','Birds','Cats','Deer', 'Dogs','Frog', 'Horse','Ship','Truck']
print(classification_report(test_set.classes, y_pred, target_names=target_names))

共有2个答案

夏振国
2023-03-14

假设你的混淆矩阵是厘米

import numpy as np
import matplotlib.pyplot as plt
import seaborn  as  sns

cm = np.array([[345,12],[0,104763]])

plt.figure(figsize=(10,7))
sns.heatmap(cm,annot=True,linewidths=1, fmt = 'd')
plt.xlabel('predicted')
plt.ylabel('Truth')
plt.show()

您将得到以下矩阵:

宋昕
2023-03-14

你能试试这个吗。

sns.heatmap(confusion_df, annot=True, fmt='.2f')
 类似资料:
  • 我得到了混淆矩阵,但是因为我的实际数据集有很多分类类别,所以很难理解。 范例- 但是如何打印标签/列名以便更好地理解呢? 我甚至试过这个- 需要帮忙吗?

  • 本文向大家介绍螺旋打印矩阵,包括了螺旋打印矩阵的使用技巧和注意事项,需要的朋友参考一下 该算法用于以螺旋方式打印数组元素。首先,从第一行开始,先打印全部内容,然后按照最后一列打印,然后再最后一行,依此类推,从而以螺旋方式打印元素。  该算法的时间复杂度为O(MN),M为行数,N为列数。 输入输出 算法 输入: 矩阵矩阵,行和列m和n。 输出:以螺旋方式打印矩阵的元素。 示例 输出结果

  • 我想正确打印unicode(比如希腊字符),但我有问题。例如: 问题是是否有任何解决方案可以正确打印所有卡哈拉特。我认为对于希腊字符,UTF-16是可以的。

  • 编辑:所以在一天的混乱之后。我的问题是spintf。我最初认为我的循环是错误的。

  • 我想获取我在 Web 表单中使用的多个标签的值。我使用隐藏的输入字段来完成这项工作,但我无法通过servlet打印标签的值,因为我得到空作为标签的值 下面是我用来测试的示例代码 索引.jsp 这是提交。服务程序 } 我想显示孩子的年龄,但输出为空 请帮助提供可能的解决方案 提前致谢

  • 我尝试在A4纸上以横向打印JavaFX WebView (JavaFX 8_25)中的HTML页面,但它以纵向打印出来,字体很小 系统输出显示纵向方向 纸张=纸:A4 (210 x 297mm) 尺寸=594.90087890625x841.3598022460938 MM 东方=纵向左边缘=54.0 右边缘=54.0 顶部边缘=54.0 底部边缘=54.0 我发现以横向模式打印HTML页面的唯一