gello_software/gello/dm_control_tasks/manipulation/tasks/base.py

56 lines
2 KiB
Python

"""A abstract class for all walker tasks."""
from dm_control import composer, mjcf
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
# Timestep of the physics simulation.
_PHYSICS_TIMESTEP: float = 0.002
# Interval between agent actions, in seconds.
# We send a control signal every (_CONTROL_TIMESTEP / _PHYSICS_TIMESTEP) physics steps.
_CONTROL_TIMESTEP: float = 0.02 # 50 Hz.
class ManipulationTask(composer.Task):
"""Base composer task for walker robots."""
def __init__(
self,
arm: Manipulator,
arena: FixedManipulationArena,
physics_timestep=_PHYSICS_TIMESTEP,
control_timestep=_CONTROL_TIMESTEP,
) -> None:
self._arm = arm
self._arena = arena
self._arm_geoms = arm.root_body.find_all("geom")
self._arena_geoms = arena.root_body.find_all("geom")
arena.attach(arm, attach_site=arena.arm_attachment_site)
self.set_timesteps(
control_timestep=control_timestep, physics_timestep=physics_timestep
)
# Enable all robot observables. Note: No additional entity observables should
# be added in subclasses.
for observable in self._arm.observables.proprioception:
observable.enabled = True
def in_collision(self, physics: mjcf.Physics) -> bool:
"""Checks if the arm is in collision with the floor (self._arena)."""
arm_ids = [physics.bind(g).element_id for g in self._arm_geoms]
arena_ids = [physics.bind(g).element_id for g in self._arena_geoms]
return any(
(con.geom1 in arm_ids and con.geom2 in arena_ids) # arm and arena
or (con.geom1 in arena_ids and con.geom2 in arm_ids) # arena and arm
or (con.geom1 in arm_ids and con.geom2 in arm_ids) # self collision
for con in physics.data.contact[: physics.data.ncon]
)
@property
def root_entity(self):
return self._arena