43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
from typing import Any, Dict, Protocol
|
|
|
|
import numpy as np
|
|
|
|
|
|
class Agent(Protocol):
|
|
def act(self, obs: Dict[str, Any]) -> np.ndarray:
|
|
"""Returns an action given an observation.
|
|
|
|
Args:
|
|
obs: observation from the environment.
|
|
|
|
Returns:
|
|
action: action to take on the environment.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class DummyAgent(Agent):
|
|
def __init__(self, num_dofs: int):
|
|
self.num_dofs = num_dofs
|
|
|
|
def act(self, obs: Dict[str, Any]) -> np.ndarray:
|
|
return np.zeros(self.num_dofs)
|
|
|
|
|
|
class BimanualAgent(Agent):
|
|
def __init__(self, agent_left: Agent, agent_right: Agent):
|
|
self.agent_left = agent_left
|
|
self.agent_right = agent_right
|
|
|
|
def act(self, obs: Dict[str, Any]) -> np.ndarray:
|
|
left_obs = {}
|
|
right_obs = {}
|
|
for key, val in obs.items():
|
|
L = val.shape[0]
|
|
half_dim = L // 2
|
|
assert L == half_dim * 2, f"{key} must be even, something is wrong"
|
|
left_obs[key] = val[:half_dim]
|
|
right_obs[key] = val[half_dim:]
|
|
return np.concatenate(
|
|
[self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
|
|
)
|