366 lines
14 KiB
Python
366 lines
14 KiB
Python
|
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
pose_estimation_lifecycle_node_with_DOPE
|
|||
|
ROS 2 program for 6D Pose Estimation
|
|||
|
|
|||
|
Source run:
|
|||
|
python inference.py --weights ../output/weights_2996 --data ../sample_dataset100 --object fork --exts jpg \
|
|||
|
--config config/config_pose_fork.yaml --camera config/camera_info_640x480.yaml
|
|||
|
|
|||
|
@shalenikol release 0.4
|
|||
|
"""
|
|||
|
import os
|
|||
|
import json
|
|||
|
import yaml
|
|||
|
|
|||
|
import rclpy
|
|||
|
from rclpy.lifecycle import Node
|
|||
|
from rclpy.lifecycle import State
|
|||
|
from rclpy.lifecycle import TransitionCallbackReturn
|
|||
|
|
|||
|
from ament_index_python.packages import get_package_share_directory
|
|||
|
from sensor_msgs.msg import Image, CameraInfo
|
|||
|
from geometry_msgs.msg import Pose
|
|||
|
|
|||
|
from cv_bridge import CvBridge # Package to convert between ROS and OpenCV Images
|
|||
|
import cv2 # OpenCV library
|
|||
|
|
|||
|
FILE_DOPE_CONFIG = "pe_dope_config.yaml"
|
|||
|
# FILE_TEMP_IMAGE = "image_rgb.png"
|
|||
|
CAMERA_LINK_DEFAULT = "rgbd_camera"
|
|||
|
NODE_NAME_DEFAULT = "lc_dope" # this name must match the name in the description (["Module"]["node_name"])
|
|||
|
PARAM_SKILL_CFG = "lc_dope_cfg"
|
|||
|
|
|||
|
|
|||
|
def get_transfer_path_() -> str:
|
|||
|
return os.path.join(get_package_share_directory("rbs_perception"), "config")
|
|||
|
|
|||
|
class PE_DOPE(Node):
|
|||
|
"""Pose estimation lifecycle node with DOPE."""
|
|||
|
def __init__(self, **kwargs):
|
|||
|
"""Construct the node."""
|
|||
|
# for other nodes
|
|||
|
kwargs["allow_undeclared_parameters"] = True
|
|||
|
kwargs["automatically_declare_parameters_from_overrides"] = True
|
|||
|
super().__init__(NODE_NAME_DEFAULT, **kwargs)
|
|||
|
|
|||
|
str_cfg = self.get_parameter(PARAM_SKILL_CFG).get_parameter_value().string_value
|
|||
|
self.skill_cfg = json.loads(str_cfg)
|
|||
|
self.nodeName = NODE_NAME_DEFAULT # self.skill_cfg["Module"]["node_name"]
|
|||
|
# out_par = self.skill_cfg["Interface"]["Output"][0]
|
|||
|
# self.topicSrv = self.nodeName + "/" + out_par["name"]
|
|||
|
self.topicSrv = self.skill_cfg["topicsOut"][0]["name"]
|
|||
|
|
|||
|
# Used to convert between ROS and OpenCV images
|
|||
|
self.br = CvBridge()
|
|||
|
self.dope_cfg = self._load_config_DOPE()
|
|||
|
|
|||
|
self._cam_pose = Pose()
|
|||
|
|
|||
|
self._is_camerainfo = False
|
|||
|
self.topicImage = ""
|
|||
|
self.topicCameraInfo = ""
|
|||
|
self._set_camera_topic(CAMERA_LINK_DEFAULT)
|
|||
|
self._sub = None
|
|||
|
self._sub_info = None
|
|||
|
self._pub = None
|
|||
|
self._image_cnt = 0
|
|||
|
self._K = []
|
|||
|
|
|||
|
def _load_config_DOPE(self):
|
|||
|
p = os.path.join(get_transfer_path_(), FILE_DOPE_CONFIG)
|
|||
|
with open(p, "r") as f:
|
|||
|
y = yaml.load(f, Loader=yaml.FullLoader)
|
|||
|
return y
|
|||
|
|
|||
|
def _set_camera_topic(self, camera_link: str):
|
|||
|
""" Service for camera name topics """
|
|||
|
self.topicImage = "/" + camera_link + "/image"
|
|||
|
self.topicCameraInfo = "/" + camera_link +"/camera_info"
|
|||
|
|
|||
|
def listener_camera_info(self, data):
|
|||
|
""" CameraInfo callback function. """
|
|||
|
if self._is_camerainfo: # don’t read camera info again
|
|||
|
return
|
|||
|
|
|||
|
self._res = [data.height, data.width]
|
|||
|
k_ = data.k
|
|||
|
self._K = [
|
|||
|
[k_[0], k_[1], k_[2]],
|
|||
|
[k_[3], k_[4], k_[5]],
|
|||
|
[k_[6], k_[7], k_[8]]
|
|||
|
]
|
|||
|
# set the indicator for receiving the camera info
|
|||
|
self._is_camerainfo = True
|
|||
|
|
|||
|
def on_configure(self, state: State) -> TransitionCallbackReturn:
|
|||
|
json_param = self.get_parameter(PARAM_SKILL_CFG).get_parameter_value().string_value
|
|||
|
jdata = json.loads(json_param)
|
|||
|
dependency = {}
|
|||
|
for comm in jdata["BTAction"]:
|
|||
|
for par in comm["param"]:
|
|||
|
if par["type"] == "weights":
|
|||
|
dependency = par["dependency"]
|
|||
|
elif par["type"] == "topic":
|
|||
|
t_dep = par["dependency"]
|
|||
|
if "Image" in t_dep["topicType"]:
|
|||
|
self.topicImage = t_dep["topicOut"]
|
|||
|
else:
|
|||
|
self.topicCameraInfo = t_dep["topicOut"]
|
|||
|
assert dependency, "no weights dependency"
|
|||
|
assert self.topicImage, "no input topic dependency"
|
|||
|
|
|||
|
# Create the subscribers.
|
|||
|
self._sub_info = self.create_subscription(CameraInfo, self.topicCameraInfo, self.listener_camera_info, 2)
|
|||
|
# Create the publisher.
|
|||
|
self._pub = self.create_lifecycle_publisher(Pose, self.topicSrv, 1)
|
|||
|
|
|||
|
# Load model weights
|
|||
|
w = dependency["weights_file"]
|
|||
|
if not os.path.isfile(w):
|
|||
|
self.get_logger().warning(f"No weights found <{w}>")
|
|||
|
return TransitionCallbackReturn.FAILURE
|
|||
|
|
|||
|
obj = dependency["object_name"]
|
|||
|
dim = dependency["dimensions"]
|
|||
|
self.dope_node = Dope(self.dope_cfg, w, obj, dim)
|
|||
|
|
|||
|
self.get_logger().info(f"configure is success (with '{obj}')")
|
|||
|
return TransitionCallbackReturn.SUCCESS
|
|||
|
|
|||
|
def on_activate(self, state: State) -> TransitionCallbackReturn:
|
|||
|
self.get_logger().info("on_activate is called")
|
|||
|
# Create the Image subscriber.
|
|||
|
self._sub = self.create_subscription(Image, self.topicImage, self.image_callback, 3)
|
|||
|
return super().on_activate(state)
|
|||
|
|
|||
|
def on_deactivate(self, state: State) -> TransitionCallbackReturn:
|
|||
|
self.get_logger().info("on_deactivate is called")
|
|||
|
# Destroy the Image subscriber.
|
|||
|
self.destroy_subscription(self._sub)
|
|||
|
return super().on_deactivate(state)
|
|||
|
|
|||
|
def on_cleanup(self, state: State) -> TransitionCallbackReturn:
|
|||
|
self.destroy_publisher(self._pub)
|
|||
|
self.destroy_subscription(self._sub_info)
|
|||
|
|
|||
|
self._is_camerainfo = False
|
|||
|
self._pub = None
|
|||
|
|
|||
|
self.get_logger().info("on_cleanup is called")
|
|||
|
return TransitionCallbackReturn.SUCCESS
|
|||
|
|
|||
|
def on_shutdown(self, state: State) -> TransitionCallbackReturn:
|
|||
|
self.get_logger().info("on_shutdown is called")
|
|||
|
return TransitionCallbackReturn.SUCCESS
|
|||
|
|
|||
|
# def test_im_proc(self):
|
|||
|
# im = "im_tst.jpg"
|
|||
|
# if not os.path.isfile(im):
|
|||
|
# print(f"File not found '{im}'")
|
|||
|
# return
|
|||
|
# frame = cv2.imread(im)
|
|||
|
# frame = frame[..., ::-1].copy()
|
|||
|
# self._K = [[585.756089952257, 0.0, 319.5], [0.0, 585.756089952257, 239.5], [0.0, 0.0, 1.0]]
|
|||
|
# # call the inference node
|
|||
|
# p = self.dope_node.image_processing(img=frame, camera_info=self._K)
|
|||
|
# print(f"pose = {p}")
|
|||
|
|
|||
|
def image_callback(self, data):
|
|||
|
""" Image Callback function. """
|
|||
|
if not self._is_camerainfo:
|
|||
|
self.get_logger().warning("No data from CameraInfo")
|
|||
|
return
|
|||
|
# Convert ROS Image message to OpenCV image
|
|||
|
current_frame = self.br.imgmsg_to_cv2(data)
|
|||
|
|
|||
|
# # Save image
|
|||
|
# frame_im = FILE_TEMP_IMAGE # str(self.objPath / "image_rgb.png")
|
|||
|
# cv2.imwrite(frame_im, current_frame)
|
|||
|
self._image_cnt += 1
|
|||
|
|
|||
|
self.get_logger().info(f"dope: begin {self._image_cnt}")
|
|||
|
current_frame = current_frame[..., ::-1].copy()
|
|||
|
pose = self.dope_node.image_processing(img=current_frame, camera_info=self._K)
|
|||
|
self.get_logger().info(f"dope: end {self._image_cnt}")
|
|||
|
if self._pub is not None and self._pub.is_activated:
|
|||
|
# publish pose estimation result
|
|||
|
self._pub.publish(pose)
|
|||
|
|
|||
|
from detector import ModelData, ObjectDetector
|
|||
|
from cuboid_pnp_solver import CuboidPNPSolver
|
|||
|
from cuboid import Cuboid3d
|
|||
|
import numpy as np
|
|||
|
|
|||
|
class Dope(object):
|
|||
|
"""ROS node that listens to image topic, runs DOPE, and publishes DOPE results"""
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
config, # config yaml loaded eg dict
|
|||
|
weight, # path to weight
|
|||
|
class_name,
|
|||
|
dim: list # dimensions of model 'class_name'
|
|||
|
):
|
|||
|
self.input_is_rectified = config["input_is_rectified"]
|
|||
|
self.downscale_height = config["downscale_height"]
|
|||
|
|
|||
|
self.config_detect = lambda: None
|
|||
|
self.config_detect.mask_edges = 1
|
|||
|
self.config_detect.mask_faces = 1
|
|||
|
self.config_detect.vertex = 1
|
|||
|
self.config_detect.threshold = 0.5
|
|||
|
self.config_detect.softmax = 1000
|
|||
|
self.config_detect.thresh_angle = config["thresh_angle"]
|
|||
|
self.config_detect.thresh_map = config["thresh_map"]
|
|||
|
self.config_detect.sigma = config["sigma"]
|
|||
|
self.config_detect.thresh_points = config["thresh_points"]
|
|||
|
|
|||
|
# load network model, create PNP solver
|
|||
|
self.model = ModelData(name=class_name, net_path=weight)
|
|||
|
|
|||
|
# TODO warn on load_net_model():
|
|||
|
# Loading DOPE model '/home/shalenikol/robossembler_ws/fork_e47.pth'...
|
|||
|
# /home/shalenikol/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208:
|
|||
|
# UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
|
|||
|
|
|||
|
# warnings.warn(
|
|||
|
# /home/shalenikol/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223:
|
|||
|
# UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future.
|
|||
|
# The current behavior is equivalent to passing `weights=None`.
|
|||
|
# warnings.warn(msg)
|
|||
|
self.model.load_net_model()
|
|||
|
# print("Model Loaded")
|
|||
|
|
|||
|
# try:
|
|||
|
# self.draw_color = tuple(config["draw_colors"][class_name])
|
|||
|
# except:
|
|||
|
self.draw_color = (0, 255, 0)
|
|||
|
|
|||
|
# TODO load dim from config
|
|||
|
# dim = [13.7, 16.5, 20.2] # config["dimensions"][class_name]
|
|||
|
self.dimension = tuple(dim)
|
|||
|
self.class_id = 1 #config["class_ids"][class_name]
|
|||
|
|
|||
|
self.pnp_solver = CuboidPNPSolver(class_name, cuboid3d=Cuboid3d(dim))
|
|||
|
self.class_name = class_name
|
|||
|
# print("Ctrl-C to stop")
|
|||
|
|
|||
|
def image_processing(
|
|||
|
self,
|
|||
|
img,
|
|||
|
camera_info
|
|||
|
# img_name, # this is the name of the img file to save, it needs the .png at the end
|
|||
|
# output_folder, # folder where to put the output
|
|||
|
# weight
|
|||
|
) -> Pose:
|
|||
|
# !!! Allways self.input_is_rectified = True
|
|||
|
camera_matrix = np.matrix(camera_info, dtype="float64").copy()
|
|||
|
dist_coeffs = np.zeros((4, 1))
|
|||
|
# Update camera matrix and distortion coefficients
|
|||
|
# if self.input_is_rectified:
|
|||
|
# P = np.matrix(
|
|||
|
# camera_info["projection_matrix"]["data"], dtype="float64"
|
|||
|
# ).copy()
|
|||
|
# P.resize((3, 4))
|
|||
|
# camera_matrix = P[:, :3]
|
|||
|
# dist_coeffs = np.zeros((4, 1))
|
|||
|
# else:
|
|||
|
# # TODO ???
|
|||
|
# camera_matrix = np.matrix(camera_info.K, dtype="float64")
|
|||
|
# camera_matrix.resize((3, 3))
|
|||
|
# dist_coeffs = np.matrix(camera_info.D, dtype="float64")
|
|||
|
# dist_coeffs.resize((len(camera_info.D), 1))
|
|||
|
|
|||
|
# Downscale image if necessary
|
|||
|
height, width, _ = img.shape
|
|||
|
scaling_factor = float(self.downscale_height) / height
|
|||
|
if scaling_factor < 1.0:
|
|||
|
camera_matrix[:2] *= scaling_factor
|
|||
|
img = cv2.resize(img, (int(scaling_factor * width), int(scaling_factor * height)))
|
|||
|
|
|||
|
self.pnp_solver.set_camera_intrinsic_matrix(camera_matrix)
|
|||
|
self.pnp_solver.set_dist_coeffs(dist_coeffs)
|
|||
|
|
|||
|
# # Copy and draw image
|
|||
|
# img_copy = img.copy()
|
|||
|
# im = Image.fromarray(img_copy)
|
|||
|
# draw = Draw(im)
|
|||
|
|
|||
|
# # dictionary for the final output
|
|||
|
# dict_out = {"camera_data": {}, "objects": []}
|
|||
|
|
|||
|
# Detect object
|
|||
|
results, _ = ObjectDetector.detect_object_in_image(
|
|||
|
self.model.net, self.pnp_solver, img, self.config_detect
|
|||
|
)
|
|||
|
|
|||
|
# Publish pose #and overlay cube on image
|
|||
|
p = Pose()
|
|||
|
for _, result in enumerate(results):
|
|||
|
if result["location"] is None:
|
|||
|
continue
|
|||
|
|
|||
|
l = result["location"]
|
|||
|
q = result["quaternion"]
|
|||
|
p.position.x = l[0]
|
|||
|
p.position.y = l[1]
|
|||
|
p.position.z = l[2]
|
|||
|
p.orientation.x = q[0]
|
|||
|
p.orientation.y = q[1]
|
|||
|
p.orientation.z = q[2]
|
|||
|
p.orientation.w = q[3]
|
|||
|
break # !!! only considering the first option for now
|
|||
|
return p
|
|||
|
# # save the json files
|
|||
|
# with open(f"tmp_result{i}.json", "w") as fp:
|
|||
|
# json.dump(result, fp, indent=2)
|
|||
|
|
|||
|
# dict_out["objects"].append(
|
|||
|
# {
|
|||
|
# "class": self.class_name,
|
|||
|
# "location": np.array(loc).tolist(),
|
|||
|
# "quaternion_xyzw": np.array(ori).tolist(),
|
|||
|
# "projected_cuboid": np.array(result["projected_points"]).tolist(),
|
|||
|
# }
|
|||
|
# )
|
|||
|
# # Draw the cube
|
|||
|
# if None not in result["projected_points"]:
|
|||
|
# points2d = []
|
|||
|
# for pair in result["projected_points"]:
|
|||
|
# points2d.append(tuple(pair))
|
|||
|
# draw.draw_cube(points2d, self.draw_color)
|
|||
|
|
|||
|
# # create directory to save image if it does not exist
|
|||
|
# img_name_base = img_name.split("/")[-1]
|
|||
|
# output_path = os.path.join(
|
|||
|
# output_folder,
|
|||
|
# weight.split("/")[-1].replace(".pth", ""),
|
|||
|
# *img_name.split("/")[:-1],
|
|||
|
# )
|
|||
|
# if not os.path.isdir(output_path):
|
|||
|
# os.makedirs(output_path, exist_ok=True)
|
|||
|
|
|||
|
# im.save(os.path.join(output_path, img_name_base))
|
|||
|
|
|||
|
# json_path = os.path.join(
|
|||
|
# output_path, ".".join(img_name_base.split(".")[:-1]) + ".json"
|
|||
|
# )
|
|||
|
# # save the json files
|
|||
|
# with open(json_path, "w") as fp:
|
|||
|
# json.dump(dict_out, fp, indent=2)
|
|||
|
|
|||
|
def main():
|
|||
|
rclpy.init()
|
|||
|
|
|||
|
executor = rclpy.executors.SingleThreadedExecutor()
|
|||
|
# executor = rclpy.executors.MultiThreadedExecutor()
|
|||
|
lc_node = PE_DOPE()
|
|||
|
executor.add_node(lc_node)
|
|||
|
try:
|
|||
|
executor.spin()
|
|||
|
except (KeyboardInterrupt, rclpy.executors.ExternalShutdownException):
|
|||
|
lc_node.destroy_node()
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|