gello_software/gello/robots/sim_robot.py

256 lines
8.6 KiB
Python

import pickle
import threading
import time
from typing import Any, Dict, Optional
import mujoco
import mujoco.viewer
import numpy as np
import zmq
from dm_control import mjcf
from gello.robots.robot import Robot
assert mujoco.viewer is mujoco.viewer
def attach_hand_to_arm(
arm_mjcf: mjcf.RootElement,
hand_mjcf: mjcf.RootElement,
) -> None:
"""Attaches a hand to an arm.
The arm must have a site named "attachment_site".
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
Args:
arm_mjcf: The mjcf.RootElement of the arm.
hand_mjcf: The mjcf.RootElement of the hand.
Raises:
ValueError: If the arm does not have a site named "attachment_site".
"""
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
attachment_site = arm_mjcf.find("site", "attachment_site")
if attachment_site is None:
raise ValueError("No attachment site found in the arm model.")
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
arm_key = arm_mjcf.find("key", "home")
if arm_key is not None:
hand_key = hand_mjcf.find("key", "home")
if hand_key is None:
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
else:
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
attachment_site.attach(hand_mjcf)
def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None):
# assert robot_xml_path.endswith(".xml")
arena = mjcf.RootElement()
arm_simulate = mjcf.from_path(robot_xml_path)
# arm_copy = mjcf.from_path(xml_path)
if gripper_xml_path is not None:
# attach gripper to the robot at "attachment_site"
gripper_simulate = mjcf.from_path(gripper_xml_path)
attach_hand_to_arm(arm_simulate, gripper_simulate)
arena.worldbody.attach(arm_simulate)
# arena.worldbody.attach(arm_copy)
return arena
class ZMQServerThread(threading.Thread):
def __init__(self, server):
super().__init__()
self._server = server
def run(self):
self._server.serve()
def terminate(self):
self._server.stop()
class ZMQRobotServer:
"""A class representing a ZMQ server for a robot."""
def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556):
self._robot = robot
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
addr = f"tcp://{host}:{port}"
self._socket.bind(addr)
self._stop_event = threading.Event()
def serve(self) -> None:
"""Serve the robot state and commands over ZMQ."""
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
while not self._stop_event.is_set():
try:
message = self._socket.recv()
request = pickle.loads(message)
# Call the appropriate method based on the request
method = request.get("method")
args = request.get("args", {})
result: Any
if method == "num_dofs":
result = self._robot.num_dofs()
elif method == "get_joint_state":
result = self._robot.get_joint_state()
elif method == "command_joint_state":
result = self._robot.command_joint_state(**args)
elif method == "get_observations":
result = self._robot.get_observations()
else:
result = {"error": "Invalid method"}
print(result)
raise NotImplementedError(
f"Invalid method: {method}, {args, result}"
)
self._socket.send(pickle.dumps(result))
except zmq.error.Again:
print("Timeout in ZMQLeaderServer serve")
# Timeout occurred, check if the stop event is set
def stop(self) -> None:
self._stop_event.set()
self._socket.close()
self._context.term()
class MujocoRobotServer:
def __init__(
self,
xml_path: str,
gripper_xml_path: Optional[str] = None,
host: str = "127.0.0.1",
port: int = 5556,
print_joints: bool = False,
):
self._has_gripper = gripper_xml_path is not None
arena = build_scene(xml_path, gripper_xml_path)
assets: Dict[str, str] = {}
for asset in arena.asset.all_children():
if asset.tag == "mesh":
f = asset.file
assets[f.get_vfs_filename()] = asset.file.contents
xml_string = arena.to_xml_string()
# save xml_string to file
with open("arena.xml", "w") as f:
f.write(xml_string)
self._model = mujoco.MjModel.from_xml_string(xml_string, assets)
self._data = mujoco.MjData(self._model)
self._num_joints = self._model.nu
self._joint_state = np.zeros(self._num_joints)
self._joint_cmd = self._joint_state
self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port)
self._zmq_server_thread = ZMQServerThread(self._zmq_server)
self._print_joints = print_joints
def num_dofs(self) -> int:
return self._num_joints
def get_joint_state(self) -> np.ndarray:
return self._joint_state
def command_joint_state(self, joint_state: np.ndarray) -> None:
assert len(joint_state) == self._num_joints, (
f"Expected joint state of length {self._num_joints}, "
f"got {len(joint_state)}."
)
if self._has_gripper:
_joint_state = joint_state.copy()
_joint_state[-1] = _joint_state[-1] * 255
self._joint_cmd = _joint_state
else:
self._joint_cmd = joint_state.copy()
def freedrive_enabled(self) -> bool:
return True
def set_freedrive_mode(self, enable: bool):
pass
def get_observations(self) -> Dict[str, np.ndarray]:
joint_positions = self._data.qpos.copy()[: self._num_joints]
joint_velocities = self._data.qvel.copy()[: self._num_joints]
ee_site = "attachment_site"
try:
ee_pos = self._data.site_xpos.copy()[
mujoco.mj_name2id(self._model, 6, ee_site)
]
ee_mat = self._data.site_xmat.copy()[
mujoco.mj_name2id(self._model, 6, ee_site)
]
ee_quat = np.zeros(4)
mujoco.mju_mat2Quat(ee_quat, ee_mat)
except Exception:
ee_pos = np.zeros(3)
ee_quat = np.zeros(4)
ee_quat[0] = 1
gripper_pos = self._data.qpos.copy()[self._num_joints - 1]
return {
"joint_positions": joint_positions,
"joint_velocities": joint_velocities,
"ee_pos_quat": np.concatenate([ee_pos, ee_quat]),
"gripper_position": gripper_pos,
}
def serve(self) -> None:
# start the zmq server
self._zmq_server_thread.start()
with mujoco.viewer.launch_passive(self._model, self._data) as viewer:
while viewer.is_running():
step_start = time.time()
# mj_step can be replaced with code that also evaluates
# a policy and applies a control signal before stepping the physics.
self._data.ctrl[:] = self._joint_cmd
# self._data.qpos[:] = self._joint_cmd
mujoco.mj_step(self._model, self._data)
self._joint_state = self._data.qpos.copy()[: self._num_joints]
if self._print_joints:
print(self._joint_state)
# Example modification of a viewer option: toggle contact points every two seconds.
with viewer.lock():
# TODO remove?
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(
self._data.time % 2
)
# Pick up changes to the physics state, apply perturbations, update options from GUI.
viewer.sync()
# Rudimentary time keeping, will drift relative to wall clock.
time_until_next_step = self._model.opt.timestep - (
time.time() - step_start
)
if time_until_next_step > 0:
time.sleep(time_until_next_step)
def stop(self) -> None:
self._zmq_server_thread.join()
def __del__(self) -> None:
self.stop()