initial commit, add gello software code and instructions
This commit is contained in:
parent
e7d842ad35
commit
18cc23a38e
70 changed files with 5875 additions and 4 deletions
128
gello/robots/robot.py
Normal file
128
gello/robots/robot.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Dict, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Robot(Protocol):
|
||||
"""Robot protocol.
|
||||
|
||||
A protocol for a robot that can be controlled.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def num_dofs(self) -> int:
|
||||
"""Get the number of joints of the robot.
|
||||
|
||||
Returns:
|
||||
int: The number of joints of the robot.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
"""Get the current state of the leader robot.
|
||||
|
||||
Returns:
|
||||
T: The current state of the leader robot.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
"""Command the leader robot to a given state.
|
||||
|
||||
Args:
|
||||
joint_state (np.ndarray): The state to command the leader robot to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
"""Get the current observations of the robot.
|
||||
|
||||
This is to extract all the information that is available from the robot,
|
||||
such as joint positions, joint velocities, etc. This may also include
|
||||
information from additional sensors, such as cameras, force sensors, etc.
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: A dictionary of observations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PrintRobot(Robot):
|
||||
"""A robot that prints the commanded joint state."""
|
||||
|
||||
def __init__(self, num_dofs: int, dont_print: bool = False):
|
||||
self._num_dofs = num_dofs
|
||||
self._joint_state = np.zeros(num_dofs)
|
||||
self._dont_print = dont_print
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return self._num_dofs
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
return self._joint_state
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
assert len(joint_state) == (self._num_dofs), (
|
||||
f"Expected joint state of length {self._num_dofs}, "
|
||||
f"got {len(joint_state)}."
|
||||
)
|
||||
self._joint_state = joint_state
|
||||
if not self._dont_print:
|
||||
print(self._joint_state)
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
joint_state = self.get_joint_state()
|
||||
pos_quat = np.zeros(7)
|
||||
return {
|
||||
"joint_positions": joint_state,
|
||||
"joint_velocities": joint_state,
|
||||
"ee_pos_quat": pos_quat,
|
||||
"gripper_position": np.array(0),
|
||||
}
|
||||
|
||||
|
||||
class BimanualRobot(Robot):
|
||||
def __init__(self, robot_l: Robot, robot_r: Robot):
|
||||
self._robot_l = robot_l
|
||||
self._robot_r = robot_r
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return self._robot_l.num_dofs() + self._robot_r.num_dofs()
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
return np.concatenate(
|
||||
(self._robot_l.get_joint_state(), self._robot_r.get_joint_state())
|
||||
)
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()])
|
||||
self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :])
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
l_obs = self._robot_l.get_observations()
|
||||
r_obs = self._robot_r.get_observations()
|
||||
assert l_obs.keys() == r_obs.keys()
|
||||
return_obs = {}
|
||||
for k in l_obs.keys():
|
||||
try:
|
||||
return_obs[k] = np.concatenate((l_obs[k], r_obs[k]))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(k)
|
||||
print(l_obs[k])
|
||||
print(r_obs[k])
|
||||
raise RuntimeError()
|
||||
|
||||
return return_obs
|
||||
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue