Added rbs_gym package for RL & multi-robot launch setup
This commit is contained in:
parent
f92670cd0d
commit
b58307dea1
103 changed files with 15170 additions and 653 deletions
284
env_manager/rbs_gym/scripts/train.py
Executable file
284
env_manager/rbs_gym/scripts/train.py
Executable file
|
@ -0,0 +1,284 @@
|
|||
#!/usr/bin/env -S python3 -O
|
||||
|
||||
import argparse
|
||||
import difflib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
from rbs_gym import envs as gz_envs
|
||||
from rbs_gym.utils.exp_manager import ExperimentManager
|
||||
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
|
||||
|
||||
|
||||
def main(args: Dict):
|
||||
|
||||
# Check if the selected environment is valid
|
||||
# If it could not be found, suggest the closest match
|
||||
registered_envs = set(gym.envs.registry.keys())
|
||||
if args.env not in registered_envs:
|
||||
try:
|
||||
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
|
||||
except IndexError:
|
||||
closest_match = "'no close match found...'"
|
||||
raise ValueError(
|
||||
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
|
||||
)
|
||||
|
||||
# If no specific seed is selected, choose a random one
|
||||
if args.seed < 0:
|
||||
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
|
||||
|
||||
# Set the random seed across platforms
|
||||
set_random_seed(args.seed)
|
||||
|
||||
# Setting num threads to 1 makes things run faster on cpu
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
# Verify that pre-trained agent exists before continuing to train it
|
||||
if args.trained_agent != "":
|
||||
assert args.trained_agent.endswith(".zip") and os.path.isfile(
|
||||
args.trained_agent
|
||||
), "The trained_agent must be a valid path to a .zip file"
|
||||
|
||||
# If enabled, ensure that the run has a unique ID
|
||||
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
|
||||
|
||||
print("=" * 10, args.env, "=" * 10)
|
||||
print(f"Seed: {args.seed}")
|
||||
|
||||
if args.track:
|
||||
try:
|
||||
import wandb
|
||||
import datetime
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
|
||||
) from e
|
||||
|
||||
run_name = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
|
||||
tags = []
|
||||
run = wandb.init(
|
||||
name=run_name,
|
||||
project="rbs-gym",
|
||||
entity=None,
|
||||
tags=tags,
|
||||
config=vars(args),
|
||||
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
|
||||
monitor_gym=False, # auto-upload the videos of agents playing the game
|
||||
save_code=True, # optional
|
||||
)
|
||||
args.tensorboard_log = f"runs/{run_name}"
|
||||
|
||||
|
||||
exp_manager = ExperimentManager(
|
||||
args,
|
||||
args.algo,
|
||||
args.env,
|
||||
args.log_folder,
|
||||
args.tensorboard_log,
|
||||
args.n_timesteps,
|
||||
args.eval_freq,
|
||||
args.eval_episodes,
|
||||
args.save_freq,
|
||||
args.hyperparams,
|
||||
args.env_kwargs,
|
||||
args.trained_agent,
|
||||
args.optimize_hyperparameters,
|
||||
truncate_last_trajectory=args.truncate_last_trajectory,
|
||||
uuid_str=uuid_str,
|
||||
seed=args.seed,
|
||||
log_interval=args.log_interval,
|
||||
save_replay_buffer=args.save_replay_buffer,
|
||||
preload_replay_buffer=args.preload_replay_buffer,
|
||||
verbose=args.verbose,
|
||||
vec_env_type=args.vec_env,
|
||||
no_optim_plots=args.no_optim_plots,
|
||||
)
|
||||
|
||||
# Prepare experiment
|
||||
results = exp_manager.setup_experiment()
|
||||
if results is not None:
|
||||
model, saved_hyperparams = results
|
||||
if args.track:
|
||||
# we need to save the loaded hyperparameters
|
||||
args.saved_hyperparams = saved_hyperparams
|
||||
assert run is not None # make mypy happy
|
||||
run.config.setdefaults(vars(args))
|
||||
|
||||
# Normal training
|
||||
if model is not None:
|
||||
exp_manager.learn(model)
|
||||
exp_manager.save_trained_model(model)
|
||||
else:
|
||||
exp_manager.hyperparameters_optimization()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Environment and its parameters
|
||||
parser.add_argument(
|
||||
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env-kwargs",
|
||||
type=str,
|
||||
nargs="+",
|
||||
action=StoreDict,
|
||||
help="Optional keyword argument to pass to the env constructor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vec-env",
|
||||
type=str,
|
||||
choices=["dummy", "subproc"],
|
||||
default="dummy",
|
||||
help="Type of VecEnv to use",
|
||||
)
|
||||
|
||||
# Algorithm and training
|
||||
parser.add_argument(
|
||||
"--algo",
|
||||
type=str,
|
||||
choices=list(ALGOS.keys()),
|
||||
required=False,
|
||||
default="sac",
|
||||
help="RL algorithm to use during the training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-params",
|
||||
"--hyperparams",
|
||||
type=str,
|
||||
nargs="+",
|
||||
action=StoreDict,
|
||||
help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--n-timesteps",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Overwrite the number of timesteps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of threads for PyTorch (-1 to use default)",
|
||||
)
|
||||
|
||||
# Continue training an already trained agent
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--trained-agent",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to a pretrained agent to continue training",
|
||||
)
|
||||
|
||||
# Random seed
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
|
||||
|
||||
# Saving of model
|
||||
parser.add_argument(
|
||||
"--save-freq",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Save the model every n steps (if negative, no checkpoint)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-replay-buffer",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Save the replay buffer too (when applicable)",
|
||||
)
|
||||
|
||||
# Pre-load a replay buffer and start training on it
|
||||
parser.add_argument(
|
||||
"--preload-replay-buffer",
|
||||
type=empty_str2none,
|
||||
default="",
|
||||
help="Path to a replay buffer that should be preloaded before starting the training process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--track",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Track experiment using wandb"
|
||||
)
|
||||
# optimization parameters
|
||||
parser.add_argument(
|
||||
"--optimize-hyperparameters",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Run optimization or not?"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-optim-plots", action="store_true", default=False, help="Disable hyperparameter optimization plots"
|
||||
)
|
||||
|
||||
# Logging
|
||||
parser.add_argument(
|
||||
"-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tb",
|
||||
"--tensorboard-log",
|
||||
type=empty_str2none,
|
||||
default="tensorboard_logs",
|
||||
help="Tensorboard log dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Override log interval (default: -1, no change)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-uuid",
|
||||
"--uuid",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Ensure that the run has a unique ID",
|
||||
)
|
||||
|
||||
# Evaluation
|
||||
parser.add_argument(
|
||||
"--eval-freq",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Evaluate the agent every n steps (if negative, no evaluation)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-episodes",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of episodes to use for evaluation",
|
||||
)
|
||||
|
||||
# Verbosity
|
||||
parser.add_argument(
|
||||
"--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
|
||||
)
|
||||
|
||||
# HER specifics
|
||||
parser.add_argument(
|
||||
"--truncate-last-trajectory",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
main(args=args)
|
Loading…
Add table
Add a link
Reference in a new issue