256 lines
8.6 KiB
Python
256 lines
8.6 KiB
Python
import pickle
|
|
import threading
|
|
import time
|
|
from typing import Any, Dict, Optional
|
|
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import numpy as np
|
|
import zmq
|
|
from dm_control import mjcf
|
|
|
|
from gello.robots.robot import Robot
|
|
|
|
assert mujoco.viewer is mujoco.viewer
|
|
|
|
|
|
def attach_hand_to_arm(
|
|
arm_mjcf: mjcf.RootElement,
|
|
hand_mjcf: mjcf.RootElement,
|
|
) -> None:
|
|
"""Attaches a hand to an arm.
|
|
|
|
The arm must have a site named "attachment_site".
|
|
|
|
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
|
|
|
|
Args:
|
|
arm_mjcf: The mjcf.RootElement of the arm.
|
|
hand_mjcf: The mjcf.RootElement of the hand.
|
|
|
|
Raises:
|
|
ValueError: If the arm does not have a site named "attachment_site".
|
|
"""
|
|
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
|
|
|
|
attachment_site = arm_mjcf.find("site", "attachment_site")
|
|
if attachment_site is None:
|
|
raise ValueError("No attachment site found in the arm model.")
|
|
|
|
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
|
|
arm_key = arm_mjcf.find("key", "home")
|
|
if arm_key is not None:
|
|
hand_key = hand_mjcf.find("key", "home")
|
|
if hand_key is None:
|
|
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
|
|
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
|
|
else:
|
|
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
|
|
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
|
|
|
|
attachment_site.attach(hand_mjcf)
|
|
|
|
|
|
def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None):
|
|
# assert robot_xml_path.endswith(".xml")
|
|
|
|
arena = mjcf.RootElement()
|
|
arm_simulate = mjcf.from_path(robot_xml_path)
|
|
# arm_copy = mjcf.from_path(xml_path)
|
|
|
|
if gripper_xml_path is not None:
|
|
# attach gripper to the robot at "attachment_site"
|
|
gripper_simulate = mjcf.from_path(gripper_xml_path)
|
|
attach_hand_to_arm(arm_simulate, gripper_simulate)
|
|
|
|
arena.worldbody.attach(arm_simulate)
|
|
# arena.worldbody.attach(arm_copy)
|
|
|
|
return arena
|
|
|
|
|
|
class ZMQServerThread(threading.Thread):
|
|
def __init__(self, server):
|
|
super().__init__()
|
|
self._server = server
|
|
|
|
def run(self):
|
|
self._server.serve()
|
|
|
|
def terminate(self):
|
|
self._server.stop()
|
|
|
|
|
|
class ZMQRobotServer:
|
|
"""A class representing a ZMQ server for a robot."""
|
|
|
|
def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556):
|
|
self._robot = robot
|
|
self._context = zmq.Context()
|
|
self._socket = self._context.socket(zmq.REP)
|
|
addr = f"tcp://{host}:{port}"
|
|
self._socket.bind(addr)
|
|
self._stop_event = threading.Event()
|
|
|
|
def serve(self) -> None:
|
|
"""Serve the robot state and commands over ZMQ."""
|
|
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
message = self._socket.recv()
|
|
request = pickle.loads(message)
|
|
|
|
# Call the appropriate method based on the request
|
|
method = request.get("method")
|
|
args = request.get("args", {})
|
|
result: Any
|
|
if method == "num_dofs":
|
|
result = self._robot.num_dofs()
|
|
elif method == "get_joint_state":
|
|
result = self._robot.get_joint_state()
|
|
elif method == "command_joint_state":
|
|
result = self._robot.command_joint_state(**args)
|
|
elif method == "get_observations":
|
|
result = self._robot.get_observations()
|
|
else:
|
|
result = {"error": "Invalid method"}
|
|
print(result)
|
|
raise NotImplementedError(
|
|
f"Invalid method: {method}, {args, result}"
|
|
)
|
|
|
|
self._socket.send(pickle.dumps(result))
|
|
except zmq.error.Again:
|
|
print("Timeout in ZMQLeaderServer serve")
|
|
# Timeout occurred, check if the stop event is set
|
|
|
|
def stop(self) -> None:
|
|
self._stop_event.set()
|
|
self._socket.close()
|
|
self._context.term()
|
|
|
|
|
|
class MujocoRobotServer:
|
|
def __init__(
|
|
self,
|
|
xml_path: str,
|
|
gripper_xml_path: Optional[str] = None,
|
|
host: str = "127.0.0.1",
|
|
port: int = 5556,
|
|
print_joints: bool = False,
|
|
):
|
|
self._has_gripper = gripper_xml_path is not None
|
|
arena = build_scene(xml_path, gripper_xml_path)
|
|
|
|
assets: Dict[str, str] = {}
|
|
for asset in arena.asset.all_children():
|
|
if asset.tag == "mesh":
|
|
f = asset.file
|
|
assets[f.get_vfs_filename()] = asset.file.contents
|
|
|
|
xml_string = arena.to_xml_string()
|
|
# save xml_string to file
|
|
with open("arena.xml", "w") as f:
|
|
f.write(xml_string)
|
|
|
|
self._model = mujoco.MjModel.from_xml_string(xml_string, assets)
|
|
self._data = mujoco.MjData(self._model)
|
|
|
|
self._num_joints = self._model.nu
|
|
|
|
self._joint_state = np.zeros(self._num_joints)
|
|
self._joint_cmd = self._joint_state
|
|
|
|
self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port)
|
|
self._zmq_server_thread = ZMQServerThread(self._zmq_server)
|
|
|
|
self._print_joints = print_joints
|
|
|
|
def num_dofs(self) -> int:
|
|
return self._num_joints
|
|
|
|
def get_joint_state(self) -> np.ndarray:
|
|
return self._joint_state
|
|
|
|
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
|
assert len(joint_state) == self._num_joints, (
|
|
f"Expected joint state of length {self._num_joints}, "
|
|
f"got {len(joint_state)}."
|
|
)
|
|
if self._has_gripper:
|
|
_joint_state = joint_state.copy()
|
|
_joint_state[-1] = _joint_state[-1] * 255
|
|
self._joint_cmd = _joint_state
|
|
else:
|
|
self._joint_cmd = joint_state.copy()
|
|
|
|
def freedrive_enabled(self) -> bool:
|
|
return True
|
|
|
|
def set_freedrive_mode(self, enable: bool):
|
|
pass
|
|
|
|
def get_observations(self) -> Dict[str, np.ndarray]:
|
|
joint_positions = self._data.qpos.copy()[: self._num_joints]
|
|
joint_velocities = self._data.qvel.copy()[: self._num_joints]
|
|
ee_site = "attachment_site"
|
|
try:
|
|
ee_pos = self._data.site_xpos.copy()[
|
|
mujoco.mj_name2id(self._model, 6, ee_site)
|
|
]
|
|
ee_mat = self._data.site_xmat.copy()[
|
|
mujoco.mj_name2id(self._model, 6, ee_site)
|
|
]
|
|
ee_quat = np.zeros(4)
|
|
mujoco.mju_mat2Quat(ee_quat, ee_mat)
|
|
except Exception:
|
|
ee_pos = np.zeros(3)
|
|
ee_quat = np.zeros(4)
|
|
ee_quat[0] = 1
|
|
gripper_pos = self._data.qpos.copy()[self._num_joints - 1]
|
|
return {
|
|
"joint_positions": joint_positions,
|
|
"joint_velocities": joint_velocities,
|
|
"ee_pos_quat": np.concatenate([ee_pos, ee_quat]),
|
|
"gripper_position": gripper_pos,
|
|
}
|
|
|
|
def serve(self) -> None:
|
|
# start the zmq server
|
|
self._zmq_server_thread.start()
|
|
with mujoco.viewer.launch_passive(self._model, self._data) as viewer:
|
|
while viewer.is_running():
|
|
step_start = time.time()
|
|
|
|
# mj_step can be replaced with code that also evaluates
|
|
# a policy and applies a control signal before stepping the physics.
|
|
self._data.ctrl[:] = self._joint_cmd
|
|
# self._data.qpos[:] = self._joint_cmd
|
|
mujoco.mj_step(self._model, self._data)
|
|
self._joint_state = self._data.qpos.copy()[: self._num_joints]
|
|
|
|
if self._print_joints:
|
|
print(self._joint_state)
|
|
|
|
# Example modification of a viewer option: toggle contact points every two seconds.
|
|
with viewer.lock():
|
|
# TODO remove?
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(
|
|
self._data.time % 2
|
|
)
|
|
|
|
# Pick up changes to the physics state, apply perturbations, update options from GUI.
|
|
viewer.sync()
|
|
|
|
# Rudimentary time keeping, will drift relative to wall clock.
|
|
time_until_next_step = self._model.opt.timestep - (
|
|
time.time() - step_start
|
|
)
|
|
if time_until_next_step > 0:
|
|
time.sleep(time_until_next_step)
|
|
|
|
def stop(self) -> None:
|
|
self._zmq_server_thread.join()
|
|
|
|
def __del__(self) -> None:
|
|
self.stop()
|