123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
![]() |
"""A task where a walker must learn to stand."""
|
||
|
from typing import Optional
|
||
|
|
||
|
import numpy as np
|
||
|
from dm_control import mjcf
|
||
|
from dm_control.suite.utils.randomizers import random_limited_quaternion
|
||
|
|
||
|
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||
|
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
|
||
|
from gello.dm_control_tasks.manipulation.tasks import base
|
||
|
|
||
|
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
|
||
|
|
||
|
|
||
|
class BlockPlay(base.ManipulationTask):
|
||
|
"""Task for a manipulator. Blocks are randomly placed in the scene."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
arm: Manipulator,
|
||
|
arena: FixedManipulationArena,
|
||
|
physics_timestep=base._PHYSICS_TIMESTEP,
|
||
|
control_timestep=base._CONTROL_TIMESTEP,
|
||
|
num_blocks: int = 10,
|
||
|
size: float = 0.03,
|
||
|
reset_joints: Optional[np.ndarray] = None,
|
||
|
) -> None:
|
||
|
super().__init__(arm, arena, physics_timestep, control_timestep)
|
||
|
|
||
|
# find key frame
|
||
|
key_frames = self.root_entity.mjcf_model.find_all("key")
|
||
|
if len(key_frames) == 0:
|
||
|
key_frames = None
|
||
|
else:
|
||
|
key_frames = key_frames[0]
|
||
|
|
||
|
# Create target.
|
||
|
block_joints = []
|
||
|
for i in range(num_blocks):
|
||
|
# select random colors = np.random.uniform(0, 1, size=3)
|
||
|
color = np.concatenate([np.random.uniform(0, 1, size=3), [1.0]])
|
||
|
|
||
|
# attach a body for block i
|
||
|
b = self.root_entity.mjcf_model.worldbody.add(
|
||
|
"body", name=f"block_{i}", pos=(0, 0, 0)
|
||
|
)
|
||
|
|
||
|
# # add a freejoint to the block so it can be moved
|
||
|
_joint = b.add("freejoint")
|
||
|
block_joints.append(_joint)
|
||
|
|
||
|
# add a geom to the block
|
||
|
b.add(
|
||
|
"geom",
|
||
|
name=f"block_geom_{i}",
|
||
|
type="box",
|
||
|
size=(size, size, size),
|
||
|
rgba=color,
|
||
|
# contype=0,
|
||
|
# conaffinity=0,
|
||
|
)
|
||
|
assert key_frames is not None
|
||
|
key_frames.qpos = np.concatenate([key_frames.qpos, np.zeros(7)])
|
||
|
|
||
|
# # save xml to file
|
||
|
# _xml_string = self.root_entity.mjcf_model.to_xml_string()
|
||
|
# with open("block_play.xml", "w") as f:
|
||
|
# f.write(_xml_string)
|
||
|
|
||
|
self._block_joints = block_joints
|
||
|
self._block_size = size
|
||
|
self._reset_joints = reset_joints
|
||
|
|
||
|
def initialize_episode(self, physics, random_state):
|
||
|
# Randomly set feasible target position
|
||
|
if self._reset_joints is not None:
|
||
|
self._arm.set_joints(physics, self._reset_joints)
|
||
|
else:
|
||
|
self._arm.randomize_joints(physics, random_state)
|
||
|
physics.forward()
|
||
|
|
||
|
# check if arm is in collision with floor
|
||
|
while self.in_collision(physics):
|
||
|
self._arm.randomize_joints(physics, random_state)
|
||
|
physics.forward()
|
||
|
|
||
|
# Randomize block positions
|
||
|
for block_j in self._block_joints:
|
||
|
randomize_pose(
|
||
|
block_j,
|
||
|
physics,
|
||
|
random_state=random_state,
|
||
|
position_range=0.5,
|
||
|
z_offset=self._block_size * 2,
|
||
|
)
|
||
|
|
||
|
physics.forward()
|
||
|
|
||
|
def get_reward(self, physics):
|
||
|
# flange position
|
||
|
return 0
|
||
|
|
||
|
|
||
|
def randomize_pose(
|
||
|
free_joint: mjcf.RootElement,
|
||
|
physics: mjcf.Physics,
|
||
|
random_state: np.random.RandomState,
|
||
|
position_range: float = 0.5,
|
||
|
z_offset: float = 0.0,
|
||
|
) -> None:
|
||
|
"""Randomize the pose of an entity."""
|
||
|
entity_pos = random_state.uniform(-position_range, position_range, size=2)
|
||
|
# make x, y farther than 0.1 from 0, 0
|
||
|
while np.linalg.norm(entity_pos) < 0.3:
|
||
|
entity_pos = random_state.uniform(-position_range, position_range, size=2)
|
||
|
|
||
|
entity_pos = np.concatenate([entity_pos, [z_offset]])
|
||
|
entity_quat = random_limited_quaternion(random_state, limit=np.pi)
|
||
|
|
||
|
qpos = np.concatenate([entity_pos, entity_quat])
|
||
|
|
||
|
physics.bind(free_joint).qpos = qpos
|