当前位置: 首页 > 工具软件 > MindSpore > 使用案例 >

【MindSpore易点通】MindSpore Data经验解析

谭煜
2023-12-01

一、简介

首先MindSpore Data提供了简洁、丰富的数据读取、处理、增强等功能;同时使用读取数据的流程,主要分为三步(使用和PyTorch中数据读取方式类似):

  1. 数据集加载 - 根据数据格式,选择最简单、高效的数据集加载方式;
  2. 数据增强 - 使用几何变换、颜色变换、旋转、平移、缩放等基本图像处理技术来扩充数据集;
  3. 数据处理 - 对数据集做repeat、batch、shuffle、map、zip等操作。

二、使用说明

1、数据集加载

首先加载要使用的数据集,根据实际使用的数据集格式,从以下三种数据集读取方式选取一种即可:

  • 常用标准数据集:例如 ImageNet、MNIST、CIFAR-10、VOC等;
  • 特定格式数据集 :特定存储格式的数据,例如:MindRecord;
  • 自定义数据集:数据组织形式自定义的数据集。

2.1 常用数据集加载

目前已经支持的常用数据集有:MNIST, CIFAR-10, CIFAR-100, VOC, ImageNet, CelebA。如果使用以上开源数据集或者已经将所使用的数据整理为以上标准数据集格式,可以直接使用如下方法加载数据集。以CIFAR-10为例:

import mindspore.dataset as ds

DATA_DIR = "./cifar-10-batches-bin/"
cifar_ds = ds.Cifar10Dataset(DATA_DIR)

数据集加载好之后,就可以调用接口create_dict_iterator()创建迭代器读取数据,后面两种方式同理。

for data in cifar_ds.create_dict_iterator():# In CIFAR-10 dataset, each dictionary of data has keys "image" and "label".

    print(data["image"])
    print(data["label"])

2.2 特定格式数据集加载

目前支持的特定格式数据集为:MindRecord。MindRecord格式的数据读取性能更优,推荐用户将数据转换为MindRecord格式。转换示例如下:

from mindspore.mindrecord import Cifar10ToMR

cifar10_path = "./cifar-10-batches-py"
mindrecord_path = "./cifar10.mindrecord"
cifar10_transformer = Cifar10ToMR(cifar10_path, mindrecord_path)
cifar10_transformer.transform(["label"])

MindRecord数据加载:

import mindspore.dataset as ds

CV_FILE_NAME = "./cifar10.mindrecord"
cifar_ds = ds.MindDataset(dataset_file=CV_FILE_NAME,columns_list=["data","label"], shuffle=True)

2.3 自定义数据集加载

提供的自定义数据集加载方式为:GeneratorDataset接口。GeneratorDataset接口需要自己实现一个生成器,生成训练数据和标签,适用于较复杂的任务。

GeneratorDataset()需要传入一个生成器,生成训练数据。

import mindspore.dataset as ds

class Dataset:
    def __init__(self, image_list, label_list):
        super(Dataset, self).__init__()
        self.imgs = image_list
        self.labels = label_list

    def __getitem__(self, index):
        img = Image.open(self.imgs[index]).convert('RGB')
        return img, self.labels[index]
 
    def __len__(self):
        return len(self.imgs)


class MySampler():
    def __init__(self, dataset):
        self.__num_data = len(dataset)

    def __iter__(self):
        indices = list(range(self.__num_data))
        return iter(indices)

dataset = Dataset(save_image_list, save_label_list)
sampler = MySampler(dataset)
cifar_ds = ds.GeneratorDataset(dataset, column_names=["image", "label"], sampler=sampler, shuffle=True)

以上例子中 dataset是一个生成器,产生image和label。

2、数据增强

提供 c_transforms 和 py_transforms 两个模块来供用户完成数据增强操作,两者的对比如下:

模块名称实现优缺点
c_transforms基于C++的OpenCV实现性能较高
py_transforms基于Python的PIL实现性能较差,但是可以自定义增强函数

使用建议:如果不需要自定义增强函数,并且c_transforms中有对应的实现,建议使用c_transforms模块。

2.1 c_transforms模块

目前c_transforms接口包括两部分:mindspore.dataset.transforms.c_transforms和mindspore.dataset.vision.c_transforms。

使用方法:

1.定义好数据增强函数:把多个增强函数加入到一个list中,并调用Compose封装;

2.调用dataset.map()函数,将定义好的函数或算子作用于指定的数据列。

示例代码如下:

import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV_transforms
import mindspore.dataset.transforms.c_transforms as C_transforms

DATA_DIR = "./cifar-10-batches-bin/"
cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')#定义增强函数列表
transforms_list = C_transforms.Compose[CV_transforms.RandomCrop((32, 32), (4, 4, 4, 4)), CV_transforms.RandomHorizontalFlip(), CV_transforms.Rescale(rescale, shift), CV_transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), CV_transforms.HWC2CHW()]#调用map()函数
cifar_ds = cifar_ds.map(operations=transforms_list, input_columns="image")

其中,input_columns为指定要做增强的数据列,operations为定义的增强函数。

2.2 py_transforms模块

py_transforms接口也包括两部分mindspore.dataset.transforms.py_transforms和mindspore.dataset.vision.py_transforms。

使用方法:和c_transforms模块中的使用方法类似。示例代码如下:

import mindspore.dataset as ds
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms

DATA_DIR = "./cifar-10-batches-bin/"
cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')
transform_list = py_transforms.Compose([
            py_vision.ToPIL(),
            py_vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            py_vision.RandomHorizontalFlip(),
            py_vision.ToTensor(),
            py_vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_ds = cifar_ds.map(operations=transforms_list, input_columns="image")

使用py_transforms自定义增强函数:

自定义增强函数可参考MindSpore源码中的py_transforms_util.py脚本。下面以RandomBrightness为例,说明自定义增强算子的定义方式:

#自定义增强函数定义class RandomBrightness(object):
    """
    Randomly adjust the brightness of the input image.
    Args:
        brightness (float): Brightness adjustment factor (default=0.0).
    Returns:
        numpy.ndarray, image.
    """
    def __init__(self, brightness=0.0):
        self.brightness = brightness
    def __call__(self, img):
        alpha = random.uniform(-self.brightness, self.brightness)
        return (1-alpha) * img

自定义算子的调用和py_transforms_util.py中的算子调用没有区别。

3、数据处理

数据处理操作有:zip、shuffle、map、batch、repeat。

数据处理操作说明
zip合并多个数据集
shuffle混洗数据
map将函数和算子作用于指定列数据
batch将数据分批,每次迭代返回一个batch的数据
repeat对数据集进行复制

一般训练过程中都会用到shuffle、map、batch、repeat,如下示例:

import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV_transforms
import mindspore.dataset.transforms.c_transforms as C_transforms

DATA_DIR = "./cifar-10-batches-bin/"
cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')
transform_list = C.Compose([
            CV.RandomCrop((32, 32), (4, 4, 4, 4)),
            CV.RandomHorizontalFlip(),
            CV.Rescale(rescale, shift),
            CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            CV.HWC2CHW()])# map()
cifar_ds.map(input_columns="image", operations=transforms_list)# batch()
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)# repeat()
cifar_ds = cifar_ds.repeat(repeat_num)

在实际使用过程中,需要组合使用这几个操作时,为达到最优性能,推荐按照如下顺序: 数据集加载并shuffle -> map -> batch -> repeat。

以下简单介绍一下数据处理函数的使用方法:

3.1数据集加载与shuffle

方式一:加载数据集时shuffle

import mindspore.dataset as ds

DATA_DIR = "./cifar-10-batches-bin/"
cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')

方式二:加载数据集后shuffle

import mindspore.dataset as ds

DATA_DIR = "./cifar-10-batches-bin/"
cifar_ds = ds.Cifar10Dataset(DATA_DIR, usage='train')
cifar_ds = cifar_ds.shuffle(buffer_size=10000)

参数说明:

buffer_size:buffer_size越大,混洗程度越大,时间消耗更大

3.2 map:

func = lambda x : x*2
cifar_ds = cifar_ds.map(input_columns="data", operations=func)

参数说明:

input_columns:函数作用的列数据

operations:对数据做操作的函数

3.3 batch

cifar_ds = cifar_ds.batch(batch_size=32, drop_remainder=True, num_parallel_workers=4)

参数说明:

drop_remainder:舍弃最后不完整的batch

num_parallel_workers: 用几个线程来读取数据

3.4 repeat

cifar_ds = cifar_ds.repeat(count=2)

参数说明:

count: 数据集复制数量

3.5 zip

import mindspore.dataset as ds

DATA_DIR_1 = "custom_dataset_dir_1/"
DATA_DIR_2 = "custom_dataset_dir_2/"
imagefolder_dataset_1 = ds.ImageFolderDatasetV2(DATA_DIR_1)
imagefolder_dataset_2 = ds.ImageFolderDatasetV2(DATA_DIR_2)
imagefolder_dataset = ds.zip((imagefolder_dataset_1, imagefolder_dataset_2))

详细代码请前往MindSpore论坛进行下载:华为云论坛_云计算论坛_开发者论坛_技术论坛-华为云

说明:严禁转载本文内容,否则视为侵权。 

 类似资料: