104 lines
2.5 KiB
Python
Executable file
104 lines
2.5 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
from typing import Dict
|
|
|
|
import gymnasium as gym
|
|
from stable_baselines3.common.env_checker import check_env
|
|
|
|
from rbs_gym import envs as gz_envs
|
|
from rbs_gym.utils.utils import StoreDict, str2bool
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
|
|
# Create the environment
|
|
if args.env_kwargs is not None:
|
|
env = gym.make(args.env, **args.env_kwargs)
|
|
else:
|
|
env = gym.make(args.env)
|
|
|
|
# Initialize random seed
|
|
env.seed(args.seed)
|
|
|
|
# Check the environment
|
|
if args.check_env:
|
|
check_env(env, warn=True, skip_render_check=True)
|
|
|
|
# Step environment for bunch of episodes
|
|
for episode in range(args.n_episodes):
|
|
|
|
# Initialize returned values
|
|
done = False
|
|
total_reward = 0
|
|
|
|
# Reset the environment
|
|
observation = env.reset()
|
|
|
|
# Step through the current episode until it is done
|
|
while not done:
|
|
|
|
# Sample random action
|
|
action = env.action_space.sample()
|
|
|
|
# Step the environment with the random action
|
|
observation, reward, truncated, terminated, info = env.step(action)
|
|
|
|
done = truncated or terminated
|
|
|
|
# Accumulate the reward
|
|
total_reward += reward
|
|
|
|
print(f"Episode #{episode}\n\treward: {total_reward}")
|
|
|
|
# Cleanup once done
|
|
env.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Environment and its parameters
|
|
parser.add_argument(
|
|
"--env", type=str, default="Reach-ColorImage-Gazebo-v0", help="Environment ID"
|
|
)
|
|
parser.add_argument(
|
|
"--env-kwargs",
|
|
type=str,
|
|
nargs="+",
|
|
action=StoreDict,
|
|
help="Optional keyword argument to pass to the env constructor",
|
|
)
|
|
|
|
# Number of episodes to run
|
|
parser.add_argument(
|
|
"-n",
|
|
"--n-episodes",
|
|
type=int,
|
|
default=10000,
|
|
help="Overwrite the number of episodes",
|
|
)
|
|
|
|
# Random seed
|
|
parser.add_argument("--seed", type=int, default=69, help="Random generator seed")
|
|
|
|
# Flag to check environment
|
|
parser.add_argument(
|
|
"--check-env",
|
|
type=str2bool,
|
|
default=True,
|
|
help="Flag to check the environment before running the random agent",
|
|
)
|
|
|
|
# Flag to enable rendering
|
|
parser.add_argument(
|
|
"--render",
|
|
type=str2bool,
|
|
default=False,
|
|
help="Flag to enable rendering",
|
|
)
|
|
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
main(args=args)
|