gello_software/gello/zmq_core/robot_node.py

126 lines
4.2 KiB
Python
Raw Normal View History

import pickle
import threading
from typing import Any, Dict
import numpy as np
import zmq
from gello.robots.robot import Robot
DEFAULT_ROBOT_PORT = 6000
class ZMQServerRobot:
def __init__(
self,
robot: Robot,
port: int = DEFAULT_ROBOT_PORT,
host: str = "127.0.0.1",
):
self._robot = robot
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
addr = f"tcp://{host}:{port}"
debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}"
print(debug_message)
self._timout_message = f"Timeout in Robot Server, Robot: {robot}"
self._socket.bind(addr)
self._stop_event = threading.Event()
def serve(self) -> None:
"""Serve the leader robot state over ZMQ."""
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
while not self._stop_event.is_set():
try:
# Wait for next request from client
message = self._socket.recv()
request = pickle.loads(message)
# Call the appropriate method based on the request
method = request.get("method")
args = request.get("args", {})
result: Any
if method == "num_dofs":
result = self._robot.num_dofs()
elif method == "get_joint_state":
result = self._robot.get_joint_state()
elif method == "command_joint_state":
result = self._robot.command_joint_state(**args)
elif method == "get_observations":
result = self._robot.get_observations()
else:
result = {"error": "Invalid method"}
print(result)
raise NotImplementedError(
f"Invalid method: {method}, {args, result}"
)
self._socket.send(pickle.dumps(result))
except zmq.Again:
print(self._timout_message)
# Timeout occurred, check if the stop event is set
def stop(self) -> None:
"""Signal the server to stop serving."""
self._stop_event.set()
class ZMQClientRobot(Robot):
"""A class representing a ZMQ client for a leader robot."""
def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"):
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{port}")
def num_dofs(self) -> int:
"""Get the number of joints in the robot.
Returns:
int: The number of joints in the robot.
"""
request = {"method": "num_dofs"}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result
def get_joint_state(self) -> np.ndarray:
"""Get the current state of the leader robot.
Returns:
T: The current state of the leader robot.
"""
request = {"method": "get_joint_state"}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result
def command_joint_state(self, joint_state: np.ndarray) -> None:
"""Command the leader robot to the given state.
Args:
joint_state (T): The state to command the leader robot to.
"""
request = {
"method": "command_joint_state",
"args": {"joint_state": joint_state},
}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result
def get_observations(self) -> Dict[str, np.ndarray]:
"""Get the current observations of the leader robot.
Returns:
Dict[str, np.ndarray]: The current observations of the leader robot.
"""
request = {"method": "get_observations"}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result