initial commit, add gello software code and instructions

This commit is contained in:
Philipp Wu 2023-11-13 09:17:27 -08:00
parent e7d842ad35
commit 18cc23a38e
70 changed files with 5875 additions and 4 deletions

View file

@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import tyro
from gello.zmq_core.camera_node import ZMQClientCamera
@dataclass
class Args:
ports: Tuple[int, ...] = (5000, 5001)
hostname: str = "127.0.0.1"
# hostname: str = "128.32.175.167"
def main(args):
cameras = []
import cv2
images_display_names = []
for port in args.ports:
cameras.append(ZMQClientCamera(port=port, host=args.hostname))
images_display_names.append(f"image_{port}")
cv2.namedWindow(images_display_names[-1], cv2.WINDOW_NORMAL)
while True:
for display_name, camera in zip(images_display_names, cameras):
image, depth = camera.read()
stacked_depth = np.dstack([depth, depth, depth]).astype(np.uint8)
image_depth = cv2.hconcat([image[:, :, ::-1], stacked_depth])
cv2.imshow(display_name, image_depth)
cv2.waitKey(1)
if __name__ == "__main__":
main(tyro.cli(Args))

View file

@ -0,0 +1,40 @@
from dataclasses import dataclass
from multiprocessing import Process
import tyro
from gello.cameras.realsense_camera import RealSenseCamera, get_device_ids
from gello.zmq_core.camera_node import ZMQServerCamera
@dataclass
class Args:
# hostname: str = "127.0.0.1"
hostname: str = "128.32.175.167"
def launch_server(port: int, camera_id: int, args: Args):
camera = RealSenseCamera(camera_id)
server = ZMQServerCamera(camera, port=port, host=args.hostname)
print(f"Starting camera server on port {port}")
server.serve()
def main(args):
ids = get_device_ids()
camera_port = 5000
camera_servers = []
for camera_id in ids:
# start a python process for each camera
print(f"Launching camera {camera_id} on port {camera_port}")
camera_servers.append(
Process(target=launch_server, args=(camera_port, camera_id, args))
)
camera_port += 1
for server in camera_servers:
server.start()
if __name__ == "__main__":
main(tyro.cli(Args))

View file

@ -0,0 +1,94 @@
from dataclasses import dataclass
from pathlib import Path
import tyro
from gello.robots.robot import BimanualRobot, PrintRobot
from gello.zmq_core.robot_node import ZMQServerRobot
@dataclass
class Args:
robot: str = "xarm"
robot_port: int = 6001
hostname: str = "127.0.0.1"
robot_ip: str = "192.168.1.10"
def launch_robot_server(args: Args):
port = args.robot_port
if args.robot == "sim_ur":
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
from gello.robots.sim_robot import MujocoRobotServer
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_panda":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_xarm":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "ufactory_xarm7" / "xarm7.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
else:
if args.robot == "xarm":
from gello.robots.xarm_robot import XArmRobot
robot = XArmRobot(ip=args.robot_ip)
elif args.robot == "ur":
from gello.robots.ur import URRobot
robot = URRobot(robot_ip=args.robot_ip)
elif args.robot == "panda":
from gello.robots.panda import PandaRobot
robot = PandaRobot(robot_ip=args.robot_ip)
elif args.robot == "bimanual_ur":
from gello.robots.ur import URRobot
# IP for the bimanual robot setup is hardcoded
_robot_l = URRobot(robot_ip="192.168.2.10")
_robot_r = URRobot(robot_ip="192.168.1.10")
robot = BimanualRobot(_robot_l, _robot_r)
elif args.robot == "none" or args.robot == "print":
robot = PrintRobot(8)
else:
raise NotImplementedError(
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
)
server = ZMQServerRobot(robot, port=port, host=args.hostname)
print(f"Starting robot server on port {port}")
server.serve()
def main(args):
launch_robot_server(args)
if __name__ == "__main__":
main(tyro.cli(Args))

192
experiments/quick_run.py Normal file
View file

@ -0,0 +1,192 @@
import atexit
import glob
import time
from dataclasses import dataclass
from multiprocessing import Process
from pathlib import Path
from typing import Optional
import numpy as np
import tyro
from gello.agents.agent import DummyAgent
from gello.agents.gello_agent import GelloAgent
from gello.agents.spacemouse_agent import SpacemouseAgent
from gello.env import RobotEnv
from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
@dataclass
class Args:
hz: int = 100
agent: str = "gello"
robot: str = "ur5"
gello_port: Optional[str] = None
mock: bool = False
verbose: bool = False
hostname: str = "127.0.0.1"
robot_port: int = 6001
def launch_robot_server(port: int, args: Args):
if args.robot == "sim_ur":
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
from gello.robots.sim_robot import MujocoRobotServer
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_panda":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
else:
if args.robot == "xarm":
from gello.robots.xarm_robot import XArmRobot
robot = XArmRobot()
elif args.robot == "ur5":
from gello.robots.ur import URRobot
robot = URRobot(robot_ip=args.robot_ip)
else:
raise NotImplementedError(
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
)
server = ZMQServerRobot(robot, port=port, host=args.hostname)
print(f"Starting robot server on port {port}")
server.serve()
def start_robot_process(args: Args):
process = Process(target=launch_robot_server, args=(args.robot_port, args))
# Function to kill the child process
def kill_child_process(process):
print("Killing child process...")
process.terminate()
# Register the kill_child_process function to be called at exit
atexit.register(kill_child_process, process)
process.start()
def main(args: Args):
start_robot_process(args)
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
env = RobotEnv(robot_client, control_rate_hz=args.hz)
if args.agent == "gello":
gello_port = args.gello_port
if gello_port is None:
usb_ports = glob.glob("/dev/serial/by-id/*")
print(f"Found {len(usb_ports)} ports")
if len(usb_ports) > 0:
gello_port = usb_ports[0]
print(f"using port {gello_port}")
else:
raise ValueError(
"No gello port found, please specify one or plug in gello"
)
agent = GelloAgent(port=gello_port)
reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
curr_joints = env.get_obs()["joint_positions"]
if reset_joints.shape == curr_joints.shape:
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
time.sleep(0.001)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
elif args.agent == "spacemouse":
agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
elif args.agent == "dummy" or args.agent == "none":
agent = DummyAgent(num_dofs=robot_client.num_dofs())
else:
raise ValueError("Invalid agent name")
# going to start position
print("Going to start position")
start_pos = agent.act(env.get_obs())
obs = env.get_obs()
joints = obs["joint_positions"]
abs_deltas = np.abs(start_pos - joints)
id_max_joint_delta = np.argmax(abs_deltas)
max_joint_delta = 0.8
if abs_deltas[id_max_joint_delta] > max_joint_delta:
id_mask = abs_deltas > max_joint_delta
print()
ids = np.arange(len(id_mask))[id_mask]
for i, delta, joint, current_j in zip(
ids,
abs_deltas[id_mask],
start_pos[id_mask],
joints[id_mask],
):
print(
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
)
return
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
assert len(start_pos) == len(
joints
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
max_delta = 0.05
for _ in range(25):
obs = env.get_obs()
command_joints = agent.act(obs)
current_joints = obs["joint_positions"]
delta = command_joints - current_joints
max_joint_delta = np.abs(delta).max()
if max_joint_delta > max_delta:
delta = delta / max_joint_delta * max_delta
env.step(current_joints + delta)
obs = env.get_obs()
joints = obs["joint_positions"]
action = agent.act(obs)
if (action - joints > 0.5).any():
print("Action is too big")
# print which joints are too big
joint_index = np.where(action - joints > 0.8)
for j in joint_index:
print(
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
)
exit()
while True:
action = agent.act(obs)
obs = env.step(action)
if __name__ == "__main__":
main(tyro.cli(Args))

246
experiments/run_env.py Normal file
View file

@ -0,0 +1,246 @@
import datetime
import glob
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import tyro
from gello.agents.agent import BimanualAgent, DummyAgent
from gello.agents.gello_agent import GelloAgent
from gello.data_utils.format_obs import save_frame
from gello.env import RobotEnv
from gello.robots.robot import PrintRobot
from gello.zmq_core.robot_node import ZMQClientRobot
def print_color(*args, color=None, attrs=(), **kwargs):
import termcolor
if len(args) > 0:
args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
print(*args, **kwargs)
@dataclass
class Args:
agent: str = "none"
robot_port: int = 6001
wrist_camera_port: int = 5000
base_camera_port: int = 5001
hostname: str = "127.0.0.1"
robot_type: str = None # only needed for quest agent or spacemouse agent
hz: int = 100
start_joints: Optional[Tuple[float, ...]] = None
gello_port: Optional[str] = None
mock: bool = False
use_save_interface: bool = False
data_dir: str = "~/bc_data"
bimanual: bool = False
verbose: bool = False
def main(args):
if args.mock:
robot_client = PrintRobot(8, dont_print=True)
camera_clients = {}
else:
camera_clients = {
# you can optionally add camera nodes here for imitation learning purposes
# "wrist": ZMQClientCamera(port=args.wrist_camera_port, host=args.hostname),
# "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
}
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
env = RobotEnv(robot_client, control_rate_hz=args.hz, camera_dict=camera_clients)
if args.bimanual:
if args.agent == "gello":
# dynamixel control box port map (to distinguish left and right gello)
right = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0"
left = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0"
left_agent = GelloAgent(port=left)
right_agent = GelloAgent(port=right)
agent = BimanualAgent(left_agent, right_agent)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
right_agent = SingleArmQuestAgent(
robot_type=args.robot_type, which_hand="r"
)
agent = BimanualAgent(left_agent, right_agent)
# raise NotImplementedError
elif args.agent == "spacemouse":
from gello.agents.spacemouse_agent import SpacemouseAgent
left_path = "/dev/hidraw0"
right_path = "/dev/hidraw1"
left_agent = SpacemouseAgent(
robot_type=args.robot_type, device_path=left_path, verbose=args.verbose
)
right_agent = SpacemouseAgent(
robot_type=args.robot_type,
device_path=right_path,
verbose=args.verbose,
invert_button=True,
)
agent = BimanualAgent(left_agent, right_agent)
else:
raise ValueError(f"Invalid agent name for bimanual: {args.agent}")
# System setup specific. This reset configuration works well on our setup. If you are mounting the robot
# differently, you need a separate reset joint configuration.
reset_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0, 0])
reset_joints_right = np.deg2rad([0, -90, 90, -90, -90, 0, 0])
reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
curr_joints = env.get_obs()["joint_positions"]
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
else:
if args.agent == "gello":
gello_port = args.gello_port
if gello_port is None:
usb_ports = glob.glob("/dev/serial/by-id/*")
print(f"Found {len(usb_ports)} ports")
if len(usb_ports) > 0:
gello_port = usb_ports[0]
print(f"using port {gello_port}")
else:
raise ValueError(
"No gello port found, please specify one or plug in gello"
)
if args.start_joints is None:
reset_joints = np.deg2rad(
[0, -90, 90, -90, -90, 0, 0]
) # Change this to your own reset joints
else:
reset_joints = args.start_joints
agent = GelloAgent(port=gello_port, start_joints=args.start_joints)
curr_joints = env.get_obs()["joint_positions"]
if reset_joints.shape == curr_joints.shape:
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
time.sleep(0.001)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
elif args.agent == "spacemouse":
from gello.agents.spacemouse_agent import SpacemouseAgent
agent = SpacemouseAgent(robot_type=args.robot_type, verbose=args.verbose)
elif args.agent == "dummy" or args.agent == "none":
agent = DummyAgent(num_dofs=robot_client.num_dofs())
elif args.agent == "policy":
raise NotImplementedError("add your imitation policy here if there is one")
else:
raise ValueError("Invalid agent name")
# going to start position
print("Going to start position")
start_pos = agent.act(env.get_obs())
obs = env.get_obs()
joints = obs["joint_positions"]
abs_deltas = np.abs(start_pos - joints)
id_max_joint_delta = np.argmax(abs_deltas)
max_joint_delta = 0.8
if abs_deltas[id_max_joint_delta] > max_joint_delta:
id_mask = abs_deltas > max_joint_delta
print()
ids = np.arange(len(id_mask))[id_mask]
for i, delta, joint, current_j in zip(
ids,
abs_deltas[id_mask],
start_pos[id_mask],
joints[id_mask],
):
print(
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
)
return
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
assert len(start_pos) == len(
joints
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
max_delta = 0.05
for _ in range(25):
obs = env.get_obs()
command_joints = agent.act(obs)
current_joints = obs["joint_positions"]
delta = command_joints - current_joints
max_joint_delta = np.abs(delta).max()
if max_joint_delta > max_delta:
delta = delta / max_joint_delta * max_delta
env.step(current_joints + delta)
obs = env.get_obs()
joints = obs["joint_positions"]
action = agent.act(obs)
if (action - joints > 0.5).any():
print("Action is too big")
# print which joints are too big
joint_index = np.where(action - joints > 0.8)
for j in joint_index:
print(
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
)
exit()
if args.use_save_interface:
from gello.data_utils.keyboard_interface import KBReset
kb_interface = KBReset()
print_color("\nStart 🚀🚀🚀", color="green", attrs=("bold",))
save_path = None
start_time = time.time()
while True:
num = time.time() - start_time
message = f"\rTime passed: {round(num, 2)} "
print_color(
message,
color="white",
attrs=("bold",),
end="",
flush=True,
)
action = agent.act(obs)
dt = datetime.datetime.now()
if args.use_save_interface:
state = kb_interface.update()
if state == "start":
dt_time = datetime.datetime.now()
save_path = (
Path(args.data_dir).expanduser()
/ args.agent
/ dt_time.strftime("%m%d_%H%M%S")
)
save_path.mkdir(parents=True, exist_ok=True)
print(f"Saving to {save_path}")
elif state == "save":
assert save_path is not None, "something went wrong"
save_frame(save_path, dt, obs, action)
elif state == "normal":
save_path = None
else:
raise ValueError(f"Invalid state {state}")
obs = env.step(action)
if __name__ == "__main__":
main(tyro.cli(Args))