initial commit, add gello software code and instructions

This commit is contained in:
Philipp Wu 2023-11-13 09:17:27 -08:00
parent e7d842ad35
commit 18cc23a38e
70 changed files with 5875 additions and 4 deletions

View file

@ -0,0 +1,7 @@
from gello.dm_control_tasks.arms.franka import Franka
from gello.dm_control_tasks.arms.ur5e import UR5e
__all__ = [
"UR5e",
"Franka",
]

View file

@ -0,0 +1,28 @@
"""Franka composer class."""
from pathlib import Path
from typing import Optional, Union
from dm_control import mjcf
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
XML = MENAGERIE_ROOT / "franka_emika_panda" / "panda_nohand.xml"
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
class Franka(Manipulator):
"""Franka Robot."""
def _build(
self,
name: str = "franka",
xml_path: Union[str, Path] = XML,
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
) -> None:
super()._build(name="franka", xml_path=XML, gripper_xml_path=GRIPPER_XML)
@property
def flange(self) -> mjcf.Element:
return self._mjcf_root.find("site", "attachment_site")

View file

@ -0,0 +1,229 @@
"""Manipulator composer class."""
import abc
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from dm_control import composer, mjcf
from dm_control.composer.observation import observable
from dm_control.mujoco.wrapper import mjbindings
from dm_control.suite.utils.randomizers import random_limited_quaternion
from gello.dm_control_tasks import mjcf_utils
def attach_hand_to_arm(
arm_mjcf: mjcf.RootElement,
hand_mjcf: mjcf.RootElement,
# attach_site: str,
) -> None:
"""Attaches a hand to an arm.
The arm must have a site named "attachment_site".
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
Args:
arm_mjcf: The mjcf.RootElement of the arm.
hand_mjcf: The mjcf.RootElement of the hand.
attach_site: The name of the site to attach the hand to.
Raises:
ValueError: If the arm does not have a site named "attachment_site".
"""
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
# attachment_site = arm_mjcf.find("site", attach_site)
attachment_site = arm_mjcf.find("site", "attachment_site")
if attachment_site is None:
raise ValueError("No attachment site found in the arm model.")
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
arm_key = arm_mjcf.find("key", "home")
if arm_key is not None:
hand_key = hand_mjcf.find("key", "home")
if hand_key is None:
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
else:
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
attachment_site.attach(hand_mjcf)
class Manipulator(composer.Entity, abc.ABC):
"""A manipulator entity."""
def _build(
self,
name: str,
xml_path: Union[str, Path],
gripper_xml_path: Optional[Union[str, Path]],
) -> None:
"""Builds the manipulator.
Subclasses can not override this method, but should call this method in their
own _build() method.
"""
self._mjcf_root = mjcf.from_path(str(xml_path))
self._arm_joints = tuple(mjcf_utils.safe_find_all(self._mjcf_root, "joint"))
if gripper_xml_path:
gripper_mjcf = mjcf.from_path(str(gripper_xml_path))
attach_hand_to_arm(self._mjcf_root, gripper_mjcf)
self._mjcf_root.model = name
self._add_mjcf_elements()
def set_joints(self, physics: mjcf.Physics, joints: np.ndarray) -> None:
assert len(joints) == len(self._arm_joints)
for joint, joint_value in zip(self._arm_joints, joints):
joint_id = physics.bind(joint).element_id
joint_name = physics.model.id2name(joint_id, "joint")
physics.named.data.qpos[joint_name] = joint_value
def randomize_joints(
self,
physics: mjcf.Physics,
random: Optional[np.random.RandomState] = None,
) -> None:
random = random or np.random # type: ignore
assert random is not None
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
free = mjbindings.enums.mjtJoint.mjJNT_FREE
qpos = physics.named.data.qpos
for joint in self._arm_joints:
joint_id = physics.bind(joint).element_id
# joint_id = physics.model.name2id(joint.name, "joint")
joint_name = physics.model.id2name(joint_id, "joint")
joint_type = physics.model.jnt_type[joint_id]
is_limited = physics.model.jnt_limited[joint_id]
range_min, range_max = physics.model.jnt_range[joint_id]
if is_limited:
if joint_type in [hinge, slide]:
qpos[joint_name] = random.uniform(range_min, range_max)
elif joint_type == ball:
qpos[joint_name] = random_limited_quaternion(random, range_max)
else:
if joint_type == hinge:
qpos[joint_name] = random.uniform(-np.pi, np.pi)
elif joint_type == ball:
quat = random.randn(4)
quat /= np.linalg.norm(quat)
qpos[joint_name] = quat
elif joint_type == free:
# this should be random.randn, but changing it now could significantly
# affect benchmark results.
quat = random.rand(4)
quat /= np.linalg.norm(quat)
qpos[joint_name][3:] = quat
def _add_mjcf_elements(self) -> None:
# Parse joints.
joints = mjcf_utils.safe_find_all(self._mjcf_root, "joint")
joints = [joint for joint in joints if joint.tag != "freejoint"]
self._joints = tuple(joints)
# Parse actuators.
actuators = mjcf_utils.safe_find_all(self._mjcf_root, "actuator")
self._actuators = tuple(actuators)
# Parse qpos / ctrl keyframes.
self._keyframes = {}
keyframes = mjcf_utils.safe_find_all(self._mjcf_root, "key")
if keyframes:
for frame in keyframes:
if frame.qpos is not None:
qpos = np.array(frame.qpos)
self._keyframes[frame.name] = qpos
# add a visualizeation the flange position that is green
self.flange.parent.add(
"geom",
name="flange_geom",
type="sphere",
size="0.01",
rgba="0 1 0 1",
pos=self.flange.pos,
contype="0",
conaffinity="0",
)
def _build_observables(self):
return ArmObservables(self)
@property
@abc.abstractmethod
def flange(self) -> mjcf.Element:
"""Returns the flange element.
The flange is the end effector of the manipulator where tools can be
attached, such as a gripper.
"""
@property
def mjcf_model(self) -> mjcf.RootElement:
return self._mjcf_root
@property
def name(self) -> str:
return self._mjcf_root.model
@property
def joints(self) -> Tuple[mjcf.Element, ...]:
return self._joints
@property
def actuators(self) -> Tuple[mjcf.Element, ...]:
return self._actuators
@property
def keyframes(self) -> Dict[str, np.ndarray]:
return self._keyframes
class ArmObservables(composer.Observables):
"""Base class for quadruped observables."""
@composer.observable
def joints_pos(self):
return observable.MJCFFeature("qpos", self._entity.joints)
@composer.observable
def joints_vel(self):
return observable.MJCFFeature("qvel", self._entity.joints)
@composer.observable
def flange_position(self):
return observable.MJCFFeature("xpos", self._entity.flange)
@composer.observable
def flange_orientation(self):
return observable.MJCFFeature("xmat", self._entity.flange)
# Semantic grouping of observables.
def _collect_from_attachments(self, attribute_name: str):
out: List[observable.MJCFFeature] = []
for entity in self._entity.iter_entities(exclude_self=True):
out.extend(getattr(entity.observables, attribute_name, []))
return out
@property
def proprioception(self):
return [
self.joints_pos,
self.joints_vel,
self.flange_position,
# self.flange_orientation,
# self.flange_velocity,
self.flange_angular_velocity,
] + self._collect_from_attachments("proprioception")

View file

@ -0,0 +1,26 @@
"""UR5e composer class."""
from pathlib import Path
from typing import Optional, Union
from dm_control import mjcf
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
class UR5e(Manipulator):
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
XML = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
def _build(
self,
name: str = "UR5e",
xml_path: Union[str, Path] = XML,
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
) -> None:
super()._build(name=name, xml_path=xml_path, gripper_xml_path=gripper_xml_path)
@property
def flange(self) -> mjcf.Element:
return self._mjcf_root.find("site", "attachment_site")

View file

@ -0,0 +1,22 @@
"""Tests for ur5e.py."""
from absl.testing import absltest
from dm_control import mjcf
from gello.dm_control_tasks.arms import ur5e
class UR5eTest(absltest.TestCase):
def test_compiles_and_steps(self) -> None:
robot = ur5e.UR5e()
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
physics.step()
def test_joints(self) -> None:
robot = ur5e.UR5e()
for joint in robot.joints:
self.assertEqual(joint.tag, "joint")
if __name__ == "__main__":
absltest.main()

View file

@ -0,0 +1,261 @@
import collections
import numpy as np
from absl import logging
from dm_control.mujoco.wrapper import mjbindings
mjlib = mjbindings.mjlib
_INVALID_JOINT_NAMES_TYPE = (
"`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}."
)
_REQUIRE_TARGET_POS_OR_QUAT = (
"At least one of `target_pos` or `target_quat` must be specified."
)
IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"])
def qpos_from_site_pose(
physics,
site_name,
target_pos=None,
target_quat=None,
joint_names=None,
tol=1e-14,
rot_weight=1.0,
regularization_threshold=0.1,
regularization_strength=3e-2,
max_update_norm=2.0,
progress_thresh=20.0,
max_steps=100,
inplace=False,
):
"""Find joint positions that satisfy a target site position and/or rotation.
Args:
physics: A `mujoco.Physics` instance.
site_name: A string specifying the name of the target site.
target_pos: A (3,) numpy array specifying the desired Cartesian position of
the site, or None if the position should be unconstrained (default).
One or both of `target_pos` or `target_quat` must be specified.
target_quat: A (4,) numpy array specifying the desired orientation of the
site as a quaternion, or None if the orientation should be unconstrained
(default). One or both of `target_pos` or `target_quat` must be specified.
joint_names: (optional) A list, tuple or numpy array specifying the names of
one or more joints that can be manipulated in order to achieve the target
site pose. If None (default), all joints may be manipulated.
tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
in the stopping criterion).
rot_weight: (optional) Determines the weight given to rotational error
relative to translational error.
regularization_threshold: (optional) L2 regularization will be used when
inverting the Jacobian whilst `err_norm` is greater than this value.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements.
max_update_norm: (optional) The maximum L2 norm of the update applied to
the joint positions on each iteration. The update vector will be scaled
such that its magnitude never exceeds this value.
progress_thresh: (optional) If `err_norm` divided by the magnitude of the
joint position update is greater than this value then the optimization
will terminate prematurely. This is a useful heuristic to avoid getting
stuck in local minima.
max_steps: (optional) The maximum number of iterations to perform.
inplace: (optional) If True, `physics.data` will be modified in place.
Default value is False, i.e. a copy of `physics.data` will be made.
Returns:
An `IKResult` namedtuple with the following fields:
qpos: An (nq,) numpy array of joint positions.
err_norm: A float, the weighted sum of L2 norms for the residual
translational and rotational errors.
steps: An int, the number of iterations that were performed.
success: Boolean, True if we converged on a solution within `max_steps`,
False otherwise.
Raises:
ValueError: If both `target_pos` and `target_quat` are None, or if
`joint_names` has an invalid type.
"""
dtype = physics.data.qpos.dtype
if target_pos is not None and target_quat is not None:
jac = np.empty((6, physics.model.nv), dtype=dtype)
err = np.empty(6, dtype=dtype)
jac_pos, jac_rot = jac[:3], jac[3:]
err_pos, err_rot = err[:3], err[3:]
else:
jac = np.empty((3, physics.model.nv), dtype=dtype)
err = np.empty(3, dtype=dtype)
if target_pos is not None:
jac_pos, jac_rot = jac, None
err_pos, err_rot = err, None
elif target_quat is not None:
jac_pos, jac_rot = None, jac
err_pos, err_rot = None, err
else:
raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
update_nv = np.zeros(physics.model.nv, dtype=dtype)
if target_quat is not None:
site_xquat = np.empty(4, dtype=dtype)
neg_site_xquat = np.empty(4, dtype=dtype)
err_rot_quat = np.empty(4, dtype=dtype)
if not inplace:
physics = physics.copy(share_model=True)
# Ensure that the Cartesian position of the site is up to date.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
# Convert site name to index.
site_id = physics.model.name2id(site_name, "site")
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
# update them in place, so we can avoid indexing overhead in the main loop.
site_xpos = physics.named.data.site_xpos[site_name]
site_xmat = physics.named.data.site_xmat[site_name]
# This is an index into the rows of `update` and the columns of `jac`
# that selects DOFs associated with joints that we are allowed to manipulate.
if joint_names is None:
dof_indices = slice(None) # Update all DOFs.
elif isinstance(joint_names, (list, np.ndarray, tuple)):
if isinstance(joint_names, tuple):
joint_names = list(joint_names)
# Find the indices of the DOFs belonging to each named joint. Note that
# these are not necessarily the same as the joint IDs, since a single joint
# may have >1 DOF (e.g. ball joints).
indexer = physics.named.model.dof_jntid.axes.row
# `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
# indexer to map each joint name to the indices of its corresponding DOFs.
dof_indices = indexer.convert_key_item(joint_names)
else:
raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
steps = 0
success = False
for steps in range(max_steps):
err_norm = 0.0
if target_pos is not None:
# Translational error.
err_pos[:] = target_pos - site_xpos
err_norm += np.linalg.norm(err_pos)
if target_quat is not None:
# Rotational error.
mjlib.mju_mat2Quat(site_xquat, site_xmat)
mjlib.mju_negQuat(neg_site_xquat, site_xquat)
mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
err_norm += np.linalg.norm(err_rot) * rot_weight
if err_norm < tol:
logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm)
success = True
break
else:
# TODO(b/112141670): Generalize this to other entities besides sites.
mjlib.mj_jacSite(
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id
)
jac_joints = jac[:, dof_indices]
# TODO(b/112141592): This does not take joint limits into consideration.
reg_strength = (
regularization_strength if err_norm > regularization_threshold else 0.0
)
update_joints = nullspace_method(
jac_joints, err, regularization_strength=reg_strength
)
update_norm = np.linalg.norm(update_joints)
# Check whether we are still making enough progress, and halt if not.
progress_criterion = err_norm / update_norm
if progress_criterion > progress_thresh:
logging.debug(
"Step %2i: err_norm / update_norm (%3g) > "
"tolerance (%3g). Halting due to insufficient progress",
steps,
progress_criterion,
progress_thresh,
)
break
if update_norm > max_update_norm:
update_joints *= max_update_norm / update_norm
# Write the entries for the specified joints into the full `update_nv`
# vector.
update_nv[dof_indices] = update_joints
# Update `physics.qpos`, taking quaternions into account.
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
# Compute the new Cartesian position of the site.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
logging.debug(
"Step %2i: err_norm=%-10.3g update_norm=%-10.3g",
steps,
err_norm,
update_norm,
)
if not success and steps == max_steps - 1:
logging.warning(
"Failed to converge after %i steps: err_norm=%3g", steps, err_norm
)
if not inplace:
# Our temporary copy of physics.data is about to go out of scope, and when
# it does the underlying mjData pointer will be freed and physics.data.qpos
# will be a view onto a block of deallocated memory. We therefore need to
# make a copy of physics.data.qpos while physics.data is still alive.
qpos = physics.data.qpos.copy()
else:
# If we're modifying physics.data in place then it's fine to return a view.
qpos = physics.data.qpos
return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
def nullspace_method(jac_joints, delta, regularization_strength=0.0):
"""Calculates the joint velocities to achieve a specified end effector delta.
Args:
jac_joints: The Jacobian of the end effector with respect to the joints. A
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
and `nv` is the number of degrees of freedom.
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
`(6,)` containing either position deltas, rotation deltas, or both.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements. Default is zero, i.e. no regularization.
Returns:
An `(nv,)` numpy array of joint velocities.
Reference:
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
transpose, pseudoinverse and damped least squares methods.
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
"""
hess_approx = jac_joints.T.dot(jac_joints)
joint_delta = jac_joints.T.dot(delta)
if regularization_strength > 0:
# L2 regularization
hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
return np.linalg.solve(hess_approx, joint_delta)
else:
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
class InverseKinematics:
def __init__(self, xml_path: str):
"""Initializes the inverse kinematics class."""
...
# TODO