当前位置: 首页 > 知识库问答 >
问题:

CNN关于时尚MNIST数据集

董哲
2023-03-14

我正在使用fashion MNIST数据集来解决这个问题。我正在使用链接中的数据:

培训:http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz

训练集标签:http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz

测试集图像http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz

测试集标签http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz

我使用代码打开数据集:

def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

label = ['T-shirt/top',  'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt',
         'Sneaker', 'Bag', 'Ankle boot']

data_dir = './'
X_train, y_train = load_mnist('D:\book', kind='train')
X_test, y_test = load_mnist('D:\book', kind='t10k')

X_train = X_train.astype(np.float32) / 256.0
X_test = X_test.astype(np.float32) / 256.0

我正在尝试构建一个具有以下架构的卷积神经网络:

  • 卷积层,32个滤波器,尺寸为3x3
  • ReLU激活功能
  • 2x2 MaxPooling
  • 卷积层,64个滤波器,尺寸为3x3
  • ReLU激活功能
  • 2x2 MaxPooling
  • 具有512个单元和ReLU激活功能的全连接层
  • 使用SGD优化器为100个时代的输出层设置Softmax激活层

我的代码是:

X_train = X_train.reshape([60000, 28, 28, 1])
X_train = X_train.astype('float32') / 255.0
X_test = X_test.reshape([10000, 28, 28, 1])
X_test = X_test.astype('float32') / 255.0
model = Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=[28,28,1]))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Flatten())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
y_train = keras.utils.np_utils.to_categorical(y_train)
y_test = keras.utils.np_utils.to_categorical(y_test)
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=100)

但是它需要很多时间来执行。就像每个时代30分钟。我认为我的代码做错了什么。有人能帮我弄清楚吗?

共有1个答案

楚天宇
2023-03-14

我想强调几点:

>

X_train = X_train.astype(np.float32) / 256.0
X_test = X_test.astype(np.float32) / 256.0 

为什么要除以256.0?图像中的像素数据范围为0-255。因此,您应该将其除以255.0,将其规格化为0-1。

在加载后对数据进行一次规范化之后,您将再次对其进行规范化。检查以下代码:

X_train = X_train.reshape([60000, 28, 28, 1])
X_train = X_train.astype('float32') / 255.0
X_test = X_test.reshape([10000, 28, 28, 1])
X_test = X_test.astype('float32') / 255.0

在这里,重塑后,您将再次对其进行规格化。这是没有必要的。在训练网络时,多次规范化数据可能会减慢收敛速度。

您没有在model.fit函数中传递batch_size值。根据这里的留档,

如果未指定,batch_size将默认为32。

这可能是它需要更多时间执行的原因。尝试将批处理大小增加到100、200等,然后检查执行时间。

 类似资料:
  • 我正在做一个ML/Tensorflow hello world,通过使用MNIST数据集来预测某物是什么类型的衣服,但是当我尝试使用数据将数据加载到我的doe中时。load_data()它给了我以下错误: 使用TensorFlow后端。从下载数据https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1

  • 本文向大家介绍关于Pytorch的MNIST数据集的预处理详解,包括了关于Pytorch的MNIST数据集的预处理详解的使用技巧和注意事项,需要的朋友参考一下 关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等。 操作系统:ubuntu18.04 显卡:GTX1080ti p

  • 源码: tensorflow/g3doc/tutorials-mnist/ 本教程的目标是展示如何下载用于手写数字分类问题所要用到的(经典)MNIST数据集。 教程 文件 本教程需要使用以下文件: 文件 目的 input_data.py 下载用于训练和测试的MNIST数据集的源码 准备数据 MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字

  • 源码: tensorflow/g3doc/tutorials/mnist/ 本教程的目标是展示如何下载用于手写数字分类问题所要用到的(经典)MNIST数据集。 教程 文件 本教程需要使用以下文件: 文件 目的 input_data.py 下载用于训练和测试的MNIST数据集的源码 准备数据 MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字

  • 数据源引用连接中的表,并且可以从不同服务器类型的表中选择数据。数据集中的字段可用于构造图表。在构建图表时,你需要指定图表使用的数据源。 连接窗格 连接窗格是浏览连接、数据库、表、查询的基本途径。如果连接窗格已隐藏,从菜单栏选择“查看”->“显示连接”。 数据源工具栏 数据源工具栏提供了可用于处理数据的控件。 设计窗格 设计窗格让你直观地构建数据源。 Navicat 提供了两种用于连接数据的模式:实

  • 数据源引用连接中的表,并且可以从不同服务器类型的表中选择数据。数据集中的字段可用于构造图表。在构建图表时,你需要指定图表使用的数据源。 连接窗格 连接窗格是浏览连接、数据库、表、查询的基本途径。如果连接窗格已隐藏,从菜单栏选择“查看”->“显示连接”。 数据源工具栏 数据源工具栏提供了可用于处理数据的控件。 设计窗格 设计窗格让你直观地构建数据源。 Navicat 提供了两种用于连接数据的模式:实