gello_software/gello/robots/sim_robot.py

257 lines
8.6 KiB
Python
Raw Normal View History

import pickle
import threading
import time
from typing import Any, Dict, Optional
import mujoco
import mujoco.viewer
import numpy as np
import zmq
from dm_control import mjcf
from gello.robots.robot import Robot
assert mujoco.viewer is mujoco.viewer
def attach_hand_to_arm(
arm_mjcf: mjcf.RootElement,
hand_mjcf: mjcf.RootElement,
) -> None:
"""Attaches a hand to an arm.
The arm must have a site named "attachment_site".
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
Args:
arm_mjcf: The mjcf.RootElement of the arm.
hand_mjcf: The mjcf.RootElement of the hand.
Raises:
ValueError: If the arm does not have a site named "attachment_site".
"""
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
attachment_site = arm_mjcf.find("site", "attachment_site")
if attachment_site is None:
raise ValueError("No attachment site found in the arm model.")
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
arm_key = arm_mjcf.find("key", "home")
if arm_key is not None:
hand_key = hand_mjcf.find("key", "home")
if hand_key is None:
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
else:
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
attachment_site.attach(hand_mjcf)
def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None):
# assert robot_xml_path.endswith(".xml")
arena = mjcf.RootElement()
arm_simulate = mjcf.from_path(robot_xml_path)
# arm_copy = mjcf.from_path(xml_path)
if gripper_xml_path is not None:
# attach gripper to the robot at "attachment_site"
gripper_simulate = mjcf.from_path(gripper_xml_path)
attach_hand_to_arm(arm_simulate, gripper_simulate)
arena.worldbody.attach(arm_simulate)
# arena.worldbody.attach(arm_copy)
return arena
class ZMQServerThread(threading.Thread):
def __init__(self, server):
super().__init__()
self._server = server
def run(self):
self._server.serve()
def terminate(self):
self._server.stop()
class ZMQRobotServer:
"""A class representing a ZMQ server for a robot."""
def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556):
self._robot = robot
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
addr = f"tcp://{host}:{port}"
self._socket.bind(addr)
self._stop_event = threading.Event()
def serve(self) -> None:
"""Serve the robot state and commands over ZMQ."""
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
while not self._stop_event.is_set():
try:
message = self._socket.recv()
request = pickle.loads(message)
# Call the appropriate method based on the request
method = request.get("method")
args = request.get("args", {})
result: Any
if method == "num_dofs":
result = self._robot.num_dofs()
elif method == "get_joint_state":
result = self._robot.get_joint_state()
elif method == "command_joint_state":
result = self._robot.command_joint_state(**args)
elif method == "get_observations":
result = self._robot.get_observations()
else:
result = {"error": "Invalid method"}
print(result)
raise NotImplementedError(
f"Invalid method: {method}, {args, result}"
)
self._socket.send(pickle.dumps(result))
except zmq.error.Again:
print("Timeout in ZMQLeaderServer serve")
# Timeout occurred, check if the stop event is set
def stop(self) -> None:
self._stop_event.set()
self._socket.close()
self._context.term()
class MujocoRobotServer:
def __init__(
self,
xml_path: str,
gripper_xml_path: Optional[str] = None,
host: str = "127.0.0.1",
port: int = 5556,
print_joints: bool = False,
):
self._has_gripper = gripper_xml_path is not None
arena = build_scene(xml_path, gripper_xml_path)
assets: Dict[str, str] = {}
for asset in arena.asset.all_children():
if asset.tag == "mesh":
f = asset.file
assets[f.get_vfs_filename()] = asset.file.contents
xml_string = arena.to_xml_string()
# save xml_string to file
with open("arena.xml", "w") as f:
f.write(xml_string)
self._model = mujoco.MjModel.from_xml_string(xml_string, assets)
self._data = mujoco.MjData(self._model)
self._num_joints = self._model.nu
self._joint_state = np.zeros(self._num_joints)
self._joint_cmd = self._joint_state
self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port)
self._zmq_server_thread = ZMQServerThread(self._zmq_server)
self._print_joints = print_joints
def num_dofs(self) -> int:
return self._num_joints
def get_joint_state(self) -> np.ndarray:
return self._joint_state
def command_joint_state(self, joint_state: np.ndarray) -> None:
assert len(joint_state) == self._num_joints, (
f"Expected joint state of length {self._num_joints}, "
f"got {len(joint_state)}."
)
if self._has_gripper:
_joint_state = joint_state.copy()
_joint_state[-1] = _joint_state[-1] * 255
self._joint_cmd = _joint_state
else:
self._joint_cmd = joint_state.copy()
def freedrive_enabled(self) -> bool:
return True
def set_freedrive_mode(self, enable: bool):
pass
def get_observations(self) -> Dict[str, np.ndarray]:
joint_positions = self._data.qpos.copy()[: self._num_joints]
joint_velocities = self._data.qvel.copy()[: self._num_joints]
ee_site = "attachment_site"
try:
ee_pos = self._data.site_xpos.copy()[
mujoco.mj_name2id(self._model, 6, ee_site)
]
ee_mat = self._data.site_xmat.copy()[
mujoco.mj_name2id(self._model, 6, ee_site)
]
ee_quat = np.zeros(4)
mujoco.mju_mat2Quat(ee_quat, ee_mat)
except Exception:
ee_pos = np.zeros(3)
ee_quat = np.zeros(4)
ee_quat[0] = 1
gripper_pos = self._data.qpos.copy()[self._num_joints - 1]
return {
"joint_positions": joint_positions,
"joint_velocities": joint_velocities,
"ee_pos_quat": np.concatenate([ee_pos, ee_quat]),
"gripper_position": gripper_pos,
}
def serve(self) -> None:
# start the zmq server
self._zmq_server_thread.start()
with mujoco.viewer.launch_passive(self._model, self._data) as viewer:
while viewer.is_running():
step_start = time.time()
# mj_step can be replaced with code that also evaluates
# a policy and applies a control signal before stepping the physics.
self._data.ctrl[:] = self._joint_cmd
# self._data.qpos[:] = self._joint_cmd
mujoco.mj_step(self._model, self._data)
self._joint_state = self._data.qpos.copy()[: self._num_joints]
if self._print_joints:
print(self._joint_state)
# Example modification of a viewer option: toggle contact points every two seconds.
with viewer.lock():
# TODO remove?
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(
self._data.time % 2
)
# Pick up changes to the physics state, apply perturbations, update options from GUI.
viewer.sync()
# Rudimentary time keeping, will drift relative to wall clock.
time_until_next_step = self._model.opt.timestep - (
time.time() - step_start
)
if time_until_next_step > 0:
time.sleep(time_until_next_step)
def stop(self) -> None:
self._zmq_server_thread.join()
def __del__(self) -> None:
self.stop()