gello_software/gello/dm_control_tasks/manipulation/tasks/block_play.py

122 lines
3.9 KiB
Python

"""A task where a walker must learn to stand."""
from typing import Optional
import numpy as np
from dm_control import mjcf
from dm_control.suite.utils.randomizers import random_limited_quaternion
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
from gello.dm_control_tasks.manipulation.tasks import base
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
class BlockPlay(base.ManipulationTask):
"""Task for a manipulator. Blocks are randomly placed in the scene."""
def __init__(
self,
arm: Manipulator,
arena: FixedManipulationArena,
physics_timestep=base._PHYSICS_TIMESTEP,
control_timestep=base._CONTROL_TIMESTEP,
num_blocks: int = 10,
size: float = 0.03,
reset_joints: Optional[np.ndarray] = None,
) -> None:
super().__init__(arm, arena, physics_timestep, control_timestep)
# find key frame
key_frames = self.root_entity.mjcf_model.find_all("key")
if len(key_frames) == 0:
key_frames = None
else:
key_frames = key_frames[0]
# Create target.
block_joints = []
for i in range(num_blocks):
# select random colors = np.random.uniform(0, 1, size=3)
color = np.concatenate([np.random.uniform(0, 1, size=3), [1.0]])
# attach a body for block i
b = self.root_entity.mjcf_model.worldbody.add(
"body", name=f"block_{i}", pos=(0, 0, 0)
)
# # add a freejoint to the block so it can be moved
_joint = b.add("freejoint")
block_joints.append(_joint)
# add a geom to the block
b.add(
"geom",
name=f"block_geom_{i}",
type="box",
size=(size, size, size),
rgba=color,
# contype=0,
# conaffinity=0,
)
assert key_frames is not None
key_frames.qpos = np.concatenate([key_frames.qpos, np.zeros(7)])
# # save xml to file
# _xml_string = self.root_entity.mjcf_model.to_xml_string()
# with open("block_play.xml", "w") as f:
# f.write(_xml_string)
self._block_joints = block_joints
self._block_size = size
self._reset_joints = reset_joints
def initialize_episode(self, physics, random_state):
# Randomly set feasible target position
if self._reset_joints is not None:
self._arm.set_joints(physics, self._reset_joints)
else:
self._arm.randomize_joints(physics, random_state)
physics.forward()
# check if arm is in collision with floor
while self.in_collision(physics):
self._arm.randomize_joints(physics, random_state)
physics.forward()
# Randomize block positions
for block_j in self._block_joints:
randomize_pose(
block_j,
physics,
random_state=random_state,
position_range=0.5,
z_offset=self._block_size * 2,
)
physics.forward()
def get_reward(self, physics):
# flange position
return 0
def randomize_pose(
free_joint: mjcf.RootElement,
physics: mjcf.Physics,
random_state: np.random.RandomState,
position_range: float = 0.5,
z_offset: float = 0.0,
) -> None:
"""Randomize the pose of an entity."""
entity_pos = random_state.uniform(-position_range, position_range, size=2)
# make x, y farther than 0.1 from 0, 0
while np.linalg.norm(entity_pos) < 0.3:
entity_pos = random_state.uniform(-position_range, position_range, size=2)
entity_pos = np.concatenate([entity_pos, [z_offset]])
entity_quat = random_limited_quaternion(random_state, limit=np.pi)
qpos = np.concatenate([entity_pos, entity_quat])
physics.bind(free_joint).qpos = qpos