webstudio/web_p/train_Yolo.py

248 lines
9.3 KiB
Python
Raw Normal View History

2024-05-02 17:36:44 +03:00
"""
train_Yolo
Общая задача: обнаружение объекта (Object detection)
Реализуемая функция: обучение нейросетевой модели YoloV8 по заданному BOP-датасету
python3 $PYTHON_TRAIN --path /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296/datasets \
--name test123 --datasetName ds213 --outpath /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296/weights
27.04.2024 @shalenikol release 0.1
20.11.2024 @shalenikol release 0.2 parser.add_argument("--addon", default="", help="Folder with add-on for dataset")
20.02.2025 @shalenikol release 0.2.1 add_on_dataset : fix
2024-05-02 17:36:44 +03:00
"""
import os
import shutil
import json
import yaml
from ultralytics import YOLO
# from ultralytics import settings
2024-05-02 17:36:44 +03:00
# from ultralytics.utils.metrics import DetMetrics
# import torch
# import torch.profiler
# import torch.utils.data
2024-05-02 17:36:44 +03:00
FILE_BASEMODEL = "yolov8s.pt" #"yolov8n.pt"
2024-05-02 17:36:44 +03:00
FILE_RBS_INFO = "rbs_info.json"
FILE_RBS_TRAIN = "rbs_train.yaml"
FILE_GT_COCO = "scene_gt_coco.json"
FILE_L_TRAIN = "i_train.txt"
FILE_L_VAL = "i_val.txt"
FILE_TRAIN_RES = "weights/last.pt"
DIR_ROOT_DS = "datasets"
DIR_COCO_DS = "rbs_coco"
DIR_RGB_DS = "images"
DIR_LABELS_DS = "labels"
LABELS_EXT = ".txt"
2024-05-02 17:36:44 +03:00
2024-07-03 15:25:51 +03:00
SZ_SERIES = 15 # number of train images per validation images
2024-05-02 17:36:44 +03:00
nn_image = 0
f1 = f2 = None
def convert2relative(height, width, bbox):
""" YOLO format use relative coordinates for annotation """
x, y, w, h = bbox
x += w/2
y += h/2
return x/width, y/height, w/width, h/height
def add_on_dataset(source_dir, target_dir) -> dict:
global nn_image, f1, f2
# Получаем список файлов в исходной директории
files = sorted(os.listdir(source_dir))
# Словарь для отслеживания порядковых номеров для каждого имени файла
file_nn = {}
for file in files:
if os.path.isdir(os.path.join(source_dir, file)):
continue
# Получаем имя файла и его расширение
file_name, file_extension = os.path.splitext(file)
# Запоминаем порядковый номер для данного имени файла
if file_name in file_nn:
nn = file_nn[file_name]
else: # new file name
nn = nn_image # текущий номер
file_nn[file_name] = nn_image
nn_image += 1
# Создаем новое имя файла
new_file_name = f"{nn:06}{file_extension}"
if file_extension == LABELS_EXT:
new_file_path = os.path.join(target_dir, DIR_LABELS_DS)
else:
new_file_path = os.path.join(target_dir, DIR_RGB_DS)
line = os.path.join("./", DIR_RGB_DS, new_file_name) + "\n"
if nn % SZ_SERIES == 0:
f2.write(line)
else:
f1.write(line)
# Полные пути к старому и новому файлам
old_file_path = os.path.join(source_dir, file)
new_file_path = os.path.join(new_file_path, new_file_name)
# Копируем файл
shutil.copy2(old_file_path, new_file_path)
return file_nn
2024-05-02 17:36:44 +03:00
def gt_parse(path: str, out_dir: str):
global nn_image, f1, f2
with open(os.path.join(path, FILE_GT_COCO), "r") as fh:
coco_data = json.load(fh)
for img in coco_data["images"]:
rgb_file = os.path.join(path, img["file_name"])
if os.path.isfile(rgb_file):
ext = os.path.splitext(rgb_file)[1] # only ext
f = f"{nn_image:06}"
out_img = os.path.join(out_dir, DIR_RGB_DS, f + ext)
shutil.copy2(rgb_file, out_img)
# заполним файлы с метками bbox
img_id = img["id"]
with open(os.path.join(out_dir, DIR_LABELS_DS, f + ".txt"), "w") as fh:
for i in coco_data["annotations"]:
if i["image_id"] == img_id:
cat_id = i["category_id"]
if cat_id < 999:
bbox = i["bbox"]
im_h = i["height"]
im_w = i["width"]
rel = convert2relative(im_h,im_w,bbox)
# формат: <target> <x-center> <y-center> <width> <height>
fh.write(f"{cat_id-1} {rel[0]} {rel[1]} {rel[2]} {rel[3]}\n") # category from 0
line = os.path.join("./", DIR_RGB_DS, f + ext) + "\n"
if nn_image % SZ_SERIES == 0:
f2.write(line)
else:
f1.write(line)
nn_image += 1
2024-05-02 17:36:44 +03:00
def explore(path: str, res_dir: str):
if not os.path.isdir(path):
return
folders = [
os.path.join(path, o)
for o in os.listdir(path)
if os.path.isdir(os.path.join(path, o))
]
for path_entry in folders:
if os.path.isfile(os.path.join(path_entry,FILE_GT_COCO)):
gt_parse(path_entry, res_dir)
else:
explore(path_entry, res_dir)
def BOP2Yolo_dataset(dpath: str, out_dir: str, lname: list, addon:str) -> str:
2024-05-02 17:36:44 +03:00
""" Convert BOP-dataset to YOLO format for train """
cfg_yaml = os.path.join(out_dir, FILE_RBS_TRAIN)
p = os.path.join(out_dir, DIR_ROOT_DS, DIR_COCO_DS)
cfg_data = {"path": p, "train": FILE_L_TRAIN, "val": FILE_L_VAL}
cfg_data["names"] = {i:x for i,x in enumerate(lname)}
with open(cfg_yaml, "w") as fh:
yaml.dump(cfg_data, fh)
res_dir = os.path.join(out_dir, DIR_ROOT_DS)
if not os.path.isdir(res_dir):
os.mkdir(res_dir)
res_dir = os.path.join(res_dir, DIR_COCO_DS)
if not os.path.isdir(res_dir):
os.mkdir(res_dir)
p = os.path.join(res_dir, DIR_RGB_DS)
if not os.path.isdir(p):
os.mkdir(p)
p = os.path.join(res_dir, DIR_LABELS_DS)
if not os.path.isdir(p):
os.mkdir(p)
global f1, f2
f1 = open(os.path.join(res_dir, FILE_L_TRAIN), "w")
f2 = open(os.path.join(res_dir, FILE_L_VAL), "w")
explore(dpath, res_dir)
if addon:
add_on_dataset(addon, res_dir)
2024-05-02 17:36:44 +03:00
f1.close()
f2.close()
return out_dir
def train_YoloV8(path:str, wname:str, dname:str, outpath:str, epochs:int, pretrain: bool, addon: str):
2024-05-02 17:36:44 +03:00
""" Main procedure for train YOLOv8 model """
if not os.path.isdir(outpath):
print(f"Invalid output path '{outpath}'")
exit(-1)
out_dir = os.path.join(outpath, wname)
if pretrain:
# продолжить обучение
if not os.path.isdir(out_dir):
print(f"No dir '{out_dir}'")
exit(-2)
dpath = out_dir
model_path = os.path.join(dpath, wname + ".pt")
else:
# обучение сначала
if not os.path.isdir(out_dir):
os.mkdir(out_dir)
ds_path = os.path.join(path, dname)
rbs_info = os.path.join(ds_path, FILE_RBS_INFO)
if not os.path.isfile(rbs_info):
print(f"{rbs_info} : no dataset description file")
exit(-3)
with open(rbs_info, "r") as fh:
y = json.load(fh)
# список имён объектов
list_name = list(map(lambda x: x["name"], y))
dpath = BOP2Yolo_dataset(ds_path, out_dir, list_name, addon)
2024-05-02 17:36:44 +03:00
if len(dpath) == 0:
print(f"Error in convert dataset '{ds_path}' to '{outpath}'")
exit(-4)
model_path = os.path.join(dpath, FILE_BASEMODEL)
model = YOLO(model_path)
# # Update settings
# settings.update({"profile": True})
# prof = torch.profiler.profile(
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/resnet18'),
# record_shapes=True,
# with_stack=True)
# prof.start()
results = model.train(data=os.path.join(dpath, FILE_RBS_TRAIN), epochs=epochs, project=out_dir) #, log_dir="runs/train")
# prof.stop()
2024-05-02 17:36:44 +03:00
wf = os.path.join(results.save_dir, FILE_TRAIN_RES)
if not os.path.isfile(wf):
print(f"Error in train: no result file '{wf}'")
exit(-5)
shutil.copy2(wf, os.path.join(dpath, wname + ".pt"))
# shutil.rmtree(results.save_dir)
2024-05-02 17:36:44 +03:00
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", required=True, help="Path for dataset")
parser.add_argument("--name", required=True, help="String with result weights name")
parser.add_argument("--datasetName", required=True, help="String with dataset name")
parser.add_argument("--outpath", default="weights", help="Output path for weights")
parser.add_argument("--epoch", default=3, type=int, help="How many training epochs")
parser.add_argument('--pretrain', action="store_true", help="Use pretraining")
parser.add_argument("--addon", default="", help="Folder with add-on for dataset")
2024-05-02 17:36:44 +03:00
args = parser.parse_args()
train_YoloV8(args.path, args.name, args.datasetName, args.outpath, args.epoch, args.pretrain, args.addon)