initial commit, add gello software code and instructions
This commit is contained in:
parent
e7d842ad35
commit
18cc23a38e
70 changed files with 5875 additions and 4 deletions
7
gello/dm_control_tasks/arms/__init__.py
Normal file
7
gello/dm_control_tasks/arms/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from gello.dm_control_tasks.arms.franka import Franka
|
||||
from gello.dm_control_tasks.arms.ur5e import UR5e
|
||||
|
||||
__all__ = [
|
||||
"UR5e",
|
||||
"Franka",
|
||||
]
|
28
gello/dm_control_tasks/arms/franka.py
Normal file
28
gello/dm_control_tasks/arms/franka.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
"""Franka composer class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
|
||||
|
||||
XML = MENAGERIE_ROOT / "franka_emika_panda" / "panda_nohand.xml"
|
||||
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
|
||||
|
||||
class Franka(Manipulator):
|
||||
"""Franka Robot."""
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str = "franka",
|
||||
xml_path: Union[str, Path] = XML,
|
||||
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
|
||||
) -> None:
|
||||
super()._build(name="franka", xml_path=XML, gripper_xml_path=GRIPPER_XML)
|
||||
|
||||
@property
|
||||
def flange(self) -> mjcf.Element:
|
||||
return self._mjcf_root.find("site", "attachment_site")
|
229
gello/dm_control_tasks/arms/manipulator.py
Normal file
229
gello/dm_control_tasks/arms/manipulator.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
"""Manipulator composer class."""
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from dm_control import composer, mjcf
|
||||
from dm_control.composer.observation import observable
|
||||
from dm_control.mujoco.wrapper import mjbindings
|
||||
from dm_control.suite.utils.randomizers import random_limited_quaternion
|
||||
|
||||
from gello.dm_control_tasks import mjcf_utils
|
||||
|
||||
|
||||
def attach_hand_to_arm(
|
||||
arm_mjcf: mjcf.RootElement,
|
||||
hand_mjcf: mjcf.RootElement,
|
||||
# attach_site: str,
|
||||
) -> None:
|
||||
"""Attaches a hand to an arm.
|
||||
|
||||
The arm must have a site named "attachment_site".
|
||||
|
||||
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
|
||||
|
||||
Args:
|
||||
arm_mjcf: The mjcf.RootElement of the arm.
|
||||
hand_mjcf: The mjcf.RootElement of the hand.
|
||||
attach_site: The name of the site to attach the hand to.
|
||||
|
||||
Raises:
|
||||
ValueError: If the arm does not have a site named "attachment_site".
|
||||
"""
|
||||
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
|
||||
|
||||
# attachment_site = arm_mjcf.find("site", attach_site)
|
||||
attachment_site = arm_mjcf.find("site", "attachment_site")
|
||||
if attachment_site is None:
|
||||
raise ValueError("No attachment site found in the arm model.")
|
||||
|
||||
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
|
||||
arm_key = arm_mjcf.find("key", "home")
|
||||
if arm_key is not None:
|
||||
hand_key = hand_mjcf.find("key", "home")
|
||||
if hand_key is None:
|
||||
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
|
||||
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
|
||||
else:
|
||||
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
|
||||
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
|
||||
|
||||
attachment_site.attach(hand_mjcf)
|
||||
|
||||
|
||||
class Manipulator(composer.Entity, abc.ABC):
|
||||
"""A manipulator entity."""
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str,
|
||||
xml_path: Union[str, Path],
|
||||
gripper_xml_path: Optional[Union[str, Path]],
|
||||
) -> None:
|
||||
"""Builds the manipulator.
|
||||
|
||||
Subclasses can not override this method, but should call this method in their
|
||||
own _build() method.
|
||||
"""
|
||||
self._mjcf_root = mjcf.from_path(str(xml_path))
|
||||
self._arm_joints = tuple(mjcf_utils.safe_find_all(self._mjcf_root, "joint"))
|
||||
if gripper_xml_path:
|
||||
gripper_mjcf = mjcf.from_path(str(gripper_xml_path))
|
||||
attach_hand_to_arm(self._mjcf_root, gripper_mjcf)
|
||||
|
||||
self._mjcf_root.model = name
|
||||
self._add_mjcf_elements()
|
||||
|
||||
def set_joints(self, physics: mjcf.Physics, joints: np.ndarray) -> None:
|
||||
assert len(joints) == len(self._arm_joints)
|
||||
for joint, joint_value in zip(self._arm_joints, joints):
|
||||
joint_id = physics.bind(joint).element_id
|
||||
joint_name = physics.model.id2name(joint_id, "joint")
|
||||
physics.named.data.qpos[joint_name] = joint_value
|
||||
|
||||
def randomize_joints(
|
||||
self,
|
||||
physics: mjcf.Physics,
|
||||
random: Optional[np.random.RandomState] = None,
|
||||
) -> None:
|
||||
random = random or np.random # type: ignore
|
||||
assert random is not None
|
||||
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
|
||||
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
|
||||
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
|
||||
free = mjbindings.enums.mjtJoint.mjJNT_FREE
|
||||
|
||||
qpos = physics.named.data.qpos
|
||||
|
||||
for joint in self._arm_joints:
|
||||
joint_id = physics.bind(joint).element_id
|
||||
# joint_id = physics.model.name2id(joint.name, "joint")
|
||||
joint_name = physics.model.id2name(joint_id, "joint")
|
||||
joint_type = physics.model.jnt_type[joint_id]
|
||||
is_limited = physics.model.jnt_limited[joint_id]
|
||||
range_min, range_max = physics.model.jnt_range[joint_id]
|
||||
|
||||
if is_limited:
|
||||
if joint_type in [hinge, slide]:
|
||||
qpos[joint_name] = random.uniform(range_min, range_max)
|
||||
|
||||
elif joint_type == ball:
|
||||
qpos[joint_name] = random_limited_quaternion(random, range_max)
|
||||
|
||||
else:
|
||||
if joint_type == hinge:
|
||||
qpos[joint_name] = random.uniform(-np.pi, np.pi)
|
||||
|
||||
elif joint_type == ball:
|
||||
quat = random.randn(4)
|
||||
quat /= np.linalg.norm(quat)
|
||||
qpos[joint_name] = quat
|
||||
|
||||
elif joint_type == free:
|
||||
# this should be random.randn, but changing it now could significantly
|
||||
# affect benchmark results.
|
||||
quat = random.rand(4)
|
||||
quat /= np.linalg.norm(quat)
|
||||
qpos[joint_name][3:] = quat
|
||||
|
||||
def _add_mjcf_elements(self) -> None:
|
||||
# Parse joints.
|
||||
joints = mjcf_utils.safe_find_all(self._mjcf_root, "joint")
|
||||
joints = [joint for joint in joints if joint.tag != "freejoint"]
|
||||
self._joints = tuple(joints)
|
||||
|
||||
# Parse actuators.
|
||||
actuators = mjcf_utils.safe_find_all(self._mjcf_root, "actuator")
|
||||
self._actuators = tuple(actuators)
|
||||
|
||||
# Parse qpos / ctrl keyframes.
|
||||
self._keyframes = {}
|
||||
keyframes = mjcf_utils.safe_find_all(self._mjcf_root, "key")
|
||||
if keyframes:
|
||||
for frame in keyframes:
|
||||
if frame.qpos is not None:
|
||||
qpos = np.array(frame.qpos)
|
||||
self._keyframes[frame.name] = qpos
|
||||
|
||||
# add a visualizeation the flange position that is green
|
||||
self.flange.parent.add(
|
||||
"geom",
|
||||
name="flange_geom",
|
||||
type="sphere",
|
||||
size="0.01",
|
||||
rgba="0 1 0 1",
|
||||
pos=self.flange.pos,
|
||||
contype="0",
|
||||
conaffinity="0",
|
||||
)
|
||||
|
||||
def _build_observables(self):
|
||||
return ArmObservables(self)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def flange(self) -> mjcf.Element:
|
||||
"""Returns the flange element.
|
||||
|
||||
The flange is the end effector of the manipulator where tools can be
|
||||
attached, such as a gripper.
|
||||
"""
|
||||
|
||||
@property
|
||||
def mjcf_model(self) -> mjcf.RootElement:
|
||||
return self._mjcf_root
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._mjcf_root.model
|
||||
|
||||
@property
|
||||
def joints(self) -> Tuple[mjcf.Element, ...]:
|
||||
return self._joints
|
||||
|
||||
@property
|
||||
def actuators(self) -> Tuple[mjcf.Element, ...]:
|
||||
return self._actuators
|
||||
|
||||
@property
|
||||
def keyframes(self) -> Dict[str, np.ndarray]:
|
||||
return self._keyframes
|
||||
|
||||
|
||||
class ArmObservables(composer.Observables):
|
||||
"""Base class for quadruped observables."""
|
||||
|
||||
@composer.observable
|
||||
def joints_pos(self):
|
||||
return observable.MJCFFeature("qpos", self._entity.joints)
|
||||
|
||||
@composer.observable
|
||||
def joints_vel(self):
|
||||
return observable.MJCFFeature("qvel", self._entity.joints)
|
||||
|
||||
@composer.observable
|
||||
def flange_position(self):
|
||||
return observable.MJCFFeature("xpos", self._entity.flange)
|
||||
|
||||
@composer.observable
|
||||
def flange_orientation(self):
|
||||
return observable.MJCFFeature("xmat", self._entity.flange)
|
||||
|
||||
# Semantic grouping of observables.
|
||||
def _collect_from_attachments(self, attribute_name: str):
|
||||
out: List[observable.MJCFFeature] = []
|
||||
for entity in self._entity.iter_entities(exclude_self=True):
|
||||
out.extend(getattr(entity.observables, attribute_name, []))
|
||||
return out
|
||||
|
||||
@property
|
||||
def proprioception(self):
|
||||
return [
|
||||
self.joints_pos,
|
||||
self.joints_vel,
|
||||
self.flange_position,
|
||||
# self.flange_orientation,
|
||||
# self.flange_velocity,
|
||||
self.flange_angular_velocity,
|
||||
] + self._collect_from_attachments("proprioception")
|
26
gello/dm_control_tasks/arms/ur5e.py
Normal file
26
gello/dm_control_tasks/arms/ur5e.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""UR5e composer class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
|
||||
|
||||
|
||||
class UR5e(Manipulator):
|
||||
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
XML = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str = "UR5e",
|
||||
xml_path: Union[str, Path] = XML,
|
||||
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
|
||||
) -> None:
|
||||
super()._build(name=name, xml_path=xml_path, gripper_xml_path=gripper_xml_path)
|
||||
|
||||
@property
|
||||
def flange(self) -> mjcf.Element:
|
||||
return self._mjcf_root.find("site", "attachment_site")
|
22
gello/dm_control_tasks/arms/ur5e_test.py
Normal file
22
gello/dm_control_tasks/arms/ur5e_test.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""Tests for ur5e.py."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms import ur5e
|
||||
|
||||
|
||||
class UR5eTest(absltest.TestCase):
|
||||
def test_compiles_and_steps(self) -> None:
|
||||
robot = ur5e.UR5e()
|
||||
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
|
||||
physics.step()
|
||||
|
||||
def test_joints(self) -> None:
|
||||
robot = ur5e.UR5e()
|
||||
for joint in robot.joints:
|
||||
self.assertEqual(joint.tag, "joint")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
261
gello/dm_control_tasks/arms/utils.py
Normal file
261
gello/dm_control_tasks/arms/utils.py
Normal file
|
@ -0,0 +1,261 @@
|
|||
import collections
|
||||
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
from dm_control.mujoco.wrapper import mjbindings
|
||||
|
||||
mjlib = mjbindings.mjlib
|
||||
|
||||
|
||||
_INVALID_JOINT_NAMES_TYPE = (
|
||||
"`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}."
|
||||
)
|
||||
_REQUIRE_TARGET_POS_OR_QUAT = (
|
||||
"At least one of `target_pos` or `target_quat` must be specified."
|
||||
)
|
||||
|
||||
IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"])
|
||||
|
||||
|
||||
def qpos_from_site_pose(
|
||||
physics,
|
||||
site_name,
|
||||
target_pos=None,
|
||||
target_quat=None,
|
||||
joint_names=None,
|
||||
tol=1e-14,
|
||||
rot_weight=1.0,
|
||||
regularization_threshold=0.1,
|
||||
regularization_strength=3e-2,
|
||||
max_update_norm=2.0,
|
||||
progress_thresh=20.0,
|
||||
max_steps=100,
|
||||
inplace=False,
|
||||
):
|
||||
"""Find joint positions that satisfy a target site position and/or rotation.
|
||||
|
||||
Args:
|
||||
physics: A `mujoco.Physics` instance.
|
||||
site_name: A string specifying the name of the target site.
|
||||
target_pos: A (3,) numpy array specifying the desired Cartesian position of
|
||||
the site, or None if the position should be unconstrained (default).
|
||||
One or both of `target_pos` or `target_quat` must be specified.
|
||||
target_quat: A (4,) numpy array specifying the desired orientation of the
|
||||
site as a quaternion, or None if the orientation should be unconstrained
|
||||
(default). One or both of `target_pos` or `target_quat` must be specified.
|
||||
joint_names: (optional) A list, tuple or numpy array specifying the names of
|
||||
one or more joints that can be manipulated in order to achieve the target
|
||||
site pose. If None (default), all joints may be manipulated.
|
||||
tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
|
||||
in the stopping criterion).
|
||||
rot_weight: (optional) Determines the weight given to rotational error
|
||||
relative to translational error.
|
||||
regularization_threshold: (optional) L2 regularization will be used when
|
||||
inverting the Jacobian whilst `err_norm` is greater than this value.
|
||||
regularization_strength: (optional) Coefficient of the quadratic penalty
|
||||
on joint movements.
|
||||
max_update_norm: (optional) The maximum L2 norm of the update applied to
|
||||
the joint positions on each iteration. The update vector will be scaled
|
||||
such that its magnitude never exceeds this value.
|
||||
progress_thresh: (optional) If `err_norm` divided by the magnitude of the
|
||||
joint position update is greater than this value then the optimization
|
||||
will terminate prematurely. This is a useful heuristic to avoid getting
|
||||
stuck in local minima.
|
||||
max_steps: (optional) The maximum number of iterations to perform.
|
||||
inplace: (optional) If True, `physics.data` will be modified in place.
|
||||
Default value is False, i.e. a copy of `physics.data` will be made.
|
||||
|
||||
Returns:
|
||||
An `IKResult` namedtuple with the following fields:
|
||||
qpos: An (nq,) numpy array of joint positions.
|
||||
err_norm: A float, the weighted sum of L2 norms for the residual
|
||||
translational and rotational errors.
|
||||
steps: An int, the number of iterations that were performed.
|
||||
success: Boolean, True if we converged on a solution within `max_steps`,
|
||||
False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `target_pos` and `target_quat` are None, or if
|
||||
`joint_names` has an invalid type.
|
||||
"""
|
||||
dtype = physics.data.qpos.dtype
|
||||
|
||||
if target_pos is not None and target_quat is not None:
|
||||
jac = np.empty((6, physics.model.nv), dtype=dtype)
|
||||
err = np.empty(6, dtype=dtype)
|
||||
jac_pos, jac_rot = jac[:3], jac[3:]
|
||||
err_pos, err_rot = err[:3], err[3:]
|
||||
else:
|
||||
jac = np.empty((3, physics.model.nv), dtype=dtype)
|
||||
err = np.empty(3, dtype=dtype)
|
||||
if target_pos is not None:
|
||||
jac_pos, jac_rot = jac, None
|
||||
err_pos, err_rot = err, None
|
||||
elif target_quat is not None:
|
||||
jac_pos, jac_rot = None, jac
|
||||
err_pos, err_rot = None, err
|
||||
else:
|
||||
raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
|
||||
|
||||
update_nv = np.zeros(physics.model.nv, dtype=dtype)
|
||||
|
||||
if target_quat is not None:
|
||||
site_xquat = np.empty(4, dtype=dtype)
|
||||
neg_site_xquat = np.empty(4, dtype=dtype)
|
||||
err_rot_quat = np.empty(4, dtype=dtype)
|
||||
|
||||
if not inplace:
|
||||
physics = physics.copy(share_model=True)
|
||||
|
||||
# Ensure that the Cartesian position of the site is up to date.
|
||||
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
|
||||
|
||||
# Convert site name to index.
|
||||
site_id = physics.model.name2id(site_name, "site")
|
||||
|
||||
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
|
||||
# update them in place, so we can avoid indexing overhead in the main loop.
|
||||
site_xpos = physics.named.data.site_xpos[site_name]
|
||||
site_xmat = physics.named.data.site_xmat[site_name]
|
||||
|
||||
# This is an index into the rows of `update` and the columns of `jac`
|
||||
# that selects DOFs associated with joints that we are allowed to manipulate.
|
||||
if joint_names is None:
|
||||
dof_indices = slice(None) # Update all DOFs.
|
||||
elif isinstance(joint_names, (list, np.ndarray, tuple)):
|
||||
if isinstance(joint_names, tuple):
|
||||
joint_names = list(joint_names)
|
||||
# Find the indices of the DOFs belonging to each named joint. Note that
|
||||
# these are not necessarily the same as the joint IDs, since a single joint
|
||||
# may have >1 DOF (e.g. ball joints).
|
||||
indexer = physics.named.model.dof_jntid.axes.row
|
||||
# `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
|
||||
# indexer to map each joint name to the indices of its corresponding DOFs.
|
||||
dof_indices = indexer.convert_key_item(joint_names)
|
||||
else:
|
||||
raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
|
||||
|
||||
steps = 0
|
||||
success = False
|
||||
|
||||
for steps in range(max_steps):
|
||||
err_norm = 0.0
|
||||
|
||||
if target_pos is not None:
|
||||
# Translational error.
|
||||
err_pos[:] = target_pos - site_xpos
|
||||
err_norm += np.linalg.norm(err_pos)
|
||||
if target_quat is not None:
|
||||
# Rotational error.
|
||||
mjlib.mju_mat2Quat(site_xquat, site_xmat)
|
||||
mjlib.mju_negQuat(neg_site_xquat, site_xquat)
|
||||
mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
|
||||
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
|
||||
err_norm += np.linalg.norm(err_rot) * rot_weight
|
||||
|
||||
if err_norm < tol:
|
||||
logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm)
|
||||
success = True
|
||||
break
|
||||
else:
|
||||
# TODO(b/112141670): Generalize this to other entities besides sites.
|
||||
mjlib.mj_jacSite(
|
||||
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id
|
||||
)
|
||||
jac_joints = jac[:, dof_indices]
|
||||
|
||||
# TODO(b/112141592): This does not take joint limits into consideration.
|
||||
reg_strength = (
|
||||
regularization_strength if err_norm > regularization_threshold else 0.0
|
||||
)
|
||||
update_joints = nullspace_method(
|
||||
jac_joints, err, regularization_strength=reg_strength
|
||||
)
|
||||
|
||||
update_norm = np.linalg.norm(update_joints)
|
||||
|
||||
# Check whether we are still making enough progress, and halt if not.
|
||||
progress_criterion = err_norm / update_norm
|
||||
if progress_criterion > progress_thresh:
|
||||
logging.debug(
|
||||
"Step %2i: err_norm / update_norm (%3g) > "
|
||||
"tolerance (%3g). Halting due to insufficient progress",
|
||||
steps,
|
||||
progress_criterion,
|
||||
progress_thresh,
|
||||
)
|
||||
break
|
||||
|
||||
if update_norm > max_update_norm:
|
||||
update_joints *= max_update_norm / update_norm
|
||||
|
||||
# Write the entries for the specified joints into the full `update_nv`
|
||||
# vector.
|
||||
update_nv[dof_indices] = update_joints
|
||||
|
||||
# Update `physics.qpos`, taking quaternions into account.
|
||||
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
|
||||
|
||||
# Compute the new Cartesian position of the site.
|
||||
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
|
||||
|
||||
logging.debug(
|
||||
"Step %2i: err_norm=%-10.3g update_norm=%-10.3g",
|
||||
steps,
|
||||
err_norm,
|
||||
update_norm,
|
||||
)
|
||||
|
||||
if not success and steps == max_steps - 1:
|
||||
logging.warning(
|
||||
"Failed to converge after %i steps: err_norm=%3g", steps, err_norm
|
||||
)
|
||||
|
||||
if not inplace:
|
||||
# Our temporary copy of physics.data is about to go out of scope, and when
|
||||
# it does the underlying mjData pointer will be freed and physics.data.qpos
|
||||
# will be a view onto a block of deallocated memory. We therefore need to
|
||||
# make a copy of physics.data.qpos while physics.data is still alive.
|
||||
qpos = physics.data.qpos.copy()
|
||||
else:
|
||||
# If we're modifying physics.data in place then it's fine to return a view.
|
||||
qpos = physics.data.qpos
|
||||
|
||||
return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
|
||||
|
||||
|
||||
def nullspace_method(jac_joints, delta, regularization_strength=0.0):
|
||||
"""Calculates the joint velocities to achieve a specified end effector delta.
|
||||
|
||||
Args:
|
||||
jac_joints: The Jacobian of the end effector with respect to the joints. A
|
||||
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
|
||||
and `nv` is the number of degrees of freedom.
|
||||
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
|
||||
`(6,)` containing either position deltas, rotation deltas, or both.
|
||||
regularization_strength: (optional) Coefficient of the quadratic penalty
|
||||
on joint movements. Default is zero, i.e. no regularization.
|
||||
|
||||
Returns:
|
||||
An `(nv,)` numpy array of joint velocities.
|
||||
|
||||
Reference:
|
||||
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
|
||||
transpose, pseudoinverse and damped least squares methods.
|
||||
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
|
||||
"""
|
||||
hess_approx = jac_joints.T.dot(jac_joints)
|
||||
joint_delta = jac_joints.T.dot(delta)
|
||||
if regularization_strength > 0:
|
||||
# L2 regularization
|
||||
hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
|
||||
return np.linalg.solve(hess_approx, joint_delta)
|
||||
else:
|
||||
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
|
||||
|
||||
|
||||
class InverseKinematics:
|
||||
def __init__(self, xml_path: str):
|
||||
"""Initializes the inverse kinematics class."""
|
||||
...
|
||||
# TODO
|
Loading…
Add table
Add a link
Reference in a new issue