framework/train_models/train_Yolo.py

181 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
train_Yolo
Общая задача: обнаружение объекта (Object detection)
Реализуемая функция: обучение нейросетевой модели YoloV8 по заданному BOP-датасету
python3 $PYTHON_TRAIN --path /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296/datasets \
--name test123 --datasetName ds213 --outpath /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296/weights
27.04.2024 @shalenikol release 0.1
"""
import os
import shutil
import json
import yaml
from ultralytics import YOLO
# from ultralytics.utils.metrics import DetMetrics
FILE_BASEMODEL = "yolov8n.pt"
FILE_RBS_INFO = "rbs_info.json"
FILE_RBS_TRAIN = "rbs_train.yaml"
FILE_GT_COCO = "scene_gt_coco.json"
FILE_L_TRAIN = "i_train.txt"
FILE_L_VAL = "i_val.txt"
FILE_TRAIN_RES = "weights/last.pt"
DIR_ROOT_DS = "datasets"
DIR_COCO_DS = "rbs_coco"
DIR_RGB_DS = "images"
DIR_LABELS_DS = "labels"
SZ_SERIES = 15 # number of train images per validation images
nn_image = 0
f1 = f2 = None
def convert2relative(height, width, bbox):
""" YOLO format use relative coordinates for annotation """
x, y, w, h = bbox
x += w/2
y += h/2
return x/width, y/height, w/width, h/height
def gt_parse(path: str, out_dir: str):
global nn_image, f1, f2
with open(os.path.join(path, FILE_GT_COCO), "r") as fh:
coco_data = json.load(fh)
for img in coco_data["images"]:
rgb_file = os.path.join(path, img["file_name"])
if os.path.isfile(rgb_file):
ext = os.path.splitext(rgb_file)[1] # only ext
f = f"{nn_image:06}"
out_img = os.path.join(out_dir, DIR_RGB_DS, f + ext)
shutil.copy2(rgb_file, out_img)
# заполним файлы с метками bbox
img_id = img["id"]
with open(os.path.join(out_dir, DIR_LABELS_DS, f + ".txt"), "w") as fh:
for i in coco_data["annotations"]:
if i["image_id"] == img_id:
cat_id = i["category_id"]
if cat_id < 999:
bbox = i["bbox"]
im_h = i["height"]
im_w = i["width"]
rel = convert2relative(im_h,im_w,bbox)
# формат: <target> <x-center> <y-center> <width> <height>
fh.write(f"{cat_id-1} {rel[0]} {rel[1]} {rel[2]} {rel[3]}\n") # category from 0
nn_image += 1
line = os.path.join("./", DIR_RGB_DS, f + ext) + "\n"
if nn_image % SZ_SERIES == 0:
f2.write(line)
else:
f1.write(line)
def explore(path: str, res_dir: str):
if not os.path.isdir(path):
return
folders = [
os.path.join(path, o)
for o in os.listdir(path)
if os.path.isdir(os.path.join(path, o))
]
for path_entry in folders:
if os.path.isfile(os.path.join(path_entry,FILE_GT_COCO)):
gt_parse(path_entry, res_dir)
else:
explore(path_entry, res_dir)
def BOP2Yolo_dataset(dpath: str, out_dir: str, lname: list) -> str:
""" Convert BOP-dataset to YOLO format for train """
cfg_yaml = os.path.join(out_dir, FILE_RBS_TRAIN)
p = os.path.join(out_dir, DIR_ROOT_DS, DIR_COCO_DS)
cfg_data = {"path": p, "train": FILE_L_TRAIN, "val": FILE_L_VAL}
cfg_data["names"] = {i:x for i,x in enumerate(lname)}
with open(cfg_yaml, "w") as fh:
yaml.dump(cfg_data, fh)
res_dir = os.path.join(out_dir, DIR_ROOT_DS)
if not os.path.isdir(res_dir):
os.mkdir(res_dir)
res_dir = os.path.join(res_dir, DIR_COCO_DS)
if not os.path.isdir(res_dir):
os.mkdir(res_dir)
p = os.path.join(res_dir, DIR_RGB_DS)
if not os.path.isdir(p):
os.mkdir(p)
p = os.path.join(res_dir, DIR_LABELS_DS)
if not os.path.isdir(p):
os.mkdir(p)
global f1, f2
f1 = open(os.path.join(res_dir, FILE_L_TRAIN), "w")
f2 = open(os.path.join(res_dir, FILE_L_VAL), "w")
explore(dpath, res_dir)
f1.close()
f2.close()
return out_dir
def train_YoloV8(path:str, wname:str, dname:str, outpath:str, epochs:int, pretrain: bool):
""" Main procedure for train YOLOv8 model """
if not os.path.isdir(outpath):
print(f"Invalid output path '{outpath}'")
exit(-1)
out_dir = os.path.join(outpath, wname)
if pretrain:
# продолжить обучение
if not os.path.isdir(out_dir):
print(f"No dir '{out_dir}'")
exit(-2)
dpath = out_dir
model_path = os.path.join(dpath, wname + ".pt")
else:
# обучение сначала
if not os.path.isdir(out_dir):
os.mkdir(out_dir)
ds_path = os.path.join(path, dname)
rbs_info = os.path.join(ds_path, FILE_RBS_INFO)
if not os.path.isfile(rbs_info):
print(f"{rbs_info} : no dataset description file")
exit(-3)
with open(rbs_info, "r") as fh:
y = json.load(fh)
# список имён объектов
list_name = list(map(lambda x: x["name"], y))
dpath = BOP2Yolo_dataset(ds_path, out_dir, list_name)
if len(dpath) == 0:
print(f"Error in convert dataset '{ds_path}' to '{outpath}'")
exit(-4)
model_path = os.path.join(dpath, FILE_BASEMODEL)
model = YOLO(model_path)
results = model.train(data=os.path.join(dpath, FILE_RBS_TRAIN), epochs=epochs, project=out_dir)
wf = os.path.join(results.save_dir, FILE_TRAIN_RES)
if not os.path.isfile(wf):
print(f"Error in train: no result file '{wf}'")
exit(-5)
shutil.copy2(wf, os.path.join(dpath, wname + ".pt"))
shutil.rmtree(results.save_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", required=True, help="Path for dataset")
parser.add_argument("--name", required=True, help="String with result weights name")
parser.add_argument("--datasetName", required=True, help="String with dataset name")
parser.add_argument("--outpath", default="weights", help="Output path for weights")
parser.add_argument("--epoch", default=3, type=int, help="How many training epochs")
parser.add_argument('--pretrain', action="store_true", help="Use pretraining")
args = parser.parse_args()
train_YoloV8(args.path, args.name, args.datasetName, args.outpath, args.epoch, args.pretrain)