128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
from abc import abstractmethod
|
|
from typing import Dict, Protocol
|
|
|
|
import numpy as np
|
|
|
|
|
|
class Robot(Protocol):
|
|
"""Robot protocol.
|
|
|
|
A protocol for a robot that can be controlled.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def num_dofs(self) -> int:
|
|
"""Get the number of joints of the robot.
|
|
|
|
Returns:
|
|
int: The number of joints of the robot.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_joint_state(self) -> np.ndarray:
|
|
"""Get the current state of the leader robot.
|
|
|
|
Returns:
|
|
T: The current state of the leader robot.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
|
"""Command the leader robot to a given state.
|
|
|
|
Args:
|
|
joint_state (np.ndarray): The state to command the leader robot to.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_observations(self) -> Dict[str, np.ndarray]:
|
|
"""Get the current observations of the robot.
|
|
|
|
This is to extract all the information that is available from the robot,
|
|
such as joint positions, joint velocities, etc. This may also include
|
|
information from additional sensors, such as cameras, force sensors, etc.
|
|
|
|
Returns:
|
|
Dict[str, np.ndarray]: A dictionary of observations.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class PrintRobot(Robot):
|
|
"""A robot that prints the commanded joint state."""
|
|
|
|
def __init__(self, num_dofs: int, dont_print: bool = False):
|
|
self._num_dofs = num_dofs
|
|
self._joint_state = np.zeros(num_dofs)
|
|
self._dont_print = dont_print
|
|
|
|
def num_dofs(self) -> int:
|
|
return self._num_dofs
|
|
|
|
def get_joint_state(self) -> np.ndarray:
|
|
return self._joint_state
|
|
|
|
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
|
assert len(joint_state) == (self._num_dofs), (
|
|
f"Expected joint state of length {self._num_dofs}, "
|
|
f"got {len(joint_state)}."
|
|
)
|
|
self._joint_state = joint_state
|
|
if not self._dont_print:
|
|
print(self._joint_state)
|
|
|
|
def get_observations(self) -> Dict[str, np.ndarray]:
|
|
joint_state = self.get_joint_state()
|
|
pos_quat = np.zeros(7)
|
|
return {
|
|
"joint_positions": joint_state,
|
|
"joint_velocities": joint_state,
|
|
"ee_pos_quat": pos_quat,
|
|
"gripper_position": np.array(0),
|
|
}
|
|
|
|
|
|
class BimanualRobot(Robot):
|
|
def __init__(self, robot_l: Robot, robot_r: Robot):
|
|
self._robot_l = robot_l
|
|
self._robot_r = robot_r
|
|
|
|
def num_dofs(self) -> int:
|
|
return self._robot_l.num_dofs() + self._robot_r.num_dofs()
|
|
|
|
def get_joint_state(self) -> np.ndarray:
|
|
return np.concatenate(
|
|
(self._robot_l.get_joint_state(), self._robot_r.get_joint_state())
|
|
)
|
|
|
|
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
|
self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()])
|
|
self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :])
|
|
|
|
def get_observations(self) -> Dict[str, np.ndarray]:
|
|
l_obs = self._robot_l.get_observations()
|
|
r_obs = self._robot_r.get_observations()
|
|
assert l_obs.keys() == r_obs.keys()
|
|
return_obs = {}
|
|
for k in l_obs.keys():
|
|
try:
|
|
return_obs[k] = np.concatenate((l_obs[k], r_obs[k]))
|
|
except Exception as e:
|
|
print(e)
|
|
print(k)
|
|
print(l_obs[k])
|
|
print(r_obs[k])
|
|
raise RuntimeError()
|
|
|
|
return return_obs
|
|
|
|
|
|
def main():
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|