gello_software/gello/dm_control_tasks/manipulation/tasks/reach.py

63 lines
2.1 KiB
Python

"""A task where a walker must learn to stand."""
import numpy as np
from dm_control.suite.utils import randomizers
from dm_control.utils import rewards
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
from gello.dm_control_tasks.manipulation.tasks import base
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
class Reach(base.ManipulationTask):
"""Reach task for a manipulator."""
def __init__(
self,
arm: Manipulator,
arena: FixedManipulationArena,
physics_timestep=base._PHYSICS_TIMESTEP,
control_timestep=base._CONTROL_TIMESTEP,
distance_tolerance: float = 0.5,
) -> None:
super().__init__(arm, arena, physics_timestep, control_timestep)
self._distance_tolerance = distance_tolerance
# Create target.
self._target = self.root_entity.mjcf_model.worldbody.add(
"site",
name="target",
type="sphere",
pos=(0, 0, 0),
size=(0.1,),
rgba=_TARGET_COLOR,
)
def initialize_episode(self, physics, random_state):
# Randomly set feasible target position
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
physics.forward()
flange_position = physics.bind(self._arm.flange).xpos[:3]
print(flange_position)
# set target position to flange position
physics.bind(self._target).pos = flange_position
# Randomize initial position of the arm.
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
physics.forward()
def get_reward(self, physics):
# flange position
flange_pos = physics.bind(self._arm.flange).pos[:3]
distance = np.linalg.norm(physics.bind(self._target).pos[:3] - flange_pos)
return -rewards.tolerance(
distance,
bounds=(0, self._distance_tolerance),
margin=self._distance_tolerance,
value_at_margin=0,
sigmoid="linear",
)