在进行目标检测或语义分割模型的训练时,我们往往会融合开源数据,但数据集较多就会存在标签不一致的问题,下面的代码可用于修改txt格式标签文件中的label字段:
import os
def changeLabel(txt_path, new_txt_path):
"""
:param txt_path: 原始的标签文件路径
:param new_txt_path: 保存修改后的标签文件的路径
"""
filenames = os.listdir(txt_path)
for filename in filenames:
if '.txt' in filename: # 根据类型判断是否为标签文件
with open(txt_path + filename, 'r', encoding='gbk') as f:
file_data = ''
for line in f.readlines():
# 替换label字段
line = line.replace('storage_tank', 'storage-tank') #将需要修改的字段进行替换replace('错误的字段',‘正确的字段’)
file_data += line
with open(new_txt_path + filename, 'w', encoding="gbk") as fw:
print('filename: ', filename)
print('file_data: ', file_data)
fw.write(file_data)
label_path = './test/labels/' # 原始标签文件夹
dst_path = './test/dst_labels/' # 存储修改后标签的文件夹
changeLabel(label_path, dst_path)
对json格式的标签文件进行修改:
import json
import os
def changeJsonLabelName(json_dir, new_json_dir):
json_files = os.listdir(json_dir)
json_dict = {}
# 需要修改的新名称
new_name1 = 'plane'
for json_file in json_files:
jsonfile = json_dir + '/' + json_file
json_out = new_json_dir + '/' + json_file
# 读单个json文件
with open(jsonfile, 'r', encoding='utf-8') as jf:
info = json.load(jf)
# print(type(info))
# 找到位置进行修改
for i, label in enumerate(info['shapes']):
if "pla" in info['shapes'][i]['label']:
info['shapes'][i]['label'] = new_name1
else:
print(info['shapes'][i]['label'])
# 使用新字典替换修改后的字典
json_dict = info
# set_trace()
# 将替换后的内容写入原文件
with open(json_out, 'w') as new_jf:
json.dump(json_dict, new_jf, indent=2)
print('change name over!')
label_path = './test/labels/' # 原始标签文件夹
dst_path = './test/dst_labels/' # 存储修改后标签的文件夹
changeJsonLabelName(label_path, dst_path)
对xml格式的标签文件进行修改:
import os
import os.path
from xml.etree.ElementTree import parse, Element
def changeName(xml_fold, origin_name, new_name):
"""
更改某一类别的标签名
:param xml_fold: xml标签文件的路径
:param origin_name: 原始label名称
:param new_name: 新的label名称
:return:
"""
files = os.listdir(xml_fold)
cnt = 0
for xmlFile in files:
file_path = os.path.join(xml_fold, xmlFile)
dom = parse(file_path)
root = dom.getroot()
for obj in root.iter('object'):
tmp_name = obj.find('name').text
if tmp_name == origin_name:
obj.find('name').text = new_name
print("change %s to %s." % (origin_name, new_name))
cnt += 1
dom.write(file_path, xml_declaration=True)
def changeAll(xml_fold, new_name):
"""
将所有xml文件的标签都修改为新标签
:param xml_fold: xml标签文件的路径
:param new_name: 需要替换的新的标签的名字
:return:
"""
files = os.listdir(xml_fold)
cnt = 0
for xmlFile in files:
file_path = os.path.join(xml_fold, xmlFile)
dom = parse(file_path)
root = dom.getroot()
for obj in root.iter('object'):
tmp_name = obj.find('name').text
obj.find('name').text = new_name
print("change %s to %s." % (tmp_name, new_name))
cnt += 1
dom.write(file_path, xml_declaration=True)
print(cnt)