YOLOv5: Data Preparation for Training
作者:XD / 发表: 2022年12月22日 03:12 / 更新: 2022年12月22日 03:52 / 编程笔记 / 阅读量:1386
YOLOv5: Data Preparation for Training, from VOC Format to YOLO Format
First, split the data to train and val, and get txt files
import os
import random
train_percent = 0.95
xmlfilepath = 'bucket_v1/Annotations'
txtsavepath = 'bucket_v1/ImageSets'
total_xml = os.listdir(xmlfilepath)
num = len(total_xml)
trainval = range(num)
tr = int(num * train_percent)
train = random.sample(trainval, tr)
ftrain = open('bucket_v1/ImageSets/Main/train.txt', 'w')
fval = open('bucket_v1/ImageSets/Main/val.txt', 'w')
for i in trainval:
name = total_xml[i][:-4] + '\n'
if i in train:
ftrain.write(name)
else:
fval.write(name)
Second, convert VOC format data (xml) to YOLO format data (txt)
import os
from tqdm import tqdm
from lxml import etree
import json
import shutil
voc_root = "/dfs/data/others/byolov5/dataset/bucket_v1"
voc_version = "bucket_v1"
train_txt = "train.txt"
val_txt = "val.txt"
save_file_root = "/dfs/data/others/byolov5/dataset/yolo_data"
voc_images_path = os.path.join(voc_root, "JPEGImages")
voc_xml_path = os.path.join(voc_root, "Annotations")
train_txt_path = os.path.join(voc_root, "ImageSets", "Main", train_txt)
val_txt_path = os.path.join(voc_root, "ImageSets", "Main", val_txt)
def parse_xml_to_dict(xml):
if len(xml) == 0:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
def translate_info(file_names: list, save_root: str, class_dict: dict, train_val='train'):
save_txt_path = os.path.join(save_root, train_val, "labels")
if os.path.exists(save_txt_path) is False:
os.makedirs(save_txt_path)
save_images_path = os.path.join(save_root, train_val, "images")
if os.path.exists(save_images_path) is False:
os.makedirs(save_images_path)
for file in tqdm(file_names, desc="translate {} file...".format(train_val)):
img_path = os.path.join(voc_images_path, file + ".jpg")
assert os.path.exists(img_path), "file:{} not exist...".format(img_path)
xml_path = os.path.join(voc_xml_path, file + ".xml")
assert os.path.exists(xml_path), "file:{} not exist...".format(xml_path)
# read xml
with open(xml_path, encoding='UTF-8') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = parse_xml_to_dict(xml)["annotation"]
img_height = int(data["size"]["height"])
img_width = int(data["size"]["width"])
# write object info into txt
# assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_path)
if "object" not in data.keys():
print("Warning: in '{}' xml, there are no objects.".format(xml_path))
continue
with open(os.path.join(save_txt_path, file + ".txt"), "w") as f:
for index, obj in enumerate(data["object"]):
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
class_name = obj["name"]
class_index = class_dict[class_name] - 1
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
xcenter = xmin + (xmax - xmin) / 2
ycenter = ymin + (ymax - ymin) / 2
w = xmax - xmin
h = ymax - ymin
xcenter = round(xcenter / img_width, 6)
ycenter = round(ycenter / img_height, 6)
w = round(w / img_width, 6)
h = round(h / img_height, 6)
info = [str(i) for i in [class_index, xcenter, ycenter, w, h]]
if index == 0:
f.write(" ".join(info))
else:
f.write("\n" + " ".join(info))
# copy image into save_images_path
path_copy_to = os.path.join(save_images_path, img_path.split(os.sep)[-1])
if os.path.exists(path_copy_to) is False:
shutil.copyfile(img_path, path_copy_to)
def main():
class_dict = {"b": 1, "t": 2}
with open(train_txt_path, "r") as r:
train_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
translate_info(train_file_names, save_file_root, class_dict, "train")
with open(val_txt_path, "r") as r:
val_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
translate_info(val_file_names, save_file_root, class_dict, "val")
if name == "main":
main()