initial commit, add gello software code and instructions
This commit is contained in:
parent
e7d842ad35
commit
18cc23a38e
70 changed files with 5875 additions and 4 deletions
37
experiments/launch_camera_clients.py
Normal file
37
experiments/launch_camera_clients.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.zmq_core.camera_node import ZMQClientCamera
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
ports: Tuple[int, ...] = (5000, 5001)
|
||||
hostname: str = "127.0.0.1"
|
||||
# hostname: str = "128.32.175.167"
|
||||
|
||||
|
||||
def main(args):
|
||||
cameras = []
|
||||
import cv2
|
||||
|
||||
images_display_names = []
|
||||
for port in args.ports:
|
||||
cameras.append(ZMQClientCamera(port=port, host=args.hostname))
|
||||
images_display_names.append(f"image_{port}")
|
||||
cv2.namedWindow(images_display_names[-1], cv2.WINDOW_NORMAL)
|
||||
|
||||
while True:
|
||||
for display_name, camera in zip(images_display_names, cameras):
|
||||
image, depth = camera.read()
|
||||
stacked_depth = np.dstack([depth, depth, depth]).astype(np.uint8)
|
||||
image_depth = cv2.hconcat([image[:, :, ::-1], stacked_depth])
|
||||
cv2.imshow(display_name, image_depth)
|
||||
cv2.waitKey(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
40
experiments/launch_camera_nodes.py
Normal file
40
experiments/launch_camera_nodes.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
|
||||
import tyro
|
||||
|
||||
from gello.cameras.realsense_camera import RealSenseCamera, get_device_ids
|
||||
from gello.zmq_core.camera_node import ZMQServerCamera
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
# hostname: str = "127.0.0.1"
|
||||
hostname: str = "128.32.175.167"
|
||||
|
||||
|
||||
def launch_server(port: int, camera_id: int, args: Args):
|
||||
camera = RealSenseCamera(camera_id)
|
||||
server = ZMQServerCamera(camera, port=port, host=args.hostname)
|
||||
print(f"Starting camera server on port {port}")
|
||||
server.serve()
|
||||
|
||||
|
||||
def main(args):
|
||||
ids = get_device_ids()
|
||||
camera_port = 5000
|
||||
camera_servers = []
|
||||
for camera_id in ids:
|
||||
# start a python process for each camera
|
||||
print(f"Launching camera {camera_id} on port {camera_port}")
|
||||
camera_servers.append(
|
||||
Process(target=launch_server, args=(camera_port, camera_id, args))
|
||||
)
|
||||
camera_port += 1
|
||||
|
||||
for server in camera_servers:
|
||||
server.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
94
experiments/launch_nodes.py
Normal file
94
experiments/launch_nodes.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import tyro
|
||||
|
||||
from gello.robots.robot import BimanualRobot, PrintRobot
|
||||
from gello.zmq_core.robot_node import ZMQServerRobot
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
robot: str = "xarm"
|
||||
robot_port: int = 6001
|
||||
hostname: str = "127.0.0.1"
|
||||
robot_ip: str = "192.168.1.10"
|
||||
|
||||
|
||||
def launch_robot_server(args: Args):
|
||||
port = args.robot_port
|
||||
if args.robot == "sim_ur":
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
elif args.robot == "sim_panda":
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
|
||||
gripper_xml = None
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
elif args.robot == "sim_xarm":
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "ufactory_xarm7" / "xarm7.xml"
|
||||
gripper_xml = None
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
|
||||
else:
|
||||
if args.robot == "xarm":
|
||||
from gello.robots.xarm_robot import XArmRobot
|
||||
|
||||
robot = XArmRobot(ip=args.robot_ip)
|
||||
elif args.robot == "ur":
|
||||
from gello.robots.ur import URRobot
|
||||
|
||||
robot = URRobot(robot_ip=args.robot_ip)
|
||||
elif args.robot == "panda":
|
||||
from gello.robots.panda import PandaRobot
|
||||
|
||||
robot = PandaRobot(robot_ip=args.robot_ip)
|
||||
elif args.robot == "bimanual_ur":
|
||||
from gello.robots.ur import URRobot
|
||||
|
||||
# IP for the bimanual robot setup is hardcoded
|
||||
_robot_l = URRobot(robot_ip="192.168.2.10")
|
||||
_robot_r = URRobot(robot_ip="192.168.1.10")
|
||||
robot = BimanualRobot(_robot_l, _robot_r)
|
||||
elif args.robot == "none" or args.robot == "print":
|
||||
robot = PrintRobot(8)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
|
||||
)
|
||||
server = ZMQServerRobot(robot, port=port, host=args.hostname)
|
||||
print(f"Starting robot server on port {port}")
|
||||
server.serve()
|
||||
|
||||
|
||||
def main(args):
|
||||
launch_robot_server(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
192
experiments/quick_run.py
Normal file
192
experiments/quick_run.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
import atexit
|
||||
import glob
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.agents.agent import DummyAgent
|
||||
from gello.agents.gello_agent import GelloAgent
|
||||
from gello.agents.spacemouse_agent import SpacemouseAgent
|
||||
from gello.env import RobotEnv
|
||||
from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
hz: int = 100
|
||||
|
||||
agent: str = "gello"
|
||||
robot: str = "ur5"
|
||||
gello_port: Optional[str] = None
|
||||
mock: bool = False
|
||||
verbose: bool = False
|
||||
|
||||
hostname: str = "127.0.0.1"
|
||||
robot_port: int = 6001
|
||||
|
||||
|
||||
def launch_robot_server(port: int, args: Args):
|
||||
if args.robot == "sim_ur":
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
elif args.robot == "sim_panda":
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
|
||||
gripper_xml = None
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
|
||||
else:
|
||||
if args.robot == "xarm":
|
||||
from gello.robots.xarm_robot import XArmRobot
|
||||
|
||||
robot = XArmRobot()
|
||||
elif args.robot == "ur5":
|
||||
from gello.robots.ur import URRobot
|
||||
|
||||
robot = URRobot(robot_ip=args.robot_ip)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
|
||||
)
|
||||
server = ZMQServerRobot(robot, port=port, host=args.hostname)
|
||||
print(f"Starting robot server on port {port}")
|
||||
server.serve()
|
||||
|
||||
|
||||
def start_robot_process(args: Args):
|
||||
process = Process(target=launch_robot_server, args=(args.robot_port, args))
|
||||
|
||||
# Function to kill the child process
|
||||
def kill_child_process(process):
|
||||
print("Killing child process...")
|
||||
process.terminate()
|
||||
|
||||
# Register the kill_child_process function to be called at exit
|
||||
atexit.register(kill_child_process, process)
|
||||
process.start()
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
start_robot_process(args)
|
||||
|
||||
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
|
||||
env = RobotEnv(robot_client, control_rate_hz=args.hz)
|
||||
|
||||
if args.agent == "gello":
|
||||
gello_port = args.gello_port
|
||||
if gello_port is None:
|
||||
usb_ports = glob.glob("/dev/serial/by-id/*")
|
||||
print(f"Found {len(usb_ports)} ports")
|
||||
if len(usb_ports) > 0:
|
||||
gello_port = usb_ports[0]
|
||||
print(f"using port {gello_port}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"No gello port found, please specify one or plug in gello"
|
||||
)
|
||||
agent = GelloAgent(port=gello_port)
|
||||
|
||||
reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
|
||||
curr_joints = env.get_obs()["joint_positions"]
|
||||
if reset_joints.shape == curr_joints.shape:
|
||||
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
||||
steps = min(int(max_delta / 0.01), 100)
|
||||
|
||||
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
||||
env.step(jnt)
|
||||
time.sleep(0.001)
|
||||
|
||||
elif args.agent == "quest":
|
||||
from gello.agents.quest_agent import SingleArmQuestAgent
|
||||
|
||||
agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
|
||||
elif args.agent == "spacemouse":
|
||||
agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
|
||||
elif args.agent == "dummy" or args.agent == "none":
|
||||
agent = DummyAgent(num_dofs=robot_client.num_dofs())
|
||||
else:
|
||||
raise ValueError("Invalid agent name")
|
||||
|
||||
# going to start position
|
||||
print("Going to start position")
|
||||
start_pos = agent.act(env.get_obs())
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
|
||||
abs_deltas = np.abs(start_pos - joints)
|
||||
id_max_joint_delta = np.argmax(abs_deltas)
|
||||
|
||||
max_joint_delta = 0.8
|
||||
if abs_deltas[id_max_joint_delta] > max_joint_delta:
|
||||
id_mask = abs_deltas > max_joint_delta
|
||||
print()
|
||||
ids = np.arange(len(id_mask))[id_mask]
|
||||
for i, delta, joint, current_j in zip(
|
||||
ids,
|
||||
abs_deltas[id_mask],
|
||||
start_pos[id_mask],
|
||||
joints[id_mask],
|
||||
):
|
||||
print(
|
||||
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
|
||||
assert len(start_pos) == len(
|
||||
joints
|
||||
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
|
||||
|
||||
max_delta = 0.05
|
||||
for _ in range(25):
|
||||
obs = env.get_obs()
|
||||
command_joints = agent.act(obs)
|
||||
current_joints = obs["joint_positions"]
|
||||
delta = command_joints - current_joints
|
||||
max_joint_delta = np.abs(delta).max()
|
||||
if max_joint_delta > max_delta:
|
||||
delta = delta / max_joint_delta * max_delta
|
||||
env.step(current_joints + delta)
|
||||
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
action = agent.act(obs)
|
||||
if (action - joints > 0.5).any():
|
||||
print("Action is too big")
|
||||
|
||||
# print which joints are too big
|
||||
joint_index = np.where(action - joints > 0.8)
|
||||
for j in joint_index:
|
||||
print(
|
||||
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
while True:
|
||||
action = agent.act(obs)
|
||||
obs = env.step(action)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
246
experiments/run_env.py
Normal file
246
experiments/run_env.py
Normal file
|
@ -0,0 +1,246 @@
|
|||
import datetime
|
||||
import glob
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.agents.agent import BimanualAgent, DummyAgent
|
||||
from gello.agents.gello_agent import GelloAgent
|
||||
from gello.data_utils.format_obs import save_frame
|
||||
from gello.env import RobotEnv
|
||||
from gello.robots.robot import PrintRobot
|
||||
from gello.zmq_core.robot_node import ZMQClientRobot
|
||||
|
||||
|
||||
def print_color(*args, color=None, attrs=(), **kwargs):
|
||||
import termcolor
|
||||
|
||||
if len(args) > 0:
|
||||
args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
agent: str = "none"
|
||||
robot_port: int = 6001
|
||||
wrist_camera_port: int = 5000
|
||||
base_camera_port: int = 5001
|
||||
hostname: str = "127.0.0.1"
|
||||
robot_type: str = None # only needed for quest agent or spacemouse agent
|
||||
hz: int = 100
|
||||
start_joints: Optional[Tuple[float, ...]] = None
|
||||
|
||||
gello_port: Optional[str] = None
|
||||
mock: bool = False
|
||||
use_save_interface: bool = False
|
||||
data_dir: str = "~/bc_data"
|
||||
bimanual: bool = False
|
||||
verbose: bool = False
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mock:
|
||||
robot_client = PrintRobot(8, dont_print=True)
|
||||
camera_clients = {}
|
||||
else:
|
||||
camera_clients = {
|
||||
# you can optionally add camera nodes here for imitation learning purposes
|
||||
# "wrist": ZMQClientCamera(port=args.wrist_camera_port, host=args.hostname),
|
||||
# "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
|
||||
}
|
||||
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
|
||||
env = RobotEnv(robot_client, control_rate_hz=args.hz, camera_dict=camera_clients)
|
||||
|
||||
if args.bimanual:
|
||||
if args.agent == "gello":
|
||||
# dynamixel control box port map (to distinguish left and right gello)
|
||||
right = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0"
|
||||
left = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0"
|
||||
left_agent = GelloAgent(port=left)
|
||||
right_agent = GelloAgent(port=right)
|
||||
agent = BimanualAgent(left_agent, right_agent)
|
||||
elif args.agent == "quest":
|
||||
from gello.agents.quest_agent import SingleArmQuestAgent
|
||||
|
||||
left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
|
||||
right_agent = SingleArmQuestAgent(
|
||||
robot_type=args.robot_type, which_hand="r"
|
||||
)
|
||||
agent = BimanualAgent(left_agent, right_agent)
|
||||
# raise NotImplementedError
|
||||
elif args.agent == "spacemouse":
|
||||
from gello.agents.spacemouse_agent import SpacemouseAgent
|
||||
|
||||
left_path = "/dev/hidraw0"
|
||||
right_path = "/dev/hidraw1"
|
||||
left_agent = SpacemouseAgent(
|
||||
robot_type=args.robot_type, device_path=left_path, verbose=args.verbose
|
||||
)
|
||||
right_agent = SpacemouseAgent(
|
||||
robot_type=args.robot_type,
|
||||
device_path=right_path,
|
||||
verbose=args.verbose,
|
||||
invert_button=True,
|
||||
)
|
||||
agent = BimanualAgent(left_agent, right_agent)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent name for bimanual: {args.agent}")
|
||||
|
||||
# System setup specific. This reset configuration works well on our setup. If you are mounting the robot
|
||||
# differently, you need a separate reset joint configuration.
|
||||
reset_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0, 0])
|
||||
reset_joints_right = np.deg2rad([0, -90, 90, -90, -90, 0, 0])
|
||||
reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
|
||||
curr_joints = env.get_obs()["joint_positions"]
|
||||
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
||||
steps = min(int(max_delta / 0.01), 100)
|
||||
|
||||
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
||||
env.step(jnt)
|
||||
else:
|
||||
if args.agent == "gello":
|
||||
gello_port = args.gello_port
|
||||
if gello_port is None:
|
||||
usb_ports = glob.glob("/dev/serial/by-id/*")
|
||||
print(f"Found {len(usb_ports)} ports")
|
||||
if len(usb_ports) > 0:
|
||||
gello_port = usb_ports[0]
|
||||
print(f"using port {gello_port}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"No gello port found, please specify one or plug in gello"
|
||||
)
|
||||
if args.start_joints is None:
|
||||
reset_joints = np.deg2rad(
|
||||
[0, -90, 90, -90, -90, 0, 0]
|
||||
) # Change this to your own reset joints
|
||||
else:
|
||||
reset_joints = args.start_joints
|
||||
agent = GelloAgent(port=gello_port, start_joints=args.start_joints)
|
||||
curr_joints = env.get_obs()["joint_positions"]
|
||||
if reset_joints.shape == curr_joints.shape:
|
||||
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
||||
steps = min(int(max_delta / 0.01), 100)
|
||||
|
||||
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
||||
env.step(jnt)
|
||||
time.sleep(0.001)
|
||||
elif args.agent == "quest":
|
||||
from gello.agents.quest_agent import SingleArmQuestAgent
|
||||
|
||||
agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
|
||||
elif args.agent == "spacemouse":
|
||||
from gello.agents.spacemouse_agent import SpacemouseAgent
|
||||
|
||||
agent = SpacemouseAgent(robot_type=args.robot_type, verbose=args.verbose)
|
||||
elif args.agent == "dummy" or args.agent == "none":
|
||||
agent = DummyAgent(num_dofs=robot_client.num_dofs())
|
||||
elif args.agent == "policy":
|
||||
raise NotImplementedError("add your imitation policy here if there is one")
|
||||
else:
|
||||
raise ValueError("Invalid agent name")
|
||||
|
||||
# going to start position
|
||||
print("Going to start position")
|
||||
start_pos = agent.act(env.get_obs())
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
|
||||
abs_deltas = np.abs(start_pos - joints)
|
||||
id_max_joint_delta = np.argmax(abs_deltas)
|
||||
|
||||
max_joint_delta = 0.8
|
||||
if abs_deltas[id_max_joint_delta] > max_joint_delta:
|
||||
id_mask = abs_deltas > max_joint_delta
|
||||
print()
|
||||
ids = np.arange(len(id_mask))[id_mask]
|
||||
for i, delta, joint, current_j in zip(
|
||||
ids,
|
||||
abs_deltas[id_mask],
|
||||
start_pos[id_mask],
|
||||
joints[id_mask],
|
||||
):
|
||||
print(
|
||||
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
|
||||
assert len(start_pos) == len(
|
||||
joints
|
||||
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
|
||||
|
||||
max_delta = 0.05
|
||||
for _ in range(25):
|
||||
obs = env.get_obs()
|
||||
command_joints = agent.act(obs)
|
||||
current_joints = obs["joint_positions"]
|
||||
delta = command_joints - current_joints
|
||||
max_joint_delta = np.abs(delta).max()
|
||||
if max_joint_delta > max_delta:
|
||||
delta = delta / max_joint_delta * max_delta
|
||||
env.step(current_joints + delta)
|
||||
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
action = agent.act(obs)
|
||||
if (action - joints > 0.5).any():
|
||||
print("Action is too big")
|
||||
|
||||
# print which joints are too big
|
||||
joint_index = np.where(action - joints > 0.8)
|
||||
for j in joint_index:
|
||||
print(
|
||||
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
if args.use_save_interface:
|
||||
from gello.data_utils.keyboard_interface import KBReset
|
||||
|
||||
kb_interface = KBReset()
|
||||
|
||||
print_color("\nStart 🚀🚀🚀", color="green", attrs=("bold",))
|
||||
|
||||
save_path = None
|
||||
start_time = time.time()
|
||||
while True:
|
||||
num = time.time() - start_time
|
||||
message = f"\rTime passed: {round(num, 2)} "
|
||||
print_color(
|
||||
message,
|
||||
color="white",
|
||||
attrs=("bold",),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
action = agent.act(obs)
|
||||
dt = datetime.datetime.now()
|
||||
if args.use_save_interface:
|
||||
state = kb_interface.update()
|
||||
if state == "start":
|
||||
dt_time = datetime.datetime.now()
|
||||
save_path = (
|
||||
Path(args.data_dir).expanduser()
|
||||
/ args.agent
|
||||
/ dt_time.strftime("%m%d_%H%M%S")
|
||||
)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Saving to {save_path}")
|
||||
elif state == "save":
|
||||
assert save_path is not None, "something went wrong"
|
||||
save_frame(save_path, dt, obs, action)
|
||||
elif state == "normal":
|
||||
save_path = None
|
||||
else:
|
||||
raise ValueError(f"Invalid state {state}")
|
||||
obs = env.step(action)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
Loading…
Add table
Add a link
Reference in a new issue