63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
"""A task where a walker must learn to stand."""
|
|
|
|
import numpy as np
|
|
from dm_control.suite.utils import randomizers
|
|
from dm_control.utils import rewards
|
|
|
|
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
|
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
|
|
from gello.dm_control_tasks.manipulation.tasks import base
|
|
|
|
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
|
|
|
|
|
|
class Reach(base.ManipulationTask):
|
|
"""Reach task for a manipulator."""
|
|
|
|
def __init__(
|
|
self,
|
|
arm: Manipulator,
|
|
arena: FixedManipulationArena,
|
|
physics_timestep=base._PHYSICS_TIMESTEP,
|
|
control_timestep=base._CONTROL_TIMESTEP,
|
|
distance_tolerance: float = 0.5,
|
|
) -> None:
|
|
super().__init__(arm, arena, physics_timestep, control_timestep)
|
|
|
|
self._distance_tolerance = distance_tolerance
|
|
|
|
# Create target.
|
|
self._target = self.root_entity.mjcf_model.worldbody.add(
|
|
"site",
|
|
name="target",
|
|
type="sphere",
|
|
pos=(0, 0, 0),
|
|
size=(0.1,),
|
|
rgba=_TARGET_COLOR,
|
|
)
|
|
|
|
def initialize_episode(self, physics, random_state):
|
|
# Randomly set feasible target position
|
|
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
|
|
physics.forward()
|
|
flange_position = physics.bind(self._arm.flange).xpos[:3]
|
|
print(flange_position)
|
|
|
|
# set target position to flange position
|
|
physics.bind(self._target).pos = flange_position
|
|
|
|
# Randomize initial position of the arm.
|
|
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
|
|
physics.forward()
|
|
|
|
def get_reward(self, physics):
|
|
# flange position
|
|
flange_pos = physics.bind(self._arm.flange).pos[:3]
|
|
distance = np.linalg.norm(physics.bind(self._target).pos[:3] - flange_pos)
|
|
return -rewards.tolerance(
|
|
distance,
|
|
bounds=(0, self._distance_tolerance),
|
|
margin=self._distance_tolerance,
|
|
value_at_margin=0,
|
|
sigmoid="linear",
|
|
)
|