56 lines
2 KiB
Python
56 lines
2 KiB
Python
"""A abstract class for all walker tasks."""
|
|
|
|
from dm_control import composer, mjcf
|
|
|
|
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
|
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
|
|
|
|
# Timestep of the physics simulation.
|
|
_PHYSICS_TIMESTEP: float = 0.002
|
|
|
|
# Interval between agent actions, in seconds.
|
|
# We send a control signal every (_CONTROL_TIMESTEP / _PHYSICS_TIMESTEP) physics steps.
|
|
_CONTROL_TIMESTEP: float = 0.02 # 50 Hz.
|
|
|
|
|
|
class ManipulationTask(composer.Task):
|
|
"""Base composer task for walker robots."""
|
|
|
|
def __init__(
|
|
self,
|
|
arm: Manipulator,
|
|
arena: FixedManipulationArena,
|
|
physics_timestep=_PHYSICS_TIMESTEP,
|
|
control_timestep=_CONTROL_TIMESTEP,
|
|
) -> None:
|
|
self._arm = arm
|
|
self._arena = arena
|
|
|
|
self._arm_geoms = arm.root_body.find_all("geom")
|
|
self._arena_geoms = arena.root_body.find_all("geom")
|
|
arena.attach(arm, attach_site=arena.arm_attachment_site)
|
|
|
|
self.set_timesteps(
|
|
control_timestep=control_timestep, physics_timestep=physics_timestep
|
|
)
|
|
|
|
# Enable all robot observables. Note: No additional entity observables should
|
|
# be added in subclasses.
|
|
for observable in self._arm.observables.proprioception:
|
|
observable.enabled = True
|
|
|
|
def in_collision(self, physics: mjcf.Physics) -> bool:
|
|
"""Checks if the arm is in collision with the floor (self._arena)."""
|
|
arm_ids = [physics.bind(g).element_id for g in self._arm_geoms]
|
|
arena_ids = [physics.bind(g).element_id for g in self._arena_geoms]
|
|
|
|
return any(
|
|
(con.geom1 in arm_ids and con.geom2 in arena_ids) # arm and arena
|
|
or (con.geom1 in arena_ids and con.geom2 in arm_ids) # arena and arm
|
|
or (con.geom1 in arm_ids and con.geom2 in arm_ids) # self collision
|
|
for con in physics.data.contact[: physics.data.ncon]
|
|
)
|
|
|
|
@property
|
|
def root_entity(self):
|
|
return self._arena
|