diff --git a/simulation/train_models/models_dope.py b/simulation/train_models/models_dope.py new file mode 100755 index 0000000..0c89004 --- /dev/null +++ b/simulation/train_models/models_dope.py @@ -0,0 +1,196 @@ +""" +NVIDIA from jtremblay@gmail.com +""" + +# Networks +import torch +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data +import torchvision.models as models + + +class DopeNetwork(nn.Module): + def __init__( + self, + pretrained=False, + numBeliefMap=9, + numAffinity=16, + stop_at_stage=6, # number of stages to process (if less than total number of stages) + ): + super(DopeNetwork, self).__init__() + + self.stop_at_stage = stop_at_stage + + vgg_full = models.vgg19(pretrained=False).features + self.vgg = nn.Sequential() + for i_layer in range(24): + self.vgg.add_module(str(i_layer), vgg_full[i_layer]) + + # Add some layers + i_layer = 23 + self.vgg.add_module( + str(i_layer), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1) + ) + self.vgg.add_module(str(i_layer + 1), nn.ReLU(inplace=True)) + self.vgg.add_module( + str(i_layer + 2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) + ) + self.vgg.add_module(str(i_layer + 3), nn.ReLU(inplace=True)) + + # print('---Belief------------------------------------------------') + # _2 are the belief map stages + self.m1_2 = DopeNetwork.create_stage(128, numBeliefMap, True) + self.m2_2 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numBeliefMap, False + ) + self.m3_2 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numBeliefMap, False + ) + self.m4_2 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numBeliefMap, False + ) + self.m5_2 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numBeliefMap, False + ) + self.m6_2 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numBeliefMap, False + ) + + # print('---Affinity----------------------------------------------') + # _1 are the affinity map stages + self.m1_1 = DopeNetwork.create_stage(128, numAffinity, True) + self.m2_1 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numAffinity, False + ) + self.m3_1 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numAffinity, False + ) + self.m4_1 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numAffinity, False + ) + self.m5_1 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numAffinity, False + ) + self.m6_1 = DopeNetwork.create_stage( + 128 + numBeliefMap + numAffinity, numAffinity, False + ) + + def forward(self, x): + """Runs inference on the neural network""" + + out1 = self.vgg(x) + + out1_2 = self.m1_2(out1) + out1_1 = self.m1_1(out1) + + if self.stop_at_stage == 1: + return [out1_2], [out1_1] + + out2 = torch.cat([out1_2, out1_1, out1], 1) + out2_2 = self.m2_2(out2) + out2_1 = self.m2_1(out2) + + if self.stop_at_stage == 2: + return [out1_2, out2_2], [out1_1, out2_1] + + out3 = torch.cat([out2_2, out2_1, out1], 1) + out3_2 = self.m3_2(out3) + out3_1 = self.m3_1(out3) + + if self.stop_at_stage == 3: + return [out1_2, out2_2, out3_2], [out1_1, out2_1, out3_1] + + out4 = torch.cat([out3_2, out3_1, out1], 1) + out4_2 = self.m4_2(out4) + out4_1 = self.m4_1(out4) + + if self.stop_at_stage == 4: + return [out1_2, out2_2, out3_2, out4_2], [out1_1, out2_1, out3_1, out4_1] + + out5 = torch.cat([out4_2, out4_1, out1], 1) + out5_2 = self.m5_2(out5) + out5_1 = self.m5_1(out5) + + if self.stop_at_stage == 5: + return [out1_2, out2_2, out3_2, out4_2, out5_2], [ + out1_1, + out2_1, + out3_1, + out4_1, + out5_1, + ] + + out6 = torch.cat([out5_2, out5_1, out1], 1) + out6_2 = self.m6_2(out6) + out6_1 = self.m6_1(out6) + + return [out1_2, out2_2, out3_2, out4_2, out5_2, out6_2], [ + out1_1, + out2_1, + out3_1, + out4_1, + out5_1, + out6_1, + ] + + @staticmethod + def create_stage(in_channels, out_channels, first=False): + """Create the neural network layers for a single stage.""" + + model = nn.Sequential() + mid_channels = 128 + if first: + padding = 1 + kernel = 3 + count = 6 + final_channels = 512 + else: + padding = 3 + kernel = 7 + count = 10 + final_channels = mid_channels + + # First convolution + model.add_module( + "0", + nn.Conv2d( + in_channels, mid_channels, kernel_size=kernel, stride=1, padding=padding + ), + ) + + # Middle convolutions + i = 1 + while i < count - 1: + model.add_module(str(i), nn.ReLU(inplace=True)) + i += 1 + model.add_module( + str(i), + nn.Conv2d( + mid_channels, + mid_channels, + kernel_size=kernel, + stride=1, + padding=padding, + ), + ) + i += 1 + + # Penultimate convolution + model.add_module(str(i), nn.ReLU(inplace=True)) + i += 1 + model.add_module( + str(i), nn.Conv2d(mid_channels, final_channels, kernel_size=1, stride=1) + ) + i += 1 + + # Last convolution + model.add_module(str(i), nn.ReLU(inplace=True)) + i += 1 + model.add_module( + str(i), nn.Conv2d(final_channels, out_channels, kernel_size=1, stride=1) + ) + i += 1 + + return model diff --git a/simulation/train_models/rbs_train.py b/simulation/train_models/rbs_train.py new file mode 100644 index 0000000..2a78759 --- /dev/null +++ b/simulation/train_models/rbs_train.py @@ -0,0 +1,29 @@ +""" + rbs_train + Общая задача: web-service pipeline + Реализуемая функция: обучение нейросетевой модели по заданному BOP-датасету + + python3 $PYTHON_EDUCATION --path /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296 \ + --name test1234 --datasetName 32123213 + + 27.04.2024 @shalenikol release 0.1 +""" +import argparse +from train_Yolo import train_YoloV8 +from train_Dope import train_Dope_i + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--path", required=True, help="Path for dataset") + parser.add_argument("--name", required=True, help="String with result weights name") + parser.add_argument("--datasetName", required=True, help="String with dataset name") + parser.add_argument("--outpath", default="weights", help="Output path for weights") + parser.add_argument("--type", default="ObjectDetection", help="Type of implementation") + parser.add_argument("--epoch", default=3, type=int, help="How many training epochs") + parser.add_argument('--pretrain', action="store_true", help="Use pretraining") + args = parser.parse_args() + + if args.type == "ObjectDetection": + train_YoloV8(args.path, args.name, args.datasetName, args.outpath, args.epoch, args.pretrain) + else: + train_Dope_i(args.path, args.name, args.datasetName, args.outpath, args.epoch, args.pretrain) diff --git a/simulation/train_models/train_Dope.py b/simulation/train_models/train_Dope.py new file mode 100644 index 0000000..f9908bc --- /dev/null +++ b/simulation/train_models/train_Dope.py @@ -0,0 +1,542 @@ +""" + train_Dope + Общая задача: оценка позиции объекта (Pose estimation) + Реализуемая функция: обучение нейросетевой модели DOPE по заданному BOP-датасету + + python3 $PYTHON_EDUCATION --path /Users/user/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296 \ + --name test1234 --datasetName 32123213 + + 08.05.2024 @shalenikol release 0.1 +""" +import os +import json +import shutil +import numpy as np +import transforms3d as t3d + +FILE_RBS_INFO = "rbs_info.json" +FILE_CAMERA = "camera.json" +FILE_GT = "scene_gt.json" +FILE_GT_COCO = "scene_gt_coco.json" +FILE_GT_INFO = "scene_gt_info.json" + +FILE_MODEL = "epoch" +EXT_MODEL = ".pth" +EXT_RGB = "jpg" +DIR_ROOT_DS = "dataset_dope" +DIR_TRAIN_OUT = "out_weights" + +MODEL_SCALE = 1000 # исходная модель в метрах, преобразуем в мм (для DOPE) + +# Own_Numbering_Files = True # наименование image-файлов: собственная нумерация +nn_image = 0 +K_intrinsic = [] +model_info = [] +camera_data = {} +im_width = 0 + +nb_update_network = 0 +# [ +# [min(x), min(y), min(z)], +# [min(x), max(y), min(z)], +# [min(x), max(y), max(z)], +# [min(x), min(y), max(z)], +# [max(x), min(y), max(z)], +# [max(x), max(y), min(z)], +# [max(x), max(y), max(z)], +# [max(x), min(y), max(z)], +# [xc, yc, zc] # min + (max - min) / 2 +# ] + +def trans_3Dto2D_point_in_camera(xyz, K_m, R_m2c, t_m2c): + """ + xyz : 3D-координаты точки + K_m : внутренняя матрица камеры 3х3 + R_m2c : матрица поворота 3х3 + t_m2c : вектор перемещения 3х1 + return [u,v] + """ + K = np.array(K_m) + r = np.array(R_m2c) + r.shape = (3, 3) + t = np.array(t_m2c) + t.shape = (3, 1) + T = np.concatenate((r, t), axis=1) + + P_m = np.array(xyz) + P_m.resize(4) + P_m[-1] = 1.0 + P_m.shape = (4, 1) + + # Project (X, Y, Z, 1) into cameras coordinate system + P_c = T @ P_m # 4x1 + # Apply camera intrinsics to map (Xc, Yc, Zc) to p=(x, y, z) + p = K @ P_c + # Normalize by z to get (u,v,1) + uv = (p / p[2][0])[:-1] + return uv.flatten().tolist() + +def gt_parse(path: str, out_dir: str): + global nn_image + with open(os.path.join(path, FILE_GT_COCO), "r") as fh: + coco_data = json.load(fh) + with open(os.path.join(path, FILE_GT), "r") as fh: + gt_data = json.load(fh) + with open(os.path.join(path, FILE_GT_INFO), "r") as fh: + gt_info = json.load(fh) + + for img in coco_data["images"]: + rgb_file = os.path.join(path, img["file_name"]) + if os.path.isfile(rgb_file): + # if Own_Numbering_Files: + ext = os.path.splitext(rgb_file)[1] # only ext + f = f"{nn_image:06}" + out_img = os.path.join(out_dir, f + ext) + # else: + # f = os.path.split(rgb_file)[1] # filename with extension + # f = os.path.splitext(f)[0] # only filename + # out_img = out_dir + shutil.copy2(rgb_file, out_img) + out_file = os.path.join(out_dir,f+".json") + nn_image += 1 + + # full annotation of the one image + all_data = camera_data.copy() + cat_names = {obj["id"]: obj["name"] for obj in coco_data["categories"]} + id_img = img["id"] # 0, 1, 2 ... + sid_img = str(id_img) # "0", "1", "2" ... + img_info = gt_info[sid_img] + img_gt = gt_data[sid_img] + img_idx = 0 # object index on the image + objs = [] + for ann in coco_data["annotations"]: + if ann["image_id"] == id_img: + item = ann["category_id"] + obj_data = {} + obj_data["class"] = cat_names[item] + x, y, width, height = ann["bbox"] + obj_data["bounding_box"] = {"top_left":[x,y], "bottom_right":[x+width,y+height]} + + # visibility from FILE_GT_INFO + item_info = img_info[img_idx] + obj_data["visibility"] = item_info["visib_fract"] + + # location from FILE_GT + item_gt = img_gt[img_idx] + obj_id = item_gt["obj_id"] - 1 # index with 0 + cam_R_m2c = item_gt["cam_R_m2c"] + cam_t_m2c = item_gt["cam_t_m2c"] + obj_data["location"] = cam_t_m2c + q = t3d.quaternions.mat2quat(np.array(cam_R_m2c)) + obj_data["quaternion_xyzw"] = [q[1], q[2], q[3], q[0]] + + cuboid_xyz = model_info[obj_id] + obj_data["projected_cuboid"] = [ + trans_3Dto2D_point_in_camera(cub, K_intrinsic, cam_R_m2c, cam_t_m2c) + for cub in cuboid_xyz + ] + + objs.append(obj_data) + img_idx += 1 + + all_data["objects"] = objs + with open(out_file, "w") as fh: + json.dump(all_data, fh, indent=2) + +def explore(path: str, res_dir: str): + if not os.path.isdir(path): + return + folders = [ + os.path.join(path, o) + for o in os.listdir(path) + if os.path.isdir(os.path.join(path, o)) + ] + for path_entry in folders: + if os.path.isfile(os.path.join(path_entry,FILE_GT_COCO)) and \ + os.path.isfile(os.path.join(path_entry,FILE_GT_INFO)) and \ + os.path.isfile(os.path.join(path_entry,FILE_GT)): + gt_parse(path_entry, res_dir) + else: + explore(path_entry, res_dir) + +def BOP2DOPE_dataset(dpath: str, out_dir: str) -> str: + """ Convert BOP-dataset to YOLO format for train """ + res_dir = os.path.join(out_dir, DIR_ROOT_DS) + if os.path.isdir(res_dir): + shutil.rmtree(res_dir) + os.mkdir(res_dir) + + explore(dpath, res_dir) + + return out_dir + +def train(dopepath:str, wname:str, epochs:int, pretrain: bool, lname: list): + import random + # try: + import configparser as configparser + # except ImportError: + # import ConfigParser as configparser + import torch + # import torch.nn.parallel + import torch.optim as optim + import torch.utils.data + import torchvision.transforms as transforms + from torch.autograd import Variable + import datetime + from tensorboardX import SummaryWriter + + from models_dope import DopeNetwork + from utils_dope import CleanVisiiDopeLoader #, VisualizeBeliefMap, save_image + + import warnings + warnings.filterwarnings("ignore") + + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" + + torch.autograd.set_detect_anomaly(False) + torch.autograd.profiler.profile(False) + torch.autograd.gradcheck = False + torch.backends.cudnn.benchmark = True + + start_time = datetime.datetime.now() + print("start:", start_time.strftime("%m/%d/%Y, %H:%M:%S")) + + res_model = os.path.join(dopepath, wname + EXT_MODEL) + + local_rank = 0 + opt = lambda: None + opt.use_s3 = False + opt.train_buckets = [] + opt.endpoint = None + opt.lr=0.0001 + opt.loginterval=100 + opt.sigma=0.5 # 4 + opt.nbupdates=None + # opt.save=False + # opt.option="default" + # opt.gpuids=[0] + + opt.namefile=FILE_MODEL + opt.workers=8 + opt.batchsize=16 + + opt.data = [os.path.join(dopepath, DIR_ROOT_DS)] + opt.outf = os.path.join(dopepath, DIR_TRAIN_OUT) + opt.object = lname #["fork"] + opt.exts = [EXT_RGB] + # opt.imagesize = im_width + opt.epochs = epochs + opt.pretrained = pretrain + opt.net_path = res_model if pretrain else None + opt.manualseed = random.randint(1, 10000) + + # # Validate Arguments + # if opt.use_s3 and (opt.train_buckets is None or opt.endpoint is None): + # raise ValueError( + # "--train_buckets and --endpoint must be specified if training with data from s3 bucket." + # ) + # if not opt.use_s3 and opt.data is None: + # raise ValueError("--data field must be specified.") + + os.makedirs(opt.outf, exist_ok=True) + + # if local_rank == 0: + # writer = SummaryWriter(opt.outf + "/runs/") + random.seed(opt.manualseed) + torch.cuda.set_device(local_rank) + # torch.distributed.init_process_group(backend="nccl", init_method="env://") + torch.manual_seed(opt.manualseed) + torch.cuda.manual_seed_all(opt.manualseed) + + # # Data Augmentation + # if not opt.save: + # contrast = 0.2 + # brightness = 0.2 + # noise = 0.1 + # normal_imgs = [0.59, 0.25] + # transform = transforms.Compose( + # [ + # AddRandomContrast(0.2), + # AddRandomBrightness(0.2), + # transforms.Resize(opt.imagesize), + # ] + # ) + # else: + # contrast = 0.00001 + # brightness = 0.00001 + # noise = 0.00001 + # normal_imgs = None + # transform = transforms.Compose( + # [transforms.Resize(opt.imagesize), transforms.ToTensor()] + # ) + + # Load Model + net = DopeNetwork() + output_size = 50 + # opt.sigma = 0.5 + + train_dataset = CleanVisiiDopeLoader( + opt.data, + sigma=opt.sigma, + output_size=output_size, + extensions=opt.exts, + objects=opt.object, + use_s3=opt.use_s3, + buckets=opt.train_buckets, + endpoint_url=opt.endpoint, + ) + trainingdata = torch.utils.data.DataLoader( + train_dataset, + batch_size=opt.batchsize, + shuffle=True, + num_workers=opt.workers, + pin_memory=True, + ) + if not trainingdata is None: + print(f"training data: {len(trainingdata)} batches") + + print("Loading Model...") + net = net.cuda() + # net = torch.nn.parallel.DistributedDataParallel( + # net.cuda(), device_ids=[local_rank], output_device=local_rank + # ) + if opt.pretrained: + if opt.net_path is not None: + net.load_state_dict(torch.load(opt.net_path)) + else: + print("Error: Did not specify path to pretrained weights.") + quit() + + parameters = filter(lambda p: p.requires_grad, net.parameters()) + optimizer = optim.Adam(parameters, lr=opt.lr) + + print("ready to train!") + + global nb_update_network + nb_update_network = 0 + # best_results = {"epoch": None, "passed": None, "add_mean": None, "add_std": None} + + scaler = torch.cuda.amp.GradScaler() + + def _runnetwork(epoch, train_loader): #, syn=False + global nb_update_network + # net + net.train() + + loss_avg_to_log = {} + loss_avg_to_log["loss"] = [] + loss_avg_to_log["loss_affinities"] = [] + loss_avg_to_log["loss_belief"] = [] + loss_avg_to_log["loss_class"] = [] + for batch_idx, targets in enumerate(train_loader): + optimizer.zero_grad() + + data = Variable(targets["img"].cuda()) + target_belief = Variable(targets["beliefs"].cuda()) + target_affinities = Variable(targets["affinities"].cuda()) + + output_belief, output_aff = net(data) + + loss = None + + loss_belief = torch.tensor(0).float().cuda() + loss_affinities = torch.tensor(0).float().cuda() + loss_class = torch.tensor(0).float().cuda() + + for stage in range(len(output_aff)): # output, each belief map layers. + loss_affinities += ( + (output_aff[stage] - target_affinities) + * (output_aff[stage] - target_affinities) + ).mean() + + loss_belief += ( + (output_belief[stage] - target_belief) + * (output_belief[stage] - target_belief) + ).mean() + + loss = loss_affinities + loss_belief + + # if batch_idx == 0: + # post = "train" + # if local_rank == 0: + # for i_output in range(1): + # # input images + # writer.add_image( + # f"{post}_input_{i_output}", + # targets["img_original"][i_output], + # epoch, + # dataformats="CWH", + # ) + # # belief maps gt + # imgs = VisualizeBeliefMap(target_belief[i_output]) + # img, grid = save_image( + # imgs, "some_img.png", mean=0, std=1, nrow=3, save=False + # ) + # writer.add_image( + # f"{post}_belief_ground_truth_{i_output}", + # grid, + # epoch, + # dataformats="CWH", + # ) + # # belief maps guess + # imgs = VisualizeBeliefMap(output_belief[-1][i_output]) + # img, grid = save_image( + # imgs, "some_img.png", mean=0, std=1, nrow=3, save=False + # ) + # writer.add_image( + # f"{post}_belief_guess_{i_output}", + # grid, + # epoch, + # dataformats="CWH", + # ) + + loss.backward() + + optimizer.step() + + nb_update_network += 1 + + # log the loss + loss_avg_to_log["loss"].append(loss.item()) + loss_avg_to_log["loss_class"].append(loss_class.item()) + loss_avg_to_log["loss_affinities"].append(loss_affinities.item()) + loss_avg_to_log["loss_belief"].append(loss_belief.item()) + + if batch_idx % opt.loginterval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)] \tLoss: {:.15f} \tLocal Rank: {}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + local_rank, + ) + ) + # # log the loss values + # if local_rank == 0: + # writer.add_scalar("loss/train_loss", np.mean(loss_avg_to_log["loss"]), epoch) + # writer.add_scalar("loss/train_cls", np.mean(loss_avg_to_log["loss_class"]), epoch) + # writer.add_scalar("loss/train_aff", np.mean(loss_avg_to_log["loss_affinities"]), epoch) + # writer.add_scalar("loss/train_bel", np.mean(loss_avg_to_log["loss_belief"]), epoch) + + for epoch in range(1, opt.epochs + 1): + + _runnetwork(epoch, trainingdata) + + try: + if local_rank == 0: + torch.save( + net.state_dict(), + f"{opt.outf}/{opt.namefile}_{str(epoch).zfill(3)}.pth", + ) + except Exception as e: + print(f"Encountered Exception: {e}") + + if not opt.nbupdates is None and nb_update_network > int(opt.nbupdates): + break + + # if local_rank == 0: + # save result model + torch.save(net.state_dict(), res_model) #os.path.join(dopepath, wname + EXT_MODEL)) + # else: + # torch.save( + # net.state_dict(), + # f"{opt.outf}/{opt.namefile}_{str(epoch).zfill(3)}_rank_{local_rank}.pth", + # ) + + print("end:", datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")) + print("Total time taken: ", str(datetime.datetime.now() - start_time).split(".")[0]) + +def train_Dope_i(path:str, wname:str, dname:str, outpath:str, epochs:int, pretrain: bool): + """ Main procedure for train DOPE model """ + global K_intrinsic, model_info, camera_data, im_width + + if not os.path.isdir(outpath): + print(f"Invalid output path '{outpath}'") + exit(-1) + out_dir = os.path.join(outpath, wname) + ds_path = os.path.join(path, dname) + + if not os.path.isdir(ds_path): + print(f"{ds_path} : no BOP directory") + return "" + + camera_json = os.path.join(ds_path, FILE_CAMERA) + if not os.path.isfile(camera_json): + print(f"{camera_json} : no intrinsic camera file") + return "" + + rbs_info = os.path.join(ds_path, FILE_RBS_INFO) + if not os.path.isfile(rbs_info): + print(f"{rbs_info} : no dataset info file") + return "" + + camera_data = {} + with open(camera_json, "r") as fh: + data = json.load(fh) + keys = ["cx","cy","fx","fy"] + intrinsic = {k: data[k] for k in keys} + im_height = data["height"] + im_width = data["width"] + camera_data["camera_data"] = dict(intrinsic=intrinsic, height=im_height, width=im_width) + K_intrinsic = [ + [data["fx"], 0.0, data["cx"]], + [0.0, data["fy"], data["cy"]], + [0.0, 0.0, 1.0] + ] + # calc cuboid + center + with open(rbs_info, "r") as fh: + info = json.load(fh) + # список имён объектов + list_name = list(map(lambda x: x["name"], info)) + # in FILE_RBS_INFO model numbering from smallest to largest + model_info = [] + for m_info in info: + cub = np.array(m_info["cuboid"]) * MODEL_SCALE + xyz_min = cub.min(axis=0) + xyz_max = cub.max(axis=0) + # [xc, yc, zc] # min + (max - min) / 2 + center = [] + for i in range(3): + center.append(xyz_min[i] + (xyz_max[i]- xyz_min[i]) / 2) + c = np.array(center, ndmin=2) + model_info.append(np.append(cub, c, axis=0)) + + if pretrain: + # продолжить обучение + if not os.path.isdir(out_dir): + print(f"No dir '{out_dir}'") + exit(-2) + dpath = out_dir + # model_path = os.path.join(dpath, wname + ".pt") + else: + # обучение сначала + if not os.path.isdir(out_dir): + os.mkdir(out_dir) + + dpath = BOP2DOPE_dataset(ds_path, out_dir) + if len(dpath) == 0: + print(f"Error in convert dataset '{ds_path}' to '{outpath}'") + exit(-4) + # model_path = os.path.join(dpath, FILE_BASEMODEL) + + # results = f"python train.py --local_rank 0 --data {dpath} --object fork" \ + # + f" -e {epochs} --batchsize 16 --exts jpg --imagesize 640 --pretrained" \ + # + " --net_path /home/shalenikol/fork_work/dope_training/output/weights_2996/net_epoch_47.pth" + # print(results) + train(dpath, wname, epochs, pretrain, list_name) + +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--path", required=True, help="Path for dataset") + parser.add_argument("--name", required=True, help="String with result weights name") + parser.add_argument("--datasetName", required=True, help="String with dataset name") + parser.add_argument("--outpath", default="weights", help="Output path for weights") + parser.add_argument("--epoch", default=3, help="How many training epochs") + parser.add_argument('--pretrain', action="store_true", help="Use pretraining") + args = parser.parse_args() + + train_Dope_i(args.path, args.name, args.datasetName, args.outpath, args.epoch, args.pretrain) diff --git a/simulation/train_models/train_Yolo.py b/simulation/train_models/train_Yolo.py new file mode 100644 index 0000000..1eaf7a0 --- /dev/null +++ b/simulation/train_models/train_Yolo.py @@ -0,0 +1,181 @@ +""" + train_Yolo + Общая задача: обнаружение объекта (Object detection) + Реализуемая функция: обучение нейросетевой модели YoloV8 по заданному BOP-датасету + + python3 $PYTHON_TRAIN --path /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296/datasets \ + --name test123 --datasetName ds213 --outpath /Users/idontsudo/webservice/server/build/public/7065d6b6-c8a3-48c5-9679-bb8f3a690296/weights + + 27.04.2024 @shalenikol release 0.1 +""" +import os +import shutil +import json +import yaml + +from ultralytics import YOLO +# from ultralytics.utils.metrics import DetMetrics + +FILE_BASEMODEL = "yolov8n.pt" +FILE_RBS_INFO = "rbs_info.json" +FILE_RBS_TRAIN = "rbs_train.yaml" +FILE_GT_COCO = "scene_gt_coco.json" +FILE_L_TRAIN = "i_train.txt" +FILE_L_VAL = "i_val.txt" +FILE_TRAIN_RES = "weights/last.pt" +DIR_ROOT_DS = "datasets" +DIR_COCO_DS = "rbs_coco" +DIR_RGB_DS = "images" +DIR_LABELS_DS = "labels" + +SZ_SERIES = 15 # number of train images per validation images + +nn_image = 0 +f1 = f2 = None + +def convert2relative(height, width, bbox): + """ YOLO format use relative coordinates for annotation """ + x, y, w, h = bbox + x += w/2 + y += h/2 + return x/width, y/height, w/width, h/height + +def gt_parse(path: str, out_dir: str): + global nn_image, f1, f2 + with open(os.path.join(path, FILE_GT_COCO), "r") as fh: + coco_data = json.load(fh) + + for img in coco_data["images"]: + rgb_file = os.path.join(path, img["file_name"]) + if os.path.isfile(rgb_file): + ext = os.path.splitext(rgb_file)[1] # only ext + f = f"{nn_image:06}" + out_img = os.path.join(out_dir, DIR_RGB_DS, f + ext) + shutil.copy2(rgb_file, out_img) + + # заполним файлы с метками bbox + img_id = img["id"] + with open(os.path.join(out_dir, DIR_LABELS_DS, f + ".txt"), "w") as fh: + for i in coco_data["annotations"]: + if i["image_id"] == img_id: + cat_id = i["category_id"] + if cat_id < 999: + bbox = i["bbox"] + im_h = i["height"] + im_w = i["width"] + rel = convert2relative(im_h,im_w,bbox) + # формат: + fh.write(f"{cat_id-1} {rel[0]} {rel[1]} {rel[2]} {rel[3]}\n") # category from 0 + + nn_image += 1 + line = os.path.join("./", DIR_RGB_DS, f + ext) + "\n" + if nn_image % SZ_SERIES == 0: + f2.write(line) + else: + f1.write(line) + +def explore(path: str, res_dir: str): + if not os.path.isdir(path): + return + folders = [ + os.path.join(path, o) + for o in os.listdir(path) + if os.path.isdir(os.path.join(path, o)) + ] + for path_entry in folders: + if os.path.isfile(os.path.join(path_entry,FILE_GT_COCO)): + gt_parse(path_entry, res_dir) + else: + explore(path_entry, res_dir) + +def BOP2Yolo_dataset(dpath: str, out_dir: str, lname: list) -> str: + """ Convert BOP-dataset to YOLO format for train """ + cfg_yaml = os.path.join(out_dir, FILE_RBS_TRAIN) + p = os.path.join(out_dir, DIR_ROOT_DS, DIR_COCO_DS) + cfg_data = {"path": p, "train": FILE_L_TRAIN, "val": FILE_L_VAL} + cfg_data["names"] = {i:x for i,x in enumerate(lname)} + with open(cfg_yaml, "w") as fh: + yaml.dump(cfg_data, fh) + + res_dir = os.path.join(out_dir, DIR_ROOT_DS) + if not os.path.isdir(res_dir): + os.mkdir(res_dir) + + res_dir = os.path.join(res_dir, DIR_COCO_DS) + if not os.path.isdir(res_dir): + os.mkdir(res_dir) + + p = os.path.join(res_dir, DIR_RGB_DS) + if not os.path.isdir(p): + os.mkdir(p) + p = os.path.join(res_dir, DIR_LABELS_DS) + if not os.path.isdir(p): + os.mkdir(p) + + global f1, f2 + f1 = open(os.path.join(res_dir, FILE_L_TRAIN), "w") + f2 = open(os.path.join(res_dir, FILE_L_VAL), "w") + explore(dpath, res_dir) + f1.close() + f2.close() + + return out_dir + +def train_YoloV8(path:str, wname:str, dname:str, outpath:str, epochs:int, pretrain: bool): + """ Main procedure for train YOLOv8 model """ + if not os.path.isdir(outpath): + print(f"Invalid output path '{outpath}'") + exit(-1) + out_dir = os.path.join(outpath, wname) + + if pretrain: + # продолжить обучение + if not os.path.isdir(out_dir): + print(f"No dir '{out_dir}'") + exit(-2) + dpath = out_dir + model_path = os.path.join(dpath, wname + ".pt") + else: + # обучение сначала + if not os.path.isdir(out_dir): + os.mkdir(out_dir) + + ds_path = os.path.join(path, dname) + rbs_info = os.path.join(ds_path, FILE_RBS_INFO) + if not os.path.isfile(rbs_info): + print(f"{rbs_info} : no dataset description file") + exit(-3) + + with open(rbs_info, "r") as fh: + y = json.load(fh) + # список имён объектов + list_name = list(map(lambda x: x["name"], y)) + + dpath = BOP2Yolo_dataset(ds_path, out_dir, list_name) + if len(dpath) == 0: + print(f"Error in convert dataset '{ds_path}' to '{outpath}'") + exit(-4) + model_path = os.path.join(dpath, FILE_BASEMODEL) + + model = YOLO(model_path) + results = model.train(data=os.path.join(dpath, FILE_RBS_TRAIN), epochs=epochs, project=out_dir) + wf = os.path.join(results.save_dir, FILE_TRAIN_RES) + if not os.path.isfile(wf): + print(f"Error in train: no result file '{wf}'") + exit(-5) + + shutil.copy2(wf, os.path.join(dpath, wname + ".pt")) + shutil.rmtree(results.save_dir) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--path", required=True, help="Path for dataset") + parser.add_argument("--name", required=True, help="String with result weights name") + parser.add_argument("--datasetName", required=True, help="String with dataset name") + parser.add_argument("--outpath", default="weights", help="Output path for weights") + parser.add_argument("--epoch", default=3, type=int, help="How many training epochs") + parser.add_argument('--pretrain', action="store_true", help="Use pretraining") + args = parser.parse_args() + + train_YoloV8(args.path, args.name, args.datasetName, args.outpath, args.epoch, args.pretrain) diff --git a/simulation/train_models/utils_dope.py b/simulation/train_models/utils_dope.py new file mode 100755 index 0000000..55ab058 --- /dev/null +++ b/simulation/train_models/utils_dope.py @@ -0,0 +1,967 @@ +""" +NVIDIA from jtremblay@gmail.com +""" +import numpy as np +import torch + +import os + +import torch +import torch.nn as nn +import torch.nn.parallel + +import torch.utils.data + +import torchvision.transforms as transforms + +import torch.utils.data as data +import glob +import os +import boto3 +import io + +from PIL import Image +from PIL import ImageDraw +from PIL import ImageEnhance + +from math import acos +from math import sqrt +from math import pi + +from os.path import exists, basename +import json +from os.path import join + +import albumentations as A + + +def default_loader(path): + return Image.open(path).convert("RGB") + + +def length(v): + return sqrt(v[0] ** 2 + v[1] ** 2) + + +def dot_product(v, w): + return v[0] * w[0] + v[1] * w[1] + + +def normalize(v): + norm = np.linalg.norm(v, ord=1) + if norm == 0: + norm = np.finfo(v.dtype).eps + return v / norm + + +def determinant(v, w): + return v[0] * w[1] - v[1] * w[0] + + +def inner_angle(v, w): + cosx = dot_product(v, w) / (length(v) * length(w)) + rad = acos(cosx) # in radians + return rad * 180 / pi # returns degrees + + +def py_ang(A, B=(1, 0)): + inner = inner_angle(A, B) + det = determinant(A, B) + if ( + det < 0 + ): # this is a property of the det. If the det < 0 then B is clockwise of A + return inner + else: # if the det > 0 then A is immediately clockwise of B + return 360 - inner + + +import colorsys, math + + +def append_dot(extensions): + res = [] + + for ext in extensions: + if not ext.startswith("."): + res.append(f".{ext}") + else: + res.append(ext) + + return res + + +def loadimages(root, extensions=["png"]): + imgs = [] + extensions = append_dot(extensions) + + def add_json_files( + path, + ): + for ext in extensions: + for file in os.listdir(path): + imgpath = os.path.join(path, file) + if ( + imgpath.endswith(ext) + and exists(imgpath) + and exists(imgpath.replace(ext, ".json")) + ): + imgs.append( + ( + imgpath, + imgpath.replace(path, "").replace("/", ""), + imgpath.replace(ext, ".json"), + ) + ) + + def explore(path): + if not os.path.isdir(path): + return + folders = [ + os.path.join(path, o) + for o in os.listdir(path) + if os.path.isdir(os.path.join(path, o)) + ] + + for path_entry in folders: + explore(path_entry) + + add_json_files(path) + + explore(root) + + return imgs + + +def loadweights(root): + if root.endswith(".pth") and os.path.isfile(root): + return [root] + else: + weights = [ + os.path.join(root, f) + for f in os.listdir(root) + if os.path.isfile(os.path.join(root, f)) and f.endswith(".pth") + ] + + weights.sort() + return weights + + +def loadimages_inference(root, extensions): + imgs, imgsname = [], [] + extensions = append_dot(extensions) + + def add_imgs( + path, + ): + for ext in extensions: + for file in os.listdir(path): + imgpath = os.path.join(path, file) + if imgpath.endswith(ext) and exists(imgpath): + imgs.append(imgpath) + imgsname.append(imgpath.replace(root, "")) + + def explore(path): + if not os.path.isdir(path): + return + folders = [ + os.path.join(path, o) + for o in os.listdir(path) + if os.path.isdir(os.path.join(path, o)) + ] + + for path_entry in folders: + explore(path_entry) + + add_imgs(path) + + explore(root) + + return imgs, imgsname + + +class CleanVisiiDopeLoader(data.Dataset): + def __init__( + self, + path_dataset, + objects=None, + sigma=1, + output_size=400, + extensions=["png"], + debug=False, + use_s3=False, + buckets=[], + endpoint_url=None, + ): + ################### + self.path_dataset = path_dataset + self.objects_interest = objects + self.sigma = sigma + self.output_size = output_size + self.extensions = append_dot(extensions) + self.debug = debug + ################### + + self.imgs = [] + self.s3_buckets = {} + self.use_s3 = use_s3 + + if self.use_s3: + self.session = boto3.Session() + self.s3 = self.session.resource( + service_name="s3", endpoint_url=endpoint_url + ) + + for bucket_name in buckets: + try: + self.s3_buckets[bucket_name] = self.s3.Bucket(bucket_name) + except Exception as e: + print( + f"Error trying to load bucket {bucket_name} for training data:", + e, + ) + + for bucket in self.s3_buckets: + bucket_objects = [ + str(obj.key) for obj in self.s3_buckets[bucket].objects.all() + ] + + jsons = set([json for json in bucket_objects if json.endswith(".json")]) + imgs = [ + img + for img in bucket_objects + if img.endswith(tuple(self.extensions)) + ] + + for ext in self.extensions: + for img in imgs: + # Only add images that have a ground truth file + if img.endswith(ext) and img.replace(ext, ".json") in jsons: + # (img key, bucket name, json key) + self.imgs.append((img, bucket, img.replace(ext, ".json"))) + + else: + for path_look in path_dataset: + self.imgs += loadimages(path_look, extensions=self.extensions) + + # np.random.shuffle(self.imgs) + print("Number of Training Images:", len(self.imgs)) + print(self.imgs) + + if debug: + print("Debuging will be save in debug/") + if os.path.isdir("debug"): + print(f'folder {"debug"}/ exists') + else: + os.mkdir("debug") + print(f'created folder {"debug"}/') + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, index): + + # load the data + if self.use_s3: + img_key, bucket, json_key = self.imgs[index] + mem_img = io.BytesIO() + + object_img = self.s3_buckets[bucket].Object(img_key) + object_img.download_fileobj(mem_img) + + img = np.array(Image.open(mem_img).convert("RGB")) + + object_json = self.s3_buckets[bucket].Object(json_key) + data_json = json.load(object_json.get()["Body"]) + + img_name = img_key[:-3] + + else: + path_img, img_name, path_json = self.imgs[index] + + # load the image + img = np.array(Image.open(path_img).convert("RGB")) + + # load the json file + with open(path_json) as f: + data_json = json.load(f) + + all_projected_cuboid_keypoints = [] + + # load the projected cuboid keypoints + for obj in data_json["objects"]: + if ( + self.objects_interest is not None + and not obj["class"] in self.objects_interest + ): + continue + # load the projected_cuboid_keypoints + # 06.02.2024 @shalenikol + # if obj["visibility_image"] > 0: + if obj["visibility"] > 0: + projected_cuboid_keypoints = obj["projected_cuboid"] + # FAT dataset only has 8 corners for 'projected_cuboid' + if len(projected_cuboid_keypoints) == 8: + projected_cuboid_keypoints.append(obj["projected_cuboid_centroid"]) + else: + projected_cuboid_keypoints = [ + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + ] + all_projected_cuboid_keypoints.append(projected_cuboid_keypoints) + + if len(all_projected_cuboid_keypoints) == 0: + all_projected_cuboid_keypoints = [ + [ + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + [-100, -100], + ] + ] + + # flatten the keypoints + flatten_projected_cuboid = [] + for obj in all_projected_cuboid_keypoints: + for p in obj: + flatten_projected_cuboid.append(p) + + ####### + if self.debug: + img_to_save = Image.fromarray(img) + draw = ImageDraw.Draw(img_to_save) + + for ip, p in enumerate(flatten_projected_cuboid): + draw.ellipse( + (int(p[0]) - 2, int(p[1]) - 2, int(p[0]) + 2, int(p[1]) + 2), + fill="green", + ) + + img_to_save.save(f"debug/{img_name.replace('.png','_original.png')}") + ####### + + # data augmentation + transform = A.Compose( + [ + A.RandomCrop(width=400, height=400), + A.Rotate(limit=180), + A.RandomBrightnessContrast( + brightness_limit=0.2, contrast_limit=0.15, p=1 + ), + A.GaussNoise(p=1), + ], + keypoint_params=A.KeypointParams(format="xy", remove_invisible=False), + ) + transformed = transform(image=img, keypoints=flatten_projected_cuboid) + img_transformed = transformed["image"] + flatten_projected_cuboid_transformed = transformed["keypoints"] + + ####### + + # transform to the final output + if not self.output_size == 400: + transform = A.Compose( + [ + A.Resize(width=self.output_size, height=self.output_size), + ], + keypoint_params=A.KeypointParams(format="xy", remove_invisible=False), + ) + transformed = transform( + image=img_transformed, keypoints=flatten_projected_cuboid_transformed + ) + img_transformed_output_size = transformed["image"] + flatten_projected_cuboid_transformed_output_size = transformed["keypoints"] + + else: + img_transformed_output_size = img_transformed + flatten_projected_cuboid_transformed_output_size = ( + flatten_projected_cuboid_transformed + ) + + ####### + if self.debug: + img_transformed_saving = Image.fromarray(img_transformed) + + draw = ImageDraw.Draw(img_transformed_saving) + + for ip, p in enumerate(flatten_projected_cuboid_transformed): + draw.ellipse( + (int(p[0]) - 2, int(p[1]) - 2, int(p[0]) + 2, int(p[1]) + 2), + fill="green", + ) + + img_transformed_saving.save( + f"debug/{img_name.replace('.png','_transformed.png')}" + ) + ####### + + # update the keypoints list + # obj x keypoint_id x (x,y) + i_all = 0 + for i_obj, obj in enumerate(all_projected_cuboid_keypoints): + for i_p, point in enumerate(obj): + all_projected_cuboid_keypoints[i_obj][ + i_p + ] = flatten_projected_cuboid_transformed_output_size[i_all] + i_all += 1 + + # generate the belief maps + beliefs = CreateBeliefMap( + size=int(self.output_size), + pointsBelief=all_projected_cuboid_keypoints, + sigma=self.sigma, + nbpoints=9, + save=False, + ) + beliefs = torch.from_numpy(np.array(beliefs)) + # generate affinity fields with centroid. + affinities = GenerateMapAffinity( + size=int(self.output_size), + nb_vertex=8, + pointsInterest=all_projected_cuboid_keypoints, + objects_centroid=np.array(all_projected_cuboid_keypoints)[:, -1].tolist(), + scale=1, + ) + + # prepare for the image tensors + normalize_tensor = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) + to_tensor = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + img_tensor = normalize_tensor(Image.fromarray(img_transformed)) + img_original = to_tensor(img_transformed) + + ######## + if self.debug: + imgs = VisualizeBeliefMap(beliefs) + img, grid = save_image( + imgs, + f"debug/{img_name.replace('.png','_beliefs.png')}", + mean=0, + std=1, + nrow=3, + save=True, + ) + imgs = VisualizeAffinityMap(affinities) + save_image( + imgs, + f"debug/{img_name.replace('.png','_affinities.png')}", + mean=0, + std=1, + nrow=3, + save=True, + ) + ######## + img_tensor[torch.isnan(img_tensor)] = 0 + affinities[torch.isnan(affinities)] = 0 + beliefs[torch.isnan(beliefs)] = 0 + + img_tensor[torch.isinf(img_tensor)] = 0 + affinities[torch.isinf(affinities)] = 0 + beliefs[torch.isinf(beliefs)] = 0 + + return { + "img": img_tensor, + "affinities": torch.clamp(affinities, -1, 1), + "beliefs": torch.clamp(beliefs, 0, 1), + "file_name": img_name, + "img_original": img_original, + } + + +def VisualizeAffinityMap( + tensor, + # tensor of (len(keypoints)*2)xwxh + threshold_norm_vector=0.4, + # how long does the vector has to be to be drawn + points=None, + # list of points to draw in white on top of the image + factor=1.0, + # by how much the image was reduced, scale factor + translation=(0, 0) + # by how much the points were moved + # return len(keypoints)x3xwxh # stack of images +): + images = torch.zeros(tensor.shape[0] // 2, 3, tensor.shape[1], tensor.shape[2]) + for i_image in range(0, tensor.shape[0], 2): # could be read as i_keypoint + + indices = ( + torch.abs(tensor[i_image, :, :]) + torch.abs(tensor[i_image + 1, :, :]) + > threshold_norm_vector + ).nonzero() + + for indice in indices: + + i, j = indice + + angle_vector = np.array([tensor[i_image, i, j], tensor[i_image + 1, i, j]]) + if length(angle_vector) > threshold_norm_vector: + angle = py_ang(angle_vector) + c = colorsys.hsv_to_rgb(angle / 360, 1, 1) + else: + c = [0, 0, 0] + for i_c in range(3): + images[i_image // 2, i_c, i, j] = c[i_c] + if not points is None: + point = points[i_image // 2] + + print( + int(point[1] * factor + translation[1]), + int(point[0] * factor + translation[0]), + ) + images[ + i_image // 2, + :, + int(point[1] * factor + translation[1]) + - 1 : int(point[1] * factor + translation[1]) + + 1, + int(point[0] * factor + translation[0]) + - 1 : int(point[0] * factor + translation[0]) + + 1, + ] = 1 + + return images + + +def VisualizeBeliefMap( + tensor, + # tensor of len(keypoints)xwxh + points=None, + # list of points to draw on top of the image + factor=1.0, + # by how much the image was reduced, scale factor + translation=(0, 0) + # by how much the points were moved + # return len(keypoints)x3xwxh # stack of images in torch tensor +): + images = torch.zeros(tensor.shape[0], 3, tensor.shape[1], tensor.shape[2]) + for i_image in range(0, tensor.shape[0]): # could be read as i_keypoint + + belief = tensor[i_image].clone() + belief -= float(torch.min(belief).item()) + belief /= float(torch.max(belief).item()) + + belief = torch.clamp(belief, 0, 1) + belief = torch.cat( + [belief.unsqueeze(0), belief.unsqueeze(0), belief.unsqueeze(0)] + ).unsqueeze(0) + + images[i_image] = belief + + return images + + +def GenerateMapAffinity( + size, nb_vertex, pointsInterest, objects_centroid, scale, save=False +): + # Apply the downscale right now, so the vectors are correct. + + img_affinity = Image.new("RGB", (int(size / scale), int(size / scale)), "black") + # create the empty tensors + totensor = transforms.Compose([transforms.ToTensor()]) + + affinities = [] + for i_points in range(nb_vertex): + affinities.append(torch.zeros(2, int(size / scale), int(size / scale))) + + for i_pointsImage in range(len(pointsInterest)): + pointsImage = pointsInterest[i_pointsImage] + center = objects_centroid[i_pointsImage] + for i_points in range(nb_vertex): + point = pointsImage[i_points] + + affinity_pair, img_affinity = getAfinityCenter( + int(size / scale), + int(size / scale), + tuple((np.array(pointsImage[i_points]) / scale).tolist()), + tuple((np.array(center) / scale).tolist()), + img_affinity=img_affinity, + radius=1, + ) + + affinities[i_points] = (affinities[i_points] + affinity_pair) / 2 + + # Normalizing + v = affinities[i_points].numpy() + + xvec = v[0] + yvec = v[1] + + norms = np.sqrt(xvec * xvec + yvec * yvec) + nonzero = norms > 0 + + xvec[nonzero] /= norms[nonzero] + yvec[nonzero] /= norms[nonzero] + + affinities[i_points] = torch.from_numpy(np.concatenate([[xvec], [yvec]])) + affinities = torch.cat(affinities, 0) + + return affinities + + +def getAfinityCenter( + width, height, point, center, radius=7, tensor=None, img_affinity=None +): + """ + Create the affinity map + """ + if tensor is None: + tensor = torch.zeros(2, height, width).float() + + # create the canvas for the afinity output + imgAffinity = Image.new("RGB", (width, height), "black") + totensor = transforms.Compose([transforms.ToTensor()]) + draw = ImageDraw.Draw(imgAffinity) + r1 = radius + p = point + draw.ellipse((p[0] - r1, p[1] - r1, p[0] + r1, p[1] + r1), (255, 255, 255)) + + del draw + + # compute the array to add the afinity + array = (np.array(imgAffinity) / 255)[:, :, 0] + + angle_vector = np.array(center) - np.array(point) + angle_vector = normalize(angle_vector) + affinity = np.concatenate([[array * angle_vector[0]], [array * angle_vector[1]]]) + + if not img_affinity is None: + # find the angle vector + if length(angle_vector) > 0: + angle = py_ang(angle_vector) + else: + angle = 0 + c = np.array(colorsys.hsv_to_rgb(angle / 360, 1, 1)) * 255 + draw = ImageDraw.Draw(img_affinity) + draw.ellipse( + (p[0] - r1, p[1] - r1, p[0] + r1, p[1] + r1), + fill=(int(c[0]), int(c[1]), int(c[2])), + ) + del draw + re = torch.from_numpy(affinity).float() + tensor + return re, img_affinity + + +def CreateBeliefMap(size, pointsBelief, nbpoints, sigma=16, save=False): + # Create the belief maps in the points + beliefsImg = [] + for numb_point in range(nbpoints): + array = np.zeros([size, size]) + out = np.zeros([size, size]) + + for point in pointsBelief: + p = [point[numb_point][1], point[numb_point][0]] + w = int(sigma * 2) + if p[0] - w >= 0 and p[0] + w < size and p[1] - w >= 0 and p[1] + w < size: + for i in range(int(p[0]) - w, int(p[0]) + w + 1): + for j in range(int(p[1]) - w, int(p[1]) + w + 1): + + # if there is already a point there. + array[i, j] = max( + np.exp( + -( + ((i - p[0]) ** 2 + (j - p[1]) ** 2) + / (2 * (sigma**2)) + ) + ), + array[i, j], + ) + + beliefsImg.append(array.copy()) + + if save: + stack = np.stack([array, array, array], axis=0).transpose(2, 1, 0) + imgBelief = Image.fromarray((stack * 255).astype("uint8")) + imgBelief.save("debug/{}.png".format(numb_point)) + return beliefsImg + + +def crop(img, i, j, h, w): + """Crop the given PIL.Image. + Args: + img (PIL.Image): Image to be cropped. + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + Returns: + PIL.Image: Cropped image. + """ + return img.crop((j, i, j + w, i + h)) + + +class AddRandomContrast(object): + """ + Apply some random image filters from PIL + """ + + def __init__(self, sigma=0.1): + self.sigma = sigma + + def __call__(self, im): + + contrast = ImageEnhance.Contrast(im) + + im = contrast.enhance(np.random.normal(1, self.sigma)) + + return im + + +class AddRandomBrightness(object): + """ + Apply some random image filters from PIL + """ + + def __init__(self, sigma=0.1): + self.sigma = sigma + + def __call__(self, im): + + contrast = ImageEnhance.Brightness(im) + im = contrast.enhance(np.random.normal(1, self.sigma)) + return im + + +class AddNoise(object): + """Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + """ + + def __init__(self, std=0.1): + self.std = std + + def __call__(self, tensor): + # TODO: make efficient + t = torch.FloatTensor(tensor.size()).normal_(0, self.std) + + t = tensor.add(t) + t = torch.clamp(t, -1, 1) # this is expansive + return t + + +irange = range + + +def make_grid( + tensor, + nrow=8, + padding=2, + normalize=False, + range=None, + scale_each=False, + pad_value=0, +): + """Make a grid of images. + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The Final grid size is (B / nrow, nrow). Default is 8. + padding (int, optional): amount of padding. Default is 2. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by subtracting the minimum and dividing by the maximum pixel value. + range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each (bool, optional): If True, scale each image in the batch of + images separately rather than the (min, max) over all images. + pad_value (float, optional): Value for the padded pixels. + Example: + See this notebook `here `_ + """ + if not ( + torch.is_tensor(tensor) + or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor)) + ): + raise TypeError( + "tensor or list of tensors expected, got {}".format(type(tensor)) + ) + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.view(1, tensor.size(0), tensor.size(1)) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2)) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if range is not None: + assert isinstance( + range, tuple + ), "range has to be a tuple (min, max) if specified. min and max are numbers" + + def norm_ip(img, min, max): + img.clamp_(min=min, max=max) + img.add_(-min).div_(max - min + 1e-5) + + def norm_range(t, range): + if range is not None: + norm_ip(t, range[0], range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, range) + else: + norm_range(tensor, range) + + if tensor.size(0) == 1: + return tensor.squeeze() + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_( + pad_value + ) + k = 0 + for y in irange(ymaps): + for x in irange(xmaps): + if k >= nmaps: + break + grid.narrow(1, y * height + padding, height - padding).narrow( + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +def save_image(tensor, filename, nrow=4, padding=2, mean=None, std=None, save=True): + """ + Saves a given Tensor into an image file. + If given a mini-batch tensor, will save the tensor as a grid of images. + """ + from PIL import Image + + tensor = tensor.cpu() + grid = make_grid(tensor, nrow=nrow, padding=10, pad_value=1) + if not mean is None: + # ndarr = grid.mul(std).add(mean).mul(255).byte().transpose(0,2).transpose(0,1).numpy() + ndarr = ( + grid.mul(std) + .add(mean) + .mul(255) + .byte() + .transpose(0, 2) + .transpose(0, 1) + .numpy() + ) + else: + ndarr = ( + grid.mul(0.5) + .add(0.5) + .mul(255) + .byte() + .transpose(0, 2) + .transpose(0, 1) + .numpy() + ) + im = Image.fromarray(ndarr) + if save is True: + im.save(filename) + return im, grid + + +from PIL import ImageDraw, Image, ImageFont +import json + + +class Draw(object): + """Drawing helper class to visualize the neural network output""" + + def __init__(self, im): + """ + :param im: The image to draw in. + """ + self.draw = ImageDraw.Draw(im) + self.width = im.size[0] + + def draw_line(self, point1, point2, line_color, line_width=2): + """Draws line on image""" + if point1 is not None and point2 is not None: + self.draw.line([point1, point2], fill=line_color, width=line_width) + + def draw_dot(self, point, point_color, point_radius): + """Draws dot (filled circle) on image""" + if point is not None: + xy = [ + point[0] - point_radius, + point[1] - point_radius, + point[0] + point_radius, + point[1] + point_radius, + ] + self.draw.ellipse(xy, fill=point_color, outline=point_color) + + def draw_text(self, point, text, text_color): + """Draws text on image""" + if point is not None: + self.draw.text(point, text, fill=text_color, font=ImageFont.truetype("misc/arial.ttf", self.width // 50)) + + def draw_cube(self, points, color=(0, 255, 0)): + """ + Draws cube with a thick solid line across + the front top edge and an X on the top face. + """ + # draw front + self.draw_line(points[0], points[1], color) + self.draw_line(points[1], points[2], color) + self.draw_line(points[3], points[2], color) + self.draw_line(points[3], points[0], color) + + # draw back + self.draw_line(points[4], points[5], color) + self.draw_line(points[6], points[5], color) + self.draw_line(points[6], points[7], color) + self.draw_line(points[4], points[7], color) + + # draw sides + self.draw_line(points[0], points[4], color) + self.draw_line(points[7], points[3], color) + self.draw_line(points[5], points[1], color) + self.draw_line(points[2], points[6], color) + + # draw dots + self.draw_dot(points[0], point_color=color, point_radius=4) + self.draw_dot(points[1], point_color=color, point_radius=4) + + # draw x on the top + self.draw_line(points[0], points[5], color) + self.draw_line(points[1], points[4], color) + + # Draw center + self.draw_dot(points[8], point_color=color, point_radius=6) + + for i in range(9): + self.draw_text(points[i], str(i), (255, 0, 0)) + +