runtime/rbs_perception/scripts/pose_estimation.py

252 lines
8 KiB
Python
Executable file
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
detection_service
ROS 2 program for 6D Pose Estimation
@shalenikol release 0.2
"""
# Import the necessary libraries
import rclpy # Python library for ROS 2
from rclpy.node import Node # Handles the creation of nodes
from sensor_msgs.msg import Image # Image is the message type
from geometry_msgs.msg import Quaternion, TransformStamped
from tf2_ros import TransformBroadcaster
from cv_bridge import CvBridge # Package to convert between ROS and OpenCV Images
import cv2 # OpenCV library
from rbs_skill_interfaces.srv import DetectObject
from rbs_skill_interfaces.msg import ObjectInfo
#import subprocess
import os
import shutil
import json
import tempfile
from pathlib import Path
import numpy as np
from ament_index_python.packages import get_package_share_directory
# import megapose
from megapose.scripts.run_inference_on_example import run_inference
tf2_send_pose = True
"""
# encoder for numpy array
def np_encoder(object):
if isinstance(object, (np.generic, np.ndarray)):
return object.item()
"""
class ImageSubscriber(Node):
"""
Create an ImageSubscriber class, which is a subclass of the Node class.
"""
def _InitService(self):
# Initialization service data
p = os.path.join(get_package_share_directory("rbs_perception"), "config", "pose_estimation_config.json")
# load config
with open(p, "r") as f:
y = json.load(f)
for name, val in y.items():
if name == "nodeName":
self.nodeName = val
elif name == "topicImage":
self.topicImage = val
elif name == "topicPubName":
self.topicPubName = val
elif name == "topicSrv":
self.topicSrv = val
elif name == "tf2_send_pose":
self.tf2_send_pose = val
elif name == "camera_info":
self.K_, self.res_ = self._getCameraParam(val)
def _getCameraParam(self, info):
"""
Returns the intrinsic matrix and resolution from the provided camera info.
"""
intrinsic_matrix = [
[info["fx"], 0.0, info["width"] / 2.0],
[0.0, info["fy"], info["height"] / 2.0],
[0.0, 0.0, 1.0]
]
resolution = [info["height"], info["width"]]
return intrinsic_matrix, resolution
def __init__(self):
"""
Class constructor to set up the node
"""
self.nodeName = "image_sub2"
self.topicImage = "/outer_rgbd_camera/image"
self.topicPubName = self.nodeName + "/pose6D_images"
self.topicSrv = self.nodeName + "/detect6Dpose"
self._InitService()
self.tmpdir = tempfile.gettempdir()
self.mytemppath = Path(self.tmpdir) / "rbs_per"
self.mytemppath.mkdir(exist_ok=True)
#os.environ["MEGAPOSE_DATA_DIR"] = str(self.mytemppath)
# Initiate the Node class's constructor and give it a name
super().__init__(self.nodeName)
# Initialize the transform broadcaster
self.tf_broadcaster = TransformBroadcaster(self)
self.subscription = None
self.objName = ""
self.objMeshFile = ""
self.objPath = ""
# Used to convert between ROS and OpenCV images
self.br = CvBridge()
self.cnt = 0
#self.get_logger().info(f"__init__ : __file__ = {__file__} tmpdir = {self.tmpdir}")
self.service = self.create_service(DetectObject, self.topicSrv, self.service_callback)
def service_callback(self, request, response):
self.get_logger().info(f"Incoming request for pose estimation ObjectInfo(name: {request.object.name}, mesh_path: {request.object.mesh_path})")
if not os.path.isfile(request.object.mesh_path):
response.call_status = False
response.error_msg = f"{request.object.mesh_path}: no such file"
return response
if request.object.id == -1:
self.subscription = None # ? сброс подпискиpython -m megapose.scripts.download --example_data
response.call_status = True
return response
if self.subscription == None:
self.objName = request.object.name
self.objMeshFile = request.object.mesh_path
self.objPath = self.mytemppath / "examples"
self.objPath.mkdir(exist_ok=True)
self.objPath /= self.objName
self.objPath.mkdir(exist_ok=True)
tPath = self.objPath / "inputs"
tPath.mkdir(exist_ok=True)
#{"label": "fork", "bbox_modal": [329, 189, 430, 270]}
output_fn = tPath / "object_data.json"
output_json_dict = {
"label": self.objName,
"bbox_modal": [2,2,self.res_[1]-4,self.res_[0]-4]
}
data = []
data.append(output_json_dict)
output_fn.write_text(json.dumps(data))
tPath = self.objPath / "meshes"
tPath.mkdir(exist_ok=True)
tPath /= self.objName
tPath.mkdir(exist_ok=True)
shutil.copyfile(self.objMeshFile, str(tPath / (self.objName+".ply")))
#{"K": [[25.0, 0.0, 8.65], [0.0, 25.0, 6.5], [0.0, 0.0, 1.0]], "resolution": [480, 640]}
output_fn = self.objPath / "camera_data.json"
output_json_dict = {
"K": self.K_,
"resolution": self.res_
}
data = []
data.append(output_json_dict)
output_fn.write_text(json.dumps(output_json_dict))
# Create the subscriber. This subscriber will receive an Image from the video_frames topic. The queue size is 3 messages.
self.subscription = self.create_subscription(Image, self.topicImage, self.listener_callback, 3)
# Create the publisher. This publisher will publish an Quaternion to the 'pose6D_<obj>' topic. The queue size is 10 messages.
self.publisher = self.create_publisher(Quaternion, "pose6D_"+self.objName, 10)
response.call_status = True
else:
response.call_status = True
return response
def load_result(self, example_dir: Path, json_name = "object_data.json"):
f = example_dir / "outputs" / json_name
if os.path.isfile(f):
data = f.read_text()
else:
data = "No result file: '" + str(f) + "'"
return data
def tf_obj_pose(self,pose):
"""
Передача позиции объекта в tf2
"""
t = TransformStamped()
# assign pose to corresponding tf variables
t.header.stamp = self.get_clock().now().to_msg()
t.header.frame_id = 'world'
t.child_frame_id = self.objName
# coordinates
tr = pose[1]
t.transform.translation.x = tr[0]
t.transform.translation.y = tr[1]
t.transform.translation.z = tr[2]
# rotation
q = pose[0]
t.transform.rotation.x = q[1] # 0
t.transform.rotation.y = q[2] # 1
t.transform.rotation.z = q[3] # 2
t.transform.rotation.w = q[0] # 3
# Send the transformation
self.tf_broadcaster.sendTransform(t)
def listener_callback(self, data):
"""
Callback function.
"""
# Display the message on the console
self.get_logger().info("Receiving video frame")
# Convert ROS Image message to OpenCV image
current_frame = self.br.imgmsg_to_cv2(data)
# Save image for Megapose
cv2.imwrite(str(self.objPath / "image_rgb.png"), current_frame)
self.cnt += 1
# 6D pose estimation
self.get_logger().info(f"megapose: begin {self.cnt}")
print(self.objPath)
run_inference(self.objPath,"megapose-1.0-RGB-multi-hypothesis")
# опубликуем результат оценки позы
data = self.load_result(self.objPath)
if data[0] == "[":
y = json.loads(data)[0]
pose = y["TWO"]
quat = pose[0]
#pose[1] - 3D перемещение
self.publisher.publish(Quaternion(x=quat[1],y=quat[2],z=quat[3],w=quat[0]))
if tf2_send_pose:
self.tf_obj_pose(pose)
self.get_logger().info(f"megapose: end {self.cnt}")
cv2.waitKey(1)
def main(args=None):
# Initialize the rclpy library
rclpy.init(args=args)
# Create the node
image_subscriber = ImageSubscriber()
# Spin the node so the callback function is called.
rclpy.spin(image_subscriber)
# Destroy the node explicitly
# (optional - otherwise it will be done automatically
# when the garbage collector destroys the node object)
image_subscriber.destroy_node()
# Shutdown the ROS client library for Python
rclpy.shutdown()
if __name__ == '__main__':
main()