当前位置: 首页 > 工具软件 > Py3Cache > 使用案例 >

yolov5-5.0版本代码详解----datasets.py的LoadImagesAndLabels函数

戈宏义
2023-12-01

yolov5-5.0版本代码详解----datasets.py的LoadImagesAndLabels函数(1)

1、作用

数据载入(数据增强)部分,即自定义数据集部分

2、调用位置

在dataset.py的create_dataloader函数中被调用

3、代码详解

1、__init__函数

1.1 传入参数
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None,
                 rect=False, image_weights=False, cache_images=False, single_cls=False,
                 stride=32, pad=0.0, prefix=''):
        self.img_size = img_size  # 图像尺寸,默认为640 X 640
        self.augment = augment  # 是否采用数据增强
        self.hyp = hyp  # 超参列表
        self.image_weights = image_weights  # 是否开启按权重采样  即根据类别频率(频率高的权重小)来进行采样  默认False
        self.rect = False if image_weights else rect  # 是否启动矩形训练 一般训练时关闭 验证时打开 可以加速
        self.mosaic = self.augment and not self.rect  # 是否启用马赛克增强
        self.mosaic_border = [-img_size // 2, -img_size // 2]  # mosaic增强的边界值
        self.stride = stride  # 最大下采样率 32
        self.path = path  # 图片路径  我们训练一般传入:test.txt/train.txt

        # 如果数据增强,用pytorch自带的Albumentations()进行数据增强
        self.albumentations = Albumentations() if augment else None
1.2 得到path路径下的所有图片的全路径
1、判断path类型(为txt还是文件夹)
2、筛选f中所有的图片文件
self.img_files == 需要载入的图片全路径列表
        # 2、得到path路径下的所有图片的路径 self.img_files-------------------------------------
        try:
            f = []  # image files
            # 2.1 判断path的类型--------------------------------------------------------------
            for p in path if isinstance(path, list) else [path]:
                # 获取数据集路径path,包含图片路径的txt文件或者包含图片的文件夹路径
                p = Path(p)

                # 如果路径path为包含图片的文件夹路径
                if p.is_dir():  # dir
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)

                # 如果路径path为包含图片路径的txt文件
                elif p.is_file():  # file
                    with open(p, 'r') as t:
                        # # 获取图片路径
                        t = t.read().strip().splitlines()

                        # 获取数据集路径的上级父目录  os.sep为路径里的分隔符(不同路径的分隔符不同,os.sep可以根据系统自适应
                        parent = str(p.parent) + os.sep

                        # 更换成全路径,防止传入的为相对路径
                        f += [x.replace('./', parent) if x.startswith('./') else x for x in t]
                else:
                    raise Exception(f'{prefix}{p}不存在')

            # 破折号替换为os.sep
            # os.path.splitext(x)将文件名与扩展名分开并返回一个列表
            # 2.2 筛选f中所有的图片文件------------------------------------------------------------------------------
            self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS])
            # IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo']
            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats])  # pathlib
            assert self.img_files, f'{prefix}No images found'
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
1.3 根据imgs路径找到labels的路径self.label_files
	self.label_files = img2label_paths(self.img_files)
1.4 cache label 下次运行这个脚本的时候直接从cache中取label而不是去文件中取label 速度更快
        cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
        try:
            # 如果有cache文件,直接加载  exists=True: 是否已从cache文件中读出了nf, nm, ne, nc, n等信息
            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict

            # 如果图片版本信息或者文件列表的hash值对不上号 说明本地数据集图片和label可能发生了变化 就重新cache label文件
            assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
        except:
            # 否则调用cache_labels缓存标签及标签相关信息
            cache, exists = self.cache_labels(cache_path, prefix), False  # cache

        # 打印cache的结果 nf nm ne nc n = 找到的标签数量,漏掉的标签数量,空的标签数量,损坏的标签数量,总的标签数量
        nf, nm, ne, nc, n = cache.pop('results')  # found, missing, empty, corrupted, total
        if exists:
            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
            tqdm(None, desc=prefix + d, total=n, initial=n)  # 展示cache results
            if cache['msgs']:
                logging.info('\n'.join(cache['msgs']))  # display warnings

        # 数据集没有标签信息 就发出警告并显示标签label下载地址help_url
        assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
1.5 从cache中读出最新变量并赋值
       # cache中的键值对最初有:
        # cache[img_file]=[l, shape, segments] cache[hash] cache[results] cache[msg] cache[version]

        # 5.1 先从cache中去除cache文件中其他无关键值如:'hash', 'version', 'msgs'等
        [cache.pop(k) for k in ('hash', 'version', 'msgs')]  # remove items
        # 只剩下cache[img_file]=[l, shape, segments]

        # cache.values(): 取cache中所有值 对应所有l, shape, segments
        # labels: 如果数据集所有图片中没有一个多边形label  labels存储的label就都是原始label(都是正常的矩形label)
        #         否则将所有图片正常gt的label存入labels 不正常gt(存在一个多边形)经过segments2boxes转换为正常的矩形label
        # shapes: 所有图片的shape
        # self.segments: 如果数据集所有图片中没有一个多边形label  self.segments=None
        #                否则存储数据集中所有存在多边形gt的图片的所有原始label(肯定有多边形label 也可能有矩形正常label 未知数)
        # zip 是因为cache中所有labels、shapes、segments信息都是按每张img分开存储的, zip是将所有图片对应的信息叠在一起

        labels, shapes, self.segments = zip(*cache.values())  # segments: 都是[]
        self.labels = list(labels)
        self.shapes = np.array(shapes, dtype=np.float64)
        self.img_files = list(cache.keys())  # 更新所有图片的img_files信息 update img_files from cache result
        self.label_files = img2label_paths(cache.keys())  # 更新所有图片的label_files信息(因为img_files信息可能发生了变化)
        if single_cls:
            for x in self.labels:
                x[:, 0] = 0

        n = len(shapes)  # number of images
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch index
        nb = bi[-1] + 1  # number of batches
        self.batch = bi  # batch index of image
        self.n = n
        self.indices = range(n)  # 所有图片的index
几个重要参数
1、self.n  # 图片的数量
2、self.indices  # 所有图片的索引 0,1,2,3---
3、self.labels # [array([[7, 0.16265,0.50026, 0.072321, 0.99947]], dtype=float32),.....] 所有label的框
4、self.segments # [[],[]] 存储数据集中所有存在多边形gt的图片的所有原始label
5、self.shapes # [[3360 1900] [3360 1900][3360 1900]---] 所有图片的尺寸
6、self.img_files  # 所有图片的全路径
7、self.label_files  # 所有xml的全路径
8、self.batch  # 图片索引/你设的bs
9、nb: 每一轮的batch个数
1 .6设置batch_shapes,即放入网络的图像尺寸
        # 这里主要是注意shapes的生成 这一步很重要 因为如果采样矩形训练那么整个batch的形状要一样 就要计算这个符合整个batch的shape
        # 而且还要对数据集按照高宽比进行排序 这样才能保证同一个batch的图片的形状差不多相同 再选则一个共同的shape代价也比较小
        if self.rect:
            # self.rect = False if image_weights else rect  # 是否启动矩形训练 一般训练时关闭 验证时打开 可以加速

            # 6.1 排序
            s = self.shapes  # 图像的宽高
            ar = s[:, 1] / s[:, 0]  # 高宽比
            irect = ar.argsort()  # 根据高宽比排序
            # np.argsort() 将x中的元素从小到大排列,返回其索引列表
            self.img_files = [self.img_files[i] for i in irect]  # 获取排序后的img_files
            self.label_files = [self.label_files[i] for i in irect]  # 获取排序后的label_files
            self.labels = [self.labels[i] for i in irect]  # 获取排序后的labels
            self.shapes = s[irect]  # 获取排序后的wh
            ar = ar[irect]  # 获取排序后的高宽比

            # 6.2 计算每个batch采用的统一尺度
            shapes = [[1, 1]] * nb  # nb: 每一轮的batch个数
            for i in range(nb):
                ari = ar[bi == i]  # bi: batch index
                mini, maxi = ari.min(), ari.max()  # 获取第i个batch中,最小和最大高宽比
                # 如果高/宽小于1(w > h),将w设为img_size(保证原图像尺度不变进行缩放)
                if maxi < 1:
                    shapes[i] = [maxi, 1]  # maxi: h相对指定尺度的比例  1: w相对指定尺度的比例
                # 如果高/宽大于1(w < h),将h设置为img_size(保证原图像尺度不变进行缩放)
                elif mini > 1:
                    shapes[i] = [1, 1 / mini]

            # 计算每个batch输入网络的shape值(向上设置为32的整数倍)
            # 要求每个batch_shapes的高宽都是32的整数倍,所以要先除以32,取整再乘以32(不过img_size如果是32倍数这里就没必要了)
            self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
1.7 是否需要cache image
       # 一般是False 因为RAM会不足  cache label还可以 但是cache image就太大了 所以一般不用
        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
        self.imgs = [None] * n
        if cache_images:
            gb = 0  # Gigabytes of cached images
            self.img_hw0, self.img_hw = [None] * n, [None] * n
            results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
            pbar = tqdm(enumerate(results), total=n)
            for i, x in pbar:
                self.imgs[i], self.img_hw0[i], self.img_hw[i] = x  # img, hw_original, hw_resized = load_image(self, i)
                gb += self.imgs[i].nbytes
                pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
            pbar.close()

2、cache_labels函数

在上面的__init__函数调用,用于缓存标签及标签相关信息

    def cache_labels(self, path=Path('./labels.cache'), prefix=''):
        # Cache dataset labels, check images and read shapes
        x = {}  # dict
        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
        desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
        with Pool(NUM_THREADS) as pool:
            pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
                        desc=desc, total=len(self.img_files))
            for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                nm += nm_f
                nf += nf_f
                ne += ne_f
                nc += nc_f
                if im_file:
                    x[im_file] = [l, shape, segments]
                if msg:
                    msgs.append(msg)
                pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

        pbar.close()
        if msgs:
            logging.info('\n'.join(msgs))
        if nf == 0:
            logging.info(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
        x['hash'] = get_hash(self.label_files + self.img_files)
        x['results'] = nf, nm, ne, nc, len(self.img_files)
        x['msgs'] = msgs  # warnings
        x['version'] = 0.4  # cache version
        try:
            np.save(path, x)  # save cache for next time
            path.with_suffix('.cache.npy').rename(path)  # remove .npy suffix
            logging.info(f'{prefix}New cache created: {path}')
        except Exception as e:
            logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}')  # path not writeable
        return x
 类似资料: