gello_software/experiments/quick_run.py

192 lines
6.1 KiB
Python

import atexit
import glob
import time
from dataclasses import dataclass
from multiprocessing import Process
from pathlib import Path
from typing import Optional
import numpy as np
import tyro
from gello.agents.agent import DummyAgent
from gello.agents.gello_agent import GelloAgent
from gello.agents.spacemouse_agent import SpacemouseAgent
from gello.env import RobotEnv
from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
@dataclass
class Args:
hz: int = 100
agent: str = "gello"
robot: str = "ur5"
gello_port: Optional[str] = None
mock: bool = False
verbose: bool = False
hostname: str = "127.0.0.1"
robot_port: int = 6001
def launch_robot_server(port: int, args: Args):
if args.robot == "sim_ur":
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
from gello.robots.sim_robot import MujocoRobotServer
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_panda":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
else:
if args.robot == "xarm":
from gello.robots.xarm_robot import XArmRobot
robot = XArmRobot()
elif args.robot == "ur5":
from gello.robots.ur import URRobot
robot = URRobot(robot_ip=args.robot_ip)
else:
raise NotImplementedError(
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
)
server = ZMQServerRobot(robot, port=port, host=args.hostname)
print(f"Starting robot server on port {port}")
server.serve()
def start_robot_process(args: Args):
process = Process(target=launch_robot_server, args=(args.robot_port, args))
# Function to kill the child process
def kill_child_process(process):
print("Killing child process...")
process.terminate()
# Register the kill_child_process function to be called at exit
atexit.register(kill_child_process, process)
process.start()
def main(args: Args):
start_robot_process(args)
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
env = RobotEnv(robot_client, control_rate_hz=args.hz)
if args.agent == "gello":
gello_port = args.gello_port
if gello_port is None:
usb_ports = glob.glob("/dev/serial/by-id/*")
print(f"Found {len(usb_ports)} ports")
if len(usb_ports) > 0:
gello_port = usb_ports[0]
print(f"using port {gello_port}")
else:
raise ValueError(
"No gello port found, please specify one or plug in gello"
)
agent = GelloAgent(port=gello_port)
reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
curr_joints = env.get_obs()["joint_positions"]
if reset_joints.shape == curr_joints.shape:
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
time.sleep(0.001)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
elif args.agent == "spacemouse":
agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
elif args.agent == "dummy" or args.agent == "none":
agent = DummyAgent(num_dofs=robot_client.num_dofs())
else:
raise ValueError("Invalid agent name")
# going to start position
print("Going to start position")
start_pos = agent.act(env.get_obs())
obs = env.get_obs()
joints = obs["joint_positions"]
abs_deltas = np.abs(start_pos - joints)
id_max_joint_delta = np.argmax(abs_deltas)
max_joint_delta = 0.8
if abs_deltas[id_max_joint_delta] > max_joint_delta:
id_mask = abs_deltas > max_joint_delta
print()
ids = np.arange(len(id_mask))[id_mask]
for i, delta, joint, current_j in zip(
ids,
abs_deltas[id_mask],
start_pos[id_mask],
joints[id_mask],
):
print(
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
)
return
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
assert len(start_pos) == len(
joints
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
max_delta = 0.05
for _ in range(25):
obs = env.get_obs()
command_joints = agent.act(obs)
current_joints = obs["joint_positions"]
delta = command_joints - current_joints
max_joint_delta = np.abs(delta).max()
if max_joint_delta > max_delta:
delta = delta / max_joint_delta * max_delta
env.step(current_joints + delta)
obs = env.get_obs()
joints = obs["joint_positions"]
action = agent.act(obs)
if (action - joints > 0.5).any():
print("Action is too big")
# print which joints are too big
joint_index = np.where(action - joints > 0.8)
for j in joint_index:
print(
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
)
exit()
while True:
action = agent.act(obs)
obs = env.step(action)
if __name__ == "__main__":
main(tyro.cli(Args))