数据载入(数据增强)部分,即自定义数据集部分
在dataset.py的create_dataloader函数中被调用
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None,
rect=False, image_weights=False, cache_images=False, single_cls=False,
stride=32, pad=0.0, prefix=''):
self.img_size = img_size # 图像尺寸,默认为640 X 640
self.augment = augment # 是否采用数据增强
self.hyp = hyp # 超参列表
self.image_weights = image_weights # 是否开启按权重采样 即根据类别频率(频率高的权重小)来进行采样 默认False
self.rect = False if image_weights else rect # 是否启动矩形训练 一般训练时关闭 验证时打开 可以加速
self.mosaic = self.augment and not self.rect # 是否启用马赛克增强
self.mosaic_border = [-img_size // 2, -img_size // 2] # mosaic增强的边界值
self.stride = stride # 最大下采样率 32
self.path = path # 图片路径 我们训练一般传入:test.txt/train.txt
# 如果数据增强,用pytorch自带的Albumentations()进行数据增强
self.albumentations = Albumentations() if augment else None
1、判断path类型(为txt还是文件夹)
2、筛选f中所有的图片文件
self.img_files == 需要载入的图片全路径列表
# 2、得到path路径下的所有图片的路径 self.img_files-------------------------------------
try:
f = [] # image files
# 2.1 判断path的类型--------------------------------------------------------------
for p in path if isinstance(path, list) else [path]:
# 获取数据集路径path,包含图片路径的txt文件或者包含图片的文件夹路径
p = Path(p)
# 如果路径path为包含图片的文件夹路径
if p.is_dir(): # dir
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# 如果路径path为包含图片路径的txt文件
elif p.is_file(): # file
with open(p, 'r') as t:
# # 获取图片路径
t = t.read().strip().splitlines()
# 获取数据集路径的上级父目录 os.sep为路径里的分隔符(不同路径的分隔符不同,os.sep可以根据系统自适应
parent = str(p.parent) + os.sep
# 更换成全路径,防止传入的为相对路径
f += [x.replace('./', parent) if x.startswith('./') else x for x in t]
else:
raise Exception(f'{prefix}{p}不存在')
# 破折号替换为os.sep
# os.path.splitext(x)将文件名与扩展名分开并返回一个列表
# 2.2 筛选f中所有的图片文件------------------------------------------------------------------------------
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS])
# IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo']
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
assert self.img_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
self.label_files = img2label_paths(self.img_files)
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
try:
# 如果有cache文件,直接加载 exists=True: 是否已从cache文件中读出了nf, nm, ne, nc, n等信息
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
# 如果图片版本信息或者文件列表的hash值对不上号 说明本地数据集图片和label可能发生了变化 就重新cache label文件
assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
except:
# 否则调用cache_labels缓存标签及标签相关信息
cache, exists = self.cache_labels(cache_path, prefix), False # cache
# 打印cache的结果 nf nm ne nc n = 找到的标签数量,漏掉的标签数量,空的标签数量,损坏的标签数量,总的标签数量
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
if exists:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # 展示cache results
if cache['msgs']:
logging.info('\n'.join(cache['msgs'])) # display warnings
# 数据集没有标签信息 就发出警告并显示标签label下载地址help_url
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
# cache中的键值对最初有:
# cache[img_file]=[l, shape, segments] cache[hash] cache[results] cache[msg] cache[version]
# 5.1 先从cache中去除cache文件中其他无关键值如:'hash', 'version', 'msgs'等
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
# 只剩下cache[img_file]=[l, shape, segments]
# cache.values(): 取cache中所有值 对应所有l, shape, segments
# labels: 如果数据集所有图片中没有一个多边形label labels存储的label就都是原始label(都是正常的矩形label)
# 否则将所有图片正常gt的label存入labels 不正常gt(存在一个多边形)经过segments2boxes转换为正常的矩形label
# shapes: 所有图片的shape
# self.segments: 如果数据集所有图片中没有一个多边形label self.segments=None
# 否则存储数据集中所有存在多边形gt的图片的所有原始label(肯定有多边形label 也可能有矩形正常label 未知数)
# zip 是因为cache中所有labels、shapes、segments信息都是按每张img分开存储的, zip是将所有图片对应的信息叠在一起
labels, shapes, self.segments = zip(*cache.values()) # segments: 都是[]
self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # 更新所有图片的img_files信息 update img_files from cache result
self.label_files = img2label_paths(cache.keys()) # 更新所有图片的label_files信息(因为img_files信息可能发生了变化)
if single_cls:
for x in self.labels:
x[:, 0] = 0
n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = range(n) # 所有图片的index
几个重要参数
1、self.n # 图片的数量
2、self.indices # 所有图片的索引 0,1,2,3---
3、self.labels # [array([[7, 0.16265,0.50026, 0.072321, 0.99947]], dtype=float32),.....] 所有label的框
4、self.segments # [[],[]] 存储数据集中所有存在多边形gt的图片的所有原始label
5、self.shapes # [[3360 1900] [3360 1900][3360 1900]---] 所有图片的尺寸
6、self.img_files # 所有图片的全路径
7、self.label_files # 所有xml的全路径
8、self.batch # 图片索引/你设的bs
9、nb: 每一轮的batch个数
# 这里主要是注意shapes的生成 这一步很重要 因为如果采样矩形训练那么整个batch的形状要一样 就要计算这个符合整个batch的shape
# 而且还要对数据集按照高宽比进行排序 这样才能保证同一个batch的图片的形状差不多相同 再选则一个共同的shape代价也比较小
if self.rect:
# self.rect = False if image_weights else rect # 是否启动矩形训练 一般训练时关闭 验证时打开 可以加速
# 6.1 排序
s = self.shapes # 图像的宽高
ar = s[:, 1] / s[:, 0] # 高宽比
irect = ar.argsort() # 根据高宽比排序
# np.argsort() 将x中的元素从小到大排列,返回其索引列表
self.img_files = [self.img_files[i] for i in irect] # 获取排序后的img_files
self.label_files = [self.label_files[i] for i in irect] # 获取排序后的label_files
self.labels = [self.labels[i] for i in irect] # 获取排序后的labels
self.shapes = s[irect] # 获取排序后的wh
ar = ar[irect] # 获取排序后的高宽比
# 6.2 计算每个batch采用的统一尺度
shapes = [[1, 1]] * nb # nb: 每一轮的batch个数
for i in range(nb):
ari = ar[bi == i] # bi: batch index
mini, maxi = ari.min(), ari.max() # 获取第i个batch中,最小和最大高宽比
# 如果高/宽小于1(w > h),将w设为img_size(保证原图像尺度不变进行缩放)
if maxi < 1:
shapes[i] = [maxi, 1] # maxi: h相对指定尺度的比例 1: w相对指定尺度的比例
# 如果高/宽大于1(w < h),将h设置为img_size(保证原图像尺度不变进行缩放)
elif mini > 1:
shapes[i] = [1, 1 / mini]
# 计算每个batch输入网络的shape值(向上设置为32的整数倍)
# 要求每个batch_shapes的高宽都是32的整数倍,所以要先除以32,取整再乘以32(不过img_size如果是32倍数这里就没必要了)
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
# 一般是False 因为RAM会不足 cache label还可以 但是cache image就太大了 所以一般不用
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs = [None] * n
if cache_images:
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.close()
在上面的__init__函数调用,用于缓存标签及标签相关信息
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.img_files))
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x[im_file] = [l, shape, segments]
if msg:
msgs.append(msg)
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
pbar.close()
if msgs:
logging.info('\n'.join(msgs))
if nf == 0:
logging.info(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, len(self.img_files)
x['msgs'] = msgs # warnings
x['version'] = 0.4 # cache version
try:
np.save(path, x) # save cache for next time
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
logging.info(f'{prefix}New cache created: {path}')
except Exception as e:
logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
return x