410 lines
No EOL
16 KiB
Python
Executable file
410 lines
No EOL
16 KiB
Python
Executable file
#!/usr/bin/env python3
|
||
"""
|
||
pose_estimation_lifecycle_node_with_DOPE
|
||
ROS 2 program for 6D Pose Estimation
|
||
|
||
Source run:
|
||
python inference.py --weights ../output/weights_2996 --data ../sample_dataset100 --object fork --exts jpg \
|
||
--config config/config_pose_fork.yaml --camera config/camera_info_640x480.yaml
|
||
|
||
@shalenikol release 0.3
|
||
"""
|
||
import os
|
||
import json
|
||
import yaml
|
||
|
||
import rclpy
|
||
from rclpy.lifecycle import Node
|
||
from rclpy.lifecycle import State
|
||
from rclpy.lifecycle import TransitionCallbackReturn
|
||
|
||
from ament_index_python.packages import get_package_share_directory
|
||
from sensor_msgs.msg import Image, CameraInfo
|
||
from geometry_msgs.msg import Pose
|
||
# from tf2_ros import TransformException
|
||
from tf2_ros.buffer import Buffer
|
||
|
||
from cv_bridge import CvBridge # Package to convert between ROS and OpenCV Images
|
||
import cv2 # OpenCV library
|
||
|
||
FILE_DOPE_CONFIG = "pe_dope_config.yaml"
|
||
# FILE_TEMP_IMAGE = "image_rgb.png"
|
||
CAMERA_LINK_DEFAULT = "outer_rgbd_camera"
|
||
NODE_NAME_DEFAULT = "lc_dope" # the name doesn't matter in this node (defined in Launch)
|
||
PARAM_SKILL_CFG = "lc_dope_cfg"
|
||
# PARAM_SUFFIX = "_cfg"
|
||
# node_name = self.cfg_data["Launch"]["name"]
|
||
# par_name = node_name + PARAM_SUFFIX
|
||
|
||
|
||
def get_transfer_path_() -> str:
|
||
return os.path.join(get_package_share_directory("rbs_perception"), "config")
|
||
|
||
class PE_DOPE(Node):
|
||
"""Pose estimation lifecycle node with DOPE."""
|
||
def __init__(self, **kwargs):
|
||
"""Construct the node."""
|
||
# for other nodes
|
||
kwargs["allow_undeclared_parameters"] = True
|
||
kwargs["automatically_declare_parameters_from_overrides"] = True
|
||
super().__init__(NODE_NAME_DEFAULT, **kwargs)
|
||
|
||
str_cfg = self.get_parameter(PARAM_SKILL_CFG).get_parameter_value().string_value
|
||
self.skill_cfg = json.loads(str_cfg)
|
||
self.nodeName = self.skill_cfg["Launch"]["name"]
|
||
out_par = self.skill_cfg["Interface"]["Output"][0]
|
||
self.topicSrv = self.nodeName + "/" + out_par["name"]
|
||
|
||
# Used to convert between ROS and OpenCV images
|
||
self.br = CvBridge()
|
||
self.dope_cfg = self._load_config_DOPE()
|
||
|
||
self._cam_pose = Pose()
|
||
self.tf_buffer = Buffer()
|
||
|
||
self._is_camerainfo = False
|
||
self.topicImage = ""
|
||
self.topicCameraInfo = ""
|
||
# self.camera_link = ""
|
||
self._set_camera_topic(CAMERA_LINK_DEFAULT)
|
||
self._sub = None
|
||
self._sub_info = None
|
||
self._pub = None
|
||
self._image_cnt = 0
|
||
self._K = []
|
||
|
||
def _load_config_DOPE(self):
|
||
p = os.path.join(get_transfer_path_(), FILE_DOPE_CONFIG)
|
||
with open(p, "r") as f:
|
||
y = yaml.load(f, Loader=yaml.FullLoader)
|
||
return y
|
||
|
||
def _set_camera_topic(self, camera_link: str):
|
||
""" Service for camera name topics """
|
||
self.topicImage = "/" + camera_link + "/image"
|
||
self.topicCameraInfo = "/" + camera_link +"/camera_info"
|
||
# self.camera_link = camera_link
|
||
|
||
def listener_camera_info(self, data):
|
||
""" CameraInfo callback function. """
|
||
if self._is_camerainfo: # don’t read camera info again
|
||
return
|
||
|
||
self._res = [data.height, data.width]
|
||
k_ = data.k
|
||
self._K = [
|
||
[k_[0], k_[1], k_[2]],
|
||
[k_[3], k_[4], k_[5]],
|
||
[k_[6], k_[7], k_[8]]
|
||
]
|
||
# set the indicator for receiving the camera info
|
||
self._is_camerainfo = True
|
||
|
||
def on_configure(self, state: State) -> TransitionCallbackReturn:
|
||
"""
|
||
Configure the node, after a configuring transition is requested.
|
||
|
||
return: The state machine either invokes a transition to the "inactive" state or stays
|
||
in "unconfigured" depending on the return value.
|
||
TransitionCallbackReturn.SUCCESS transitions to "inactive".
|
||
TransitionCallbackReturn.FAILURE transitions to "unconfigured".
|
||
TransitionCallbackReturn.ERROR or any uncaught exceptions to "errorprocessing"
|
||
"""
|
||
json_param = self.get_parameter(PARAM_SKILL_CFG).get_parameter_value().string_value
|
||
jdata = json.loads(json_param)
|
||
dependency = {}
|
||
for comm in jdata["BTAction"]:
|
||
for par in comm["param"]:
|
||
if par["type"] == "weights":
|
||
dependency = par["dependency"]
|
||
assert dependency, "no dependency"
|
||
|
||
for par in jdata["Settings"]:
|
||
if par["name"] == "cameraLink":
|
||
self._set_camera_topic(par["value"])
|
||
|
||
# Create the subscribers.
|
||
self._sub_info = self.create_subscription(CameraInfo, self.topicCameraInfo, self.listener_camera_info, 2)
|
||
# Create the publisher.
|
||
self._pub = self.create_lifecycle_publisher(Pose, self.topicSrv, 1)
|
||
|
||
# Load model weights
|
||
w = dependency["weights_file"]
|
||
if not os.path.isfile(w):
|
||
self.get_logger().warning(f"No weights found <{w}>")
|
||
return TransitionCallbackReturn.FAILURE
|
||
|
||
obj = dependency["object_name"]
|
||
dim = dependency["dimensions"]
|
||
self.dope_node = Dope(self.dope_cfg, w, obj, dim)
|
||
|
||
self.get_logger().info(f"configure is success (with '{obj}')")
|
||
return TransitionCallbackReturn.SUCCESS
|
||
|
||
def on_activate(self, state: State) -> TransitionCallbackReturn:
|
||
self.get_logger().info("on_activate is called")
|
||
# Create the Image subscriber.
|
||
self._sub = self.create_subscription(Image, self.topicImage, self.image_callback, 3)
|
||
|
||
# # !!! test
|
||
# self._timer = self.create_timer(5, self.test_im_proc)
|
||
|
||
return super().on_activate(state)
|
||
|
||
def on_deactivate(self, state: State) -> TransitionCallbackReturn:
|
||
self.get_logger().info("on_deactivate is called")
|
||
|
||
# # !!! test
|
||
# self.destroy_timer(self._timer)
|
||
|
||
# Destroy the Image subscriber.
|
||
self.destroy_subscription(self._sub)
|
||
return super().on_deactivate(state)
|
||
|
||
def on_cleanup(self, state: State) -> TransitionCallbackReturn:
|
||
# очистим параметры
|
||
# node_param = rclpy.parameter.Parameter("mesh_path", rclpy.Parameter.Type.STRING, "")
|
||
# all_node_param = [node_param]
|
||
# self.set_parameters(all_node_param)
|
||
|
||
self._is_camerainfo = False
|
||
|
||
self.destroy_publisher(self._pub)
|
||
self.destroy_subscription(self._sub_info)
|
||
|
||
self.get_logger().info("on_cleanup is called")
|
||
return TransitionCallbackReturn.SUCCESS
|
||
|
||
def on_shutdown(self, state: State) -> TransitionCallbackReturn:
|
||
self.get_logger().info("on_shutdown is called")
|
||
return TransitionCallbackReturn.SUCCESS
|
||
|
||
# def test_im_proc(self):
|
||
# im = "im_tst.jpg"
|
||
# if not os.path.isfile(im):
|
||
# print(f"File not found '{im}'")
|
||
# return
|
||
# frame = cv2.imread(im)
|
||
# frame = frame[..., ::-1].copy()
|
||
# self._K = [[585.756089952257, 0.0, 319.5], [0.0, 585.756089952257, 239.5], [0.0, 0.0, 1.0]]
|
||
# # call the inference node
|
||
# p = self.dope_node.image_processing(img=frame, camera_info=self._K)
|
||
# print(f"pose = {p}")
|
||
|
||
def image_callback(self, data):
|
||
""" Image Callback function. """
|
||
if not self._is_camerainfo:
|
||
self.get_logger().warning("No data from CameraInfo")
|
||
return
|
||
# # get camera pose
|
||
# camera_name = self.camera_link #self.topicImage.split('/')[1]
|
||
# try:
|
||
# t = self.tf_buffer.lookup_transform("world", camera_name, rclpy.time.Time())
|
||
# except TransformException as ex:
|
||
# self.get_logger().info(f"Could not transform {camera_name} to world: {ex}")
|
||
# return
|
||
# self._cam_pose.position.x = t.transform.translation.x
|
||
# self._cam_pose.position.y = t.transform.translation.y
|
||
# self._cam_pose.position.z = t.transform.translation.z
|
||
# self._cam_pose.orientation.w = t.transform.rotation.w
|
||
# self._cam_pose.orientation.x = t.transform.rotation.x
|
||
# self._cam_pose.orientation.y = t.transform.rotation.y
|
||
# self._cam_pose.orientation.z = t.transform.rotation.z
|
||
|
||
# Convert ROS Image message to OpenCV image
|
||
current_frame = self.br.imgmsg_to_cv2(data)
|
||
|
||
# # Save image
|
||
# frame_im = FILE_TEMP_IMAGE # str(self.objPath / "image_rgb.png")
|
||
# cv2.imwrite(frame_im, current_frame)
|
||
self._image_cnt += 1
|
||
|
||
self.get_logger().info(f"dope: begin {self._image_cnt}")
|
||
current_frame = current_frame[..., ::-1].copy()
|
||
pose = self.dope_node.image_processing(img=current_frame, camera_info=self._K)
|
||
self.get_logger().info(f"dope: end {self._image_cnt}")
|
||
if self._pub is not None and self._pub.is_activated:
|
||
# publish pose estimation result
|
||
self._pub.publish(pose)
|
||
# if self.tf2_send_pose:
|
||
# self.tf_obj_pose(t,q) #(self._pose)
|
||
|
||
from detector import ModelData, ObjectDetector
|
||
from cuboid_pnp_solver import CuboidPNPSolver
|
||
from cuboid import Cuboid3d
|
||
import numpy as np
|
||
# robossembler_ws/src/robossembler-ros2/rbs_perception/scripts/utils.py
|
||
# from utils import Draw
|
||
|
||
class Dope(object):
|
||
"""ROS node that listens to image topic, runs DOPE, and publishes DOPE results"""
|
||
|
||
def __init__(
|
||
self,
|
||
config, # config yaml loaded eg dict
|
||
weight, # path to weight
|
||
class_name,
|
||
dim: list # dimensions of model 'class_name'
|
||
):
|
||
self.input_is_rectified = config["input_is_rectified"]
|
||
self.downscale_height = config["downscale_height"]
|
||
|
||
self.config_detect = lambda: None
|
||
self.config_detect.mask_edges = 1
|
||
self.config_detect.mask_faces = 1
|
||
self.config_detect.vertex = 1
|
||
self.config_detect.threshold = 0.5
|
||
self.config_detect.softmax = 1000
|
||
self.config_detect.thresh_angle = config["thresh_angle"]
|
||
self.config_detect.thresh_map = config["thresh_map"]
|
||
self.config_detect.sigma = config["sigma"]
|
||
self.config_detect.thresh_points = config["thresh_points"]
|
||
|
||
# load network model, create PNP solver
|
||
self.model = ModelData(name=class_name, net_path=weight)
|
||
|
||
# TODO warn on load_net_model():
|
||
# Loading DOPE model '/home/shalenikol/robossembler_ws/fork_e47.pth'...
|
||
# /home/shalenikol/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208:
|
||
# UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
|
||
|
||
# warnings.warn(
|
||
# /home/shalenikol/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223:
|
||
# UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future.
|
||
# The current behavior is equivalent to passing `weights=None`.
|
||
# warnings.warn(msg)
|
||
self.model.load_net_model()
|
||
# print("Model Loaded")
|
||
|
||
# try:
|
||
# self.draw_color = tuple(config["draw_colors"][class_name])
|
||
# except:
|
||
self.draw_color = (0, 255, 0)
|
||
|
||
# TODO load dim from config
|
||
# dim = [13.7, 16.5, 20.2] # config["dimensions"][class_name]
|
||
self.dimension = tuple(dim)
|
||
self.class_id = 1 #config["class_ids"][class_name]
|
||
|
||
self.pnp_solver = CuboidPNPSolver(class_name, cuboid3d=Cuboid3d(dim))
|
||
self.class_name = class_name
|
||
# print("Ctrl-C to stop")
|
||
|
||
def image_processing(
|
||
self,
|
||
img,
|
||
camera_info
|
||
# img_name, # this is the name of the img file to save, it needs the .png at the end
|
||
# output_folder, # folder where to put the output
|
||
# weight
|
||
) -> Pose:
|
||
# !!! Allways self.input_is_rectified = True
|
||
camera_matrix = np.matrix(camera_info, dtype="float64").copy()
|
||
dist_coeffs = np.zeros((4, 1))
|
||
# Update camera matrix and distortion coefficients
|
||
# if self.input_is_rectified:
|
||
# P = np.matrix(
|
||
# camera_info["projection_matrix"]["data"], dtype="float64"
|
||
# ).copy()
|
||
# P.resize((3, 4))
|
||
# camera_matrix = P[:, :3]
|
||
# dist_coeffs = np.zeros((4, 1))
|
||
# else:
|
||
# # TODO ???
|
||
# camera_matrix = np.matrix(camera_info.K, dtype="float64")
|
||
# camera_matrix.resize((3, 3))
|
||
# dist_coeffs = np.matrix(camera_info.D, dtype="float64")
|
||
# dist_coeffs.resize((len(camera_info.D), 1))
|
||
|
||
# Downscale image if necessary
|
||
height, width, _ = img.shape
|
||
scaling_factor = float(self.downscale_height) / height
|
||
if scaling_factor < 1.0:
|
||
camera_matrix[:2] *= scaling_factor
|
||
img = cv2.resize(img, (int(scaling_factor * width), int(scaling_factor * height)))
|
||
|
||
self.pnp_solver.set_camera_intrinsic_matrix(camera_matrix)
|
||
self.pnp_solver.set_dist_coeffs(dist_coeffs)
|
||
|
||
# # Copy and draw image
|
||
# img_copy = img.copy()
|
||
# im = Image.fromarray(img_copy)
|
||
# draw = Draw(im)
|
||
|
||
# # dictionary for the final output
|
||
# dict_out = {"camera_data": {}, "objects": []}
|
||
|
||
# Detect object
|
||
results, _ = ObjectDetector.detect_object_in_image(
|
||
self.model.net, self.pnp_solver, img, self.config_detect
|
||
)
|
||
|
||
# Publish pose #and overlay cube on image
|
||
p = Pose()
|
||
for _, result in enumerate(results):
|
||
if result["location"] is None:
|
||
continue
|
||
|
||
l = result["location"]
|
||
q = result["quaternion"]
|
||
p.position.x = l[0]
|
||
p.position.y = l[1]
|
||
p.position.z = l[2]
|
||
p.orientation.x = q[0]
|
||
p.orientation.y = q[1]
|
||
p.orientation.z = q[2]
|
||
p.orientation.w = q[3]
|
||
break # !!! only considering the first option for now
|
||
return p
|
||
# # save the json files
|
||
# with open(f"tmp_result{i}.json", "w") as fp:
|
||
# json.dump(result, fp, indent=2)
|
||
|
||
# dict_out["objects"].append(
|
||
# {
|
||
# "class": self.class_name,
|
||
# "location": np.array(loc).tolist(),
|
||
# "quaternion_xyzw": np.array(ori).tolist(),
|
||
# "projected_cuboid": np.array(result["projected_points"]).tolist(),
|
||
# }
|
||
# )
|
||
# # Draw the cube
|
||
# if None not in result["projected_points"]:
|
||
# points2d = []
|
||
# for pair in result["projected_points"]:
|
||
# points2d.append(tuple(pair))
|
||
# draw.draw_cube(points2d, self.draw_color)
|
||
|
||
# # create directory to save image if it does not exist
|
||
# img_name_base = img_name.split("/")[-1]
|
||
# output_path = os.path.join(
|
||
# output_folder,
|
||
# weight.split("/")[-1].replace(".pth", ""),
|
||
# *img_name.split("/")[:-1],
|
||
# )
|
||
# if not os.path.isdir(output_path):
|
||
# os.makedirs(output_path, exist_ok=True)
|
||
|
||
# im.save(os.path.join(output_path, img_name_base))
|
||
|
||
# json_path = os.path.join(
|
||
# output_path, ".".join(img_name_base.split(".")[:-1]) + ".json"
|
||
# )
|
||
# # save the json files
|
||
# with open(json_path, "w") as fp:
|
||
# json.dump(dict_out, fp, indent=2)
|
||
|
||
def main():
|
||
rclpy.init()
|
||
|
||
executor = rclpy.executors.SingleThreadedExecutor()
|
||
# executor = rclpy.executors.MultiThreadedExecutor()
|
||
lc_node = PE_DOPE()
|
||
executor.add_node(lc_node)
|
||
try:
|
||
executor.spin()
|
||
except (KeyboardInterrupt, rclpy.executors.ExternalShutdownException):
|
||
lc_node.destroy_node()
|
||
|
||
if __name__ == '__main__':
|
||
main() |