192 lines
6.1 KiB
Python
192 lines
6.1 KiB
Python
import atexit
|
|
import glob
|
|
import time
|
|
from dataclasses import dataclass
|
|
from multiprocessing import Process
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import tyro
|
|
|
|
from gello.agents.agent import DummyAgent
|
|
from gello.agents.gello_agent import GelloAgent
|
|
from gello.agents.spacemouse_agent import SpacemouseAgent
|
|
from gello.env import RobotEnv
|
|
from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
hz: int = 100
|
|
|
|
agent: str = "gello"
|
|
robot: str = "ur5"
|
|
gello_port: Optional[str] = None
|
|
mock: bool = False
|
|
verbose: bool = False
|
|
|
|
hostname: str = "127.0.0.1"
|
|
robot_port: int = 6001
|
|
|
|
|
|
def launch_robot_server(port: int, args: Args):
|
|
if args.robot == "sim_ur":
|
|
MENAGERIE_ROOT: Path = (
|
|
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
|
)
|
|
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
|
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
|
from gello.robots.sim_robot import MujocoRobotServer
|
|
|
|
server = MujocoRobotServer(
|
|
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
|
)
|
|
server.serve()
|
|
elif args.robot == "sim_panda":
|
|
from gello.robots.sim_robot import MujocoRobotServer
|
|
|
|
MENAGERIE_ROOT: Path = (
|
|
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
|
)
|
|
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
|
|
gripper_xml = None
|
|
server = MujocoRobotServer(
|
|
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
|
)
|
|
server.serve()
|
|
|
|
else:
|
|
if args.robot == "xarm":
|
|
from gello.robots.xarm_robot import XArmRobot
|
|
|
|
robot = XArmRobot()
|
|
elif args.robot == "ur5":
|
|
from gello.robots.ur import URRobot
|
|
|
|
robot = URRobot(robot_ip=args.robot_ip)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
|
|
)
|
|
server = ZMQServerRobot(robot, port=port, host=args.hostname)
|
|
print(f"Starting robot server on port {port}")
|
|
server.serve()
|
|
|
|
|
|
def start_robot_process(args: Args):
|
|
process = Process(target=launch_robot_server, args=(args.robot_port, args))
|
|
|
|
# Function to kill the child process
|
|
def kill_child_process(process):
|
|
print("Killing child process...")
|
|
process.terminate()
|
|
|
|
# Register the kill_child_process function to be called at exit
|
|
atexit.register(kill_child_process, process)
|
|
process.start()
|
|
|
|
|
|
def main(args: Args):
|
|
start_robot_process(args)
|
|
|
|
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
|
|
env = RobotEnv(robot_client, control_rate_hz=args.hz)
|
|
|
|
if args.agent == "gello":
|
|
gello_port = args.gello_port
|
|
if gello_port is None:
|
|
usb_ports = glob.glob("/dev/serial/by-id/*")
|
|
print(f"Found {len(usb_ports)} ports")
|
|
if len(usb_ports) > 0:
|
|
gello_port = usb_ports[0]
|
|
print(f"using port {gello_port}")
|
|
else:
|
|
raise ValueError(
|
|
"No gello port found, please specify one or plug in gello"
|
|
)
|
|
agent = GelloAgent(port=gello_port)
|
|
|
|
reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
|
|
curr_joints = env.get_obs()["joint_positions"]
|
|
if reset_joints.shape == curr_joints.shape:
|
|
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
|
steps = min(int(max_delta / 0.01), 100)
|
|
|
|
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
|
env.step(jnt)
|
|
time.sleep(0.001)
|
|
|
|
elif args.agent == "quest":
|
|
from gello.agents.quest_agent import SingleArmQuestAgent
|
|
|
|
agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
|
|
elif args.agent == "spacemouse":
|
|
agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
|
|
elif args.agent == "dummy" or args.agent == "none":
|
|
agent = DummyAgent(num_dofs=robot_client.num_dofs())
|
|
else:
|
|
raise ValueError("Invalid agent name")
|
|
|
|
# going to start position
|
|
print("Going to start position")
|
|
start_pos = agent.act(env.get_obs())
|
|
obs = env.get_obs()
|
|
joints = obs["joint_positions"]
|
|
|
|
abs_deltas = np.abs(start_pos - joints)
|
|
id_max_joint_delta = np.argmax(abs_deltas)
|
|
|
|
max_joint_delta = 0.8
|
|
if abs_deltas[id_max_joint_delta] > max_joint_delta:
|
|
id_mask = abs_deltas > max_joint_delta
|
|
print()
|
|
ids = np.arange(len(id_mask))[id_mask]
|
|
for i, delta, joint, current_j in zip(
|
|
ids,
|
|
abs_deltas[id_mask],
|
|
start_pos[id_mask],
|
|
joints[id_mask],
|
|
):
|
|
print(
|
|
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
|
|
)
|
|
return
|
|
|
|
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
|
|
assert len(start_pos) == len(
|
|
joints
|
|
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
|
|
|
|
max_delta = 0.05
|
|
for _ in range(25):
|
|
obs = env.get_obs()
|
|
command_joints = agent.act(obs)
|
|
current_joints = obs["joint_positions"]
|
|
delta = command_joints - current_joints
|
|
max_joint_delta = np.abs(delta).max()
|
|
if max_joint_delta > max_delta:
|
|
delta = delta / max_joint_delta * max_delta
|
|
env.step(current_joints + delta)
|
|
|
|
obs = env.get_obs()
|
|
joints = obs["joint_positions"]
|
|
action = agent.act(obs)
|
|
if (action - joints > 0.5).any():
|
|
print("Action is too big")
|
|
|
|
# print which joints are too big
|
|
joint_index = np.where(action - joints > 0.8)
|
|
for j in joint_index:
|
|
print(
|
|
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
|
)
|
|
exit()
|
|
|
|
while True:
|
|
action = agent.act(obs)
|
|
obs = env.step(action)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(tyro.cli(Args))
|