diff --git a/experiments/launch_nodes.py b/experiments/launch_nodes.py index 9251711..40463ce 100644 --- a/experiments/launch_nodes.py +++ b/experiments/launch_nodes.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from pathlib import Path +import signal import tyro @@ -17,6 +18,7 @@ class Args: def launch_robot_server(args: Args): port = args.robot_port + robot = None if args.robot == "sim_ur": MENAGERIE_ROOT: Path = ( Path(__file__).parent.parent / "third_party" / "mujoco_menagerie" @@ -28,7 +30,6 @@ def launch_robot_server(args: Args): server = MujocoRobotServer( xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname ) - server.serve() elif args.robot == "sim_panda": from gello.robots.sim_robot import MujocoRobotServer @@ -40,7 +41,6 @@ def launch_robot_server(args: Args): server = MujocoRobotServer( xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname ) - server.serve() elif args.robot == "sim_xarm": from gello.robots.sim_robot import MujocoRobotServer @@ -52,8 +52,6 @@ def launch_robot_server(args: Args): server = MujocoRobotServer( xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname ) - server.serve() - else: if args.robot == "xarm": from gello.robots.xarm_robot import XArmRobot @@ -74,6 +72,10 @@ def launch_robot_server(args: Args): _robot_l = URRobot(robot_ip="192.168.2.10") _robot_r = URRobot(robot_ip="192.168.1.10") robot = BimanualRobot(_robot_l, _robot_r) + elif args.robot == "viperx": + from gello.robots.viperx import ViperXRobot + + robot = ViperXRobot(robot_ip=args.robot_ip) elif args.robot == "none" or args.robot == "print": robot = PrintRobot(8) @@ -83,7 +85,15 @@ def launch_robot_server(args: Args): ) server = ZMQServerRobot(robot, port=port, host=args.hostname) print(f"Starting robot server on port {port}") + + def sigint(sig, frame): + server.stop() + signal.signal(signal.SIGINT, sigint) + try: server.serve() + finally: + if robot is not None and hasattr(robot, 'stop'): + robot.stop() def main(args): diff --git a/experiments/run_env.py b/experiments/run_env.py index bd2319a..8d6846a 100644 --- a/experiments/run_env.py +++ b/experiments/run_env.py @@ -9,7 +9,7 @@ import numpy as np import tyro from gello.agents.agent import BimanualAgent, DummyAgent -from gello.agents.gello_agent import GelloAgent +from gello.agents.gello_agent import GelloAgent, known_gello_port from gello.data_utils.format_obs import save_frame from gello.env import RobotEnv from gello.robots.robot import PrintRobot @@ -108,8 +108,10 @@ def main(args): if gello_port is None: usb_ports = glob.glob("/dev/serial/by-id/*") print(f"Found {len(usb_ports)} ports") - if len(usb_ports) > 0: - gello_port = usb_ports[0] + # Need to filter out non-gello Dynamixel based robots + gello_ports = [port for port in usb_ports if known_gello_port(port)] + if len(gello_ports) > 0: + gello_port = gello_ports[0] print(f"using port {gello_port}") else: raise ValueError( diff --git a/gello/agents/gello_agent.py b/gello/agents/gello_agent.py index fca60c3..8953a57 100644 --- a/gello/agents/gello_agent.py +++ b/gello/agents/gello_agent.py @@ -105,6 +105,8 @@ PORT_CONFIG_MAP: Dict[str, DynamixelRobotConfig] = { ), } +def known_gello_port(port): + return port in PORT_CONFIG_MAP class GelloAgent(Agent): def __init__( diff --git a/gello/robots/viperx.py b/gello/robots/viperx.py new file mode 100644 index 0000000..bb07ee5 --- /dev/null +++ b/gello/robots/viperx.py @@ -0,0 +1,68 @@ +import numpy as np +from pyquaternion import Quaternion + +from gello.robots.robot import Robot +from interbotix_xs_modules.xs_robot.arm import InterbotixManipulatorXS +from interbotix_xs_msgs.msg import JointSingleCommand + +GRIPPER_POSITION_OPEN = 0.05800 +GRIPPER_POSITION_CLOSE = 0.01844 + +GRIPPER_JOINT_OPEN = 1.4910 +GRIPPER_JOINT_CLOSE = -0.6213 + +class ViperXRobot(Robot): + def __init__(self, robot_ip): + print('ViperXRobot __init__') + super().__init__() + self.bot = InterbotixManipulatorXS(robot_model='vx300s', group_name='arm', gripper_name='gripper') + self._gripper_cmd = JointSingleCommand(name="gripper") + + self.bot.core.robot_reboot_motors("single", "gripper", True) + self.bot.core.robot_set_operating_modes("single", "gripper", "current_based_position") + + self.bot.core.robot_set_motor_registers("group", "arm", 'Profile_Velocity', 100) + self.bot.core.robot_set_motor_registers("group", "arm", 'Profile_Acceleration', 0) + + def stop(self): + self.bot.core.robot_set_operating_modes("single", "gripper", "pwm") + + self.bot.core.robot_set_motor_registers("group", "arm", 'Profile_Velocity', 2000) + self.bot.core.robot_set_motor_registers("group", "arm", 'Profile_Acceleration', 300) + + def num_dofs(self) -> int: + return 7 + + def get_joint_state(self) -> np.ndarray: + state = np.concatenate([self.bot.arm.get_joint_commands(), 0]) + print(f'get_joint_state: {state}') + return state + + def command_joint_state(self, joint_state: np.ndarray) -> None: + assert len(joint_state) == (self.num_dofs()), ( + f"Expected joint state of length {self.num_dofs()}, " + f"got {len(joint_state)}." + ) + + self.bot.arm.set_joint_positions(joint_state[:6], blocking=False) + + gripper_angle = ((1 - joint_state[6]) * (GRIPPER_JOINT_OPEN - GRIPPER_JOINT_CLOSE) + GRIPPER_JOINT_CLOSE) + self._gripper_cmd.cmd = gripper_angle + self.bot.gripper.core.pub_single.publish(self._gripper_cmd) + + def get_observations(self): + gripper_angle = self.bot.core.joint_states.position[-2] + gripper_pos = 1 - ((gripper_angle - GRIPPER_POSITION_CLOSE) / (GRIPPER_POSITION_OPEN - GRIPPER_POSITION_CLOSE)) + + joints = np.concatenate([self.bot.arm.get_joint_commands(), [gripper_pos]]) + ee_pos_matrix = self.bot.arm.get_ee_pose_command() + ee_pos = np.array([ee_pos_matrix[0][3], ee_pos_matrix[1][3], ee_pos_matrix[2][3]]) + ee_quat = Quaternion(matrix=ee_pos_matrix[:3, :3]) + + obs = { + "joint_positions": joints, + "joint_velocities": joints, + "ee_pos_quat": np.concatenate([ee_pos, ee_quat.elements]), + "gripper_position": np.array([gripper_pos]), + } + return obs