gello_software/gello/dm_control_tasks/arms/utils.py

261 lines
11 KiB
Python

import collections
import numpy as np
from absl import logging
from dm_control.mujoco.wrapper import mjbindings
mjlib = mjbindings.mjlib
_INVALID_JOINT_NAMES_TYPE = (
"`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}."
)
_REQUIRE_TARGET_POS_OR_QUAT = (
"At least one of `target_pos` or `target_quat` must be specified."
)
IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"])
def qpos_from_site_pose(
physics,
site_name,
target_pos=None,
target_quat=None,
joint_names=None,
tol=1e-14,
rot_weight=1.0,
regularization_threshold=0.1,
regularization_strength=3e-2,
max_update_norm=2.0,
progress_thresh=20.0,
max_steps=100,
inplace=False,
):
"""Find joint positions that satisfy a target site position and/or rotation.
Args:
physics: A `mujoco.Physics` instance.
site_name: A string specifying the name of the target site.
target_pos: A (3,) numpy array specifying the desired Cartesian position of
the site, or None if the position should be unconstrained (default).
One or both of `target_pos` or `target_quat` must be specified.
target_quat: A (4,) numpy array specifying the desired orientation of the
site as a quaternion, or None if the orientation should be unconstrained
(default). One or both of `target_pos` or `target_quat` must be specified.
joint_names: (optional) A list, tuple or numpy array specifying the names of
one or more joints that can be manipulated in order to achieve the target
site pose. If None (default), all joints may be manipulated.
tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
in the stopping criterion).
rot_weight: (optional) Determines the weight given to rotational error
relative to translational error.
regularization_threshold: (optional) L2 regularization will be used when
inverting the Jacobian whilst `err_norm` is greater than this value.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements.
max_update_norm: (optional) The maximum L2 norm of the update applied to
the joint positions on each iteration. The update vector will be scaled
such that its magnitude never exceeds this value.
progress_thresh: (optional) If `err_norm` divided by the magnitude of the
joint position update is greater than this value then the optimization
will terminate prematurely. This is a useful heuristic to avoid getting
stuck in local minima.
max_steps: (optional) The maximum number of iterations to perform.
inplace: (optional) If True, `physics.data` will be modified in place.
Default value is False, i.e. a copy of `physics.data` will be made.
Returns:
An `IKResult` namedtuple with the following fields:
qpos: An (nq,) numpy array of joint positions.
err_norm: A float, the weighted sum of L2 norms for the residual
translational and rotational errors.
steps: An int, the number of iterations that were performed.
success: Boolean, True if we converged on a solution within `max_steps`,
False otherwise.
Raises:
ValueError: If both `target_pos` and `target_quat` are None, or if
`joint_names` has an invalid type.
"""
dtype = physics.data.qpos.dtype
if target_pos is not None and target_quat is not None:
jac = np.empty((6, physics.model.nv), dtype=dtype)
err = np.empty(6, dtype=dtype)
jac_pos, jac_rot = jac[:3], jac[3:]
err_pos, err_rot = err[:3], err[3:]
else:
jac = np.empty((3, physics.model.nv), dtype=dtype)
err = np.empty(3, dtype=dtype)
if target_pos is not None:
jac_pos, jac_rot = jac, None
err_pos, err_rot = err, None
elif target_quat is not None:
jac_pos, jac_rot = None, jac
err_pos, err_rot = None, err
else:
raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
update_nv = np.zeros(physics.model.nv, dtype=dtype)
if target_quat is not None:
site_xquat = np.empty(4, dtype=dtype)
neg_site_xquat = np.empty(4, dtype=dtype)
err_rot_quat = np.empty(4, dtype=dtype)
if not inplace:
physics = physics.copy(share_model=True)
# Ensure that the Cartesian position of the site is up to date.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
# Convert site name to index.
site_id = physics.model.name2id(site_name, "site")
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
# update them in place, so we can avoid indexing overhead in the main loop.
site_xpos = physics.named.data.site_xpos[site_name]
site_xmat = physics.named.data.site_xmat[site_name]
# This is an index into the rows of `update` and the columns of `jac`
# that selects DOFs associated with joints that we are allowed to manipulate.
if joint_names is None:
dof_indices = slice(None) # Update all DOFs.
elif isinstance(joint_names, (list, np.ndarray, tuple)):
if isinstance(joint_names, tuple):
joint_names = list(joint_names)
# Find the indices of the DOFs belonging to each named joint. Note that
# these are not necessarily the same as the joint IDs, since a single joint
# may have >1 DOF (e.g. ball joints).
indexer = physics.named.model.dof_jntid.axes.row
# `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
# indexer to map each joint name to the indices of its corresponding DOFs.
dof_indices = indexer.convert_key_item(joint_names)
else:
raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
steps = 0
success = False
for steps in range(max_steps):
err_norm = 0.0
if target_pos is not None:
# Translational error.
err_pos[:] = target_pos - site_xpos
err_norm += np.linalg.norm(err_pos)
if target_quat is not None:
# Rotational error.
mjlib.mju_mat2Quat(site_xquat, site_xmat)
mjlib.mju_negQuat(neg_site_xquat, site_xquat)
mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
err_norm += np.linalg.norm(err_rot) * rot_weight
if err_norm < tol:
logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm)
success = True
break
else:
# TODO(b/112141670): Generalize this to other entities besides sites.
mjlib.mj_jacSite(
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id
)
jac_joints = jac[:, dof_indices]
# TODO(b/112141592): This does not take joint limits into consideration.
reg_strength = (
regularization_strength if err_norm > regularization_threshold else 0.0
)
update_joints = nullspace_method(
jac_joints, err, regularization_strength=reg_strength
)
update_norm = np.linalg.norm(update_joints)
# Check whether we are still making enough progress, and halt if not.
progress_criterion = err_norm / update_norm
if progress_criterion > progress_thresh:
logging.debug(
"Step %2i: err_norm / update_norm (%3g) > "
"tolerance (%3g). Halting due to insufficient progress",
steps,
progress_criterion,
progress_thresh,
)
break
if update_norm > max_update_norm:
update_joints *= max_update_norm / update_norm
# Write the entries for the specified joints into the full `update_nv`
# vector.
update_nv[dof_indices] = update_joints
# Update `physics.qpos`, taking quaternions into account.
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
# Compute the new Cartesian position of the site.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
logging.debug(
"Step %2i: err_norm=%-10.3g update_norm=%-10.3g",
steps,
err_norm,
update_norm,
)
if not success and steps == max_steps - 1:
logging.warning(
"Failed to converge after %i steps: err_norm=%3g", steps, err_norm
)
if not inplace:
# Our temporary copy of physics.data is about to go out of scope, and when
# it does the underlying mjData pointer will be freed and physics.data.qpos
# will be a view onto a block of deallocated memory. We therefore need to
# make a copy of physics.data.qpos while physics.data is still alive.
qpos = physics.data.qpos.copy()
else:
# If we're modifying physics.data in place then it's fine to return a view.
qpos = physics.data.qpos
return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
def nullspace_method(jac_joints, delta, regularization_strength=0.0):
"""Calculates the joint velocities to achieve a specified end effector delta.
Args:
jac_joints: The Jacobian of the end effector with respect to the joints. A
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
and `nv` is the number of degrees of freedom.
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
`(6,)` containing either position deltas, rotation deltas, or both.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements. Default is zero, i.e. no regularization.
Returns:
An `(nv,)` numpy array of joint velocities.
Reference:
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
transpose, pseudoinverse and damped least squares methods.
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
"""
hess_approx = jac_joints.T.dot(jac_joints)
joint_delta = jac_joints.T.dot(delta)
if regularization_strength > 0:
# L2 regularization
hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
return np.linalg.solve(hess_approx, joint_delta)
else:
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
class InverseKinematics:
def __init__(self, xml_path: str):
"""Initializes the inverse kinematics class."""
...
# TODO