285 lines
8.1 KiB
Python
285 lines
8.1 KiB
Python
|
#!/usr/bin/env -S python3 -O
|
||
|
|
||
|
import argparse
|
||
|
import difflib
|
||
|
import os
|
||
|
import uuid
|
||
|
from typing import Dict
|
||
|
|
||
|
import gymnasium as gym
|
||
|
import numpy as np
|
||
|
import torch as th
|
||
|
from stable_baselines3.common.utils import set_random_seed
|
||
|
|
||
|
from rbs_gym import envs as gz_envs
|
||
|
from rbs_gym.utils.exp_manager import ExperimentManager
|
||
|
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
|
||
|
|
||
|
|
||
|
def main(args: Dict):
|
||
|
|
||
|
# Check if the selected environment is valid
|
||
|
# If it could not be found, suggest the closest match
|
||
|
registered_envs = set(gym.envs.registry.keys())
|
||
|
if args.env not in registered_envs:
|
||
|
try:
|
||
|
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
|
||
|
except IndexError:
|
||
|
closest_match = "'no close match found...'"
|
||
|
raise ValueError(
|
||
|
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
|
||
|
)
|
||
|
|
||
|
# If no specific seed is selected, choose a random one
|
||
|
if args.seed < 0:
|
||
|
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
|
||
|
|
||
|
# Set the random seed across platforms
|
||
|
set_random_seed(args.seed)
|
||
|
|
||
|
# Setting num threads to 1 makes things run faster on cpu
|
||
|
if args.num_threads > 0:
|
||
|
if args.verbose > 1:
|
||
|
print(f"Setting torch.num_threads to {args.num_threads}")
|
||
|
th.set_num_threads(args.num_threads)
|
||
|
|
||
|
# Verify that pre-trained agent exists before continuing to train it
|
||
|
if args.trained_agent != "":
|
||
|
assert args.trained_agent.endswith(".zip") and os.path.isfile(
|
||
|
args.trained_agent
|
||
|
), "The trained_agent must be a valid path to a .zip file"
|
||
|
|
||
|
# If enabled, ensure that the run has a unique ID
|
||
|
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
|
||
|
|
||
|
print("=" * 10, args.env, "=" * 10)
|
||
|
print(f"Seed: {args.seed}")
|
||
|
|
||
|
if args.track:
|
||
|
try:
|
||
|
import wandb
|
||
|
import datetime
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
|
||
|
) from e
|
||
|
|
||
|
run_name = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||
|
|
||
|
tags = []
|
||
|
run = wandb.init(
|
||
|
name=run_name,
|
||
|
project="rbs-gym",
|
||
|
entity=None,
|
||
|
tags=tags,
|
||
|
config=vars(args),
|
||
|
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
|
||
|
monitor_gym=False, # auto-upload the videos of agents playing the game
|
||
|
save_code=True, # optional
|
||
|
)
|
||
|
args.tensorboard_log = f"runs/{run_name}"
|
||
|
|
||
|
|
||
|
exp_manager = ExperimentManager(
|
||
|
args,
|
||
|
args.algo,
|
||
|
args.env,
|
||
|
args.log_folder,
|
||
|
args.tensorboard_log,
|
||
|
args.n_timesteps,
|
||
|
args.eval_freq,
|
||
|
args.eval_episodes,
|
||
|
args.save_freq,
|
||
|
args.hyperparams,
|
||
|
args.env_kwargs,
|
||
|
args.trained_agent,
|
||
|
args.optimize_hyperparameters,
|
||
|
truncate_last_trajectory=args.truncate_last_trajectory,
|
||
|
uuid_str=uuid_str,
|
||
|
seed=args.seed,
|
||
|
log_interval=args.log_interval,
|
||
|
save_replay_buffer=args.save_replay_buffer,
|
||
|
preload_replay_buffer=args.preload_replay_buffer,
|
||
|
verbose=args.verbose,
|
||
|
vec_env_type=args.vec_env,
|
||
|
no_optim_plots=args.no_optim_plots,
|
||
|
)
|
||
|
|
||
|
# Prepare experiment
|
||
|
results = exp_manager.setup_experiment()
|
||
|
if results is not None:
|
||
|
model, saved_hyperparams = results
|
||
|
if args.track:
|
||
|
# we need to save the loaded hyperparameters
|
||
|
args.saved_hyperparams = saved_hyperparams
|
||
|
assert run is not None # make mypy happy
|
||
|
run.config.setdefaults(vars(args))
|
||
|
|
||
|
# Normal training
|
||
|
if model is not None:
|
||
|
exp_manager.learn(model)
|
||
|
exp_manager.save_trained_model(model)
|
||
|
else:
|
||
|
exp_manager.hyperparameters_optimization()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
# Environment and its parameters
|
||
|
parser.add_argument(
|
||
|
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--env-kwargs",
|
||
|
type=str,
|
||
|
nargs="+",
|
||
|
action=StoreDict,
|
||
|
help="Optional keyword argument to pass to the env constructor",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--vec-env",
|
||
|
type=str,
|
||
|
choices=["dummy", "subproc"],
|
||
|
default="dummy",
|
||
|
help="Type of VecEnv to use",
|
||
|
)
|
||
|
|
||
|
# Algorithm and training
|
||
|
parser.add_argument(
|
||
|
"--algo",
|
||
|
type=str,
|
||
|
choices=list(ALGOS.keys()),
|
||
|
required=False,
|
||
|
default="sac",
|
||
|
help="RL algorithm to use during the training",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-params",
|
||
|
"--hyperparams",
|
||
|
type=str,
|
||
|
nargs="+",
|
||
|
action=StoreDict,
|
||
|
help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-n",
|
||
|
"--n-timesteps",
|
||
|
type=int,
|
||
|
default=-1,
|
||
|
help="Overwrite the number of timesteps",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--num-threads",
|
||
|
type=int,
|
||
|
default=-1,
|
||
|
help="Number of threads for PyTorch (-1 to use default)",
|
||
|
)
|
||
|
|
||
|
# Continue training an already trained agent
|
||
|
parser.add_argument(
|
||
|
"-i",
|
||
|
"--trained-agent",
|
||
|
type=str,
|
||
|
default="",
|
||
|
help="Path to a pretrained agent to continue training",
|
||
|
)
|
||
|
|
||
|
# Random seed
|
||
|
parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
|
||
|
|
||
|
# Saving of model
|
||
|
parser.add_argument(
|
||
|
"--save-freq",
|
||
|
type=int,
|
||
|
default=10000,
|
||
|
help="Save the model every n steps (if negative, no checkpoint)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--save-replay-buffer",
|
||
|
type=str2bool,
|
||
|
default=False,
|
||
|
help="Save the replay buffer too (when applicable)",
|
||
|
)
|
||
|
|
||
|
# Pre-load a replay buffer and start training on it
|
||
|
parser.add_argument(
|
||
|
"--preload-replay-buffer",
|
||
|
type=empty_str2none,
|
||
|
default="",
|
||
|
help="Path to a replay buffer that should be preloaded before starting the training process",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--track",
|
||
|
type=str2bool,
|
||
|
default=False,
|
||
|
help="Track experiment using wandb"
|
||
|
)
|
||
|
# optimization parameters
|
||
|
parser.add_argument(
|
||
|
"--optimize-hyperparameters",
|
||
|
type=str2bool,
|
||
|
default=False,
|
||
|
help="Run optimization or not?"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--no-optim-plots", action="store_true", default=False, help="Disable hyperparameter optimization plots"
|
||
|
)
|
||
|
|
||
|
# Logging
|
||
|
parser.add_argument(
|
||
|
"-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-tb",
|
||
|
"--tensorboard-log",
|
||
|
type=empty_str2none,
|
||
|
default="tensorboard_logs",
|
||
|
help="Tensorboard log dir",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--log-interval",
|
||
|
type=int,
|
||
|
default=-1,
|
||
|
help="Override log interval (default: -1, no change)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"-uuid",
|
||
|
"--uuid",
|
||
|
type=str2bool,
|
||
|
default=False,
|
||
|
help="Ensure that the run has a unique ID",
|
||
|
)
|
||
|
|
||
|
# Evaluation
|
||
|
parser.add_argument(
|
||
|
"--eval-freq",
|
||
|
type=int,
|
||
|
default=-1,
|
||
|
help="Evaluate the agent every n steps (if negative, no evaluation)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--eval-episodes",
|
||
|
type=int,
|
||
|
default=5,
|
||
|
help="Number of episodes to use for evaluation",
|
||
|
)
|
||
|
|
||
|
# Verbosity
|
||
|
parser.add_argument(
|
||
|
"--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
|
||
|
)
|
||
|
|
||
|
# HER specifics
|
||
|
parser.add_argument(
|
||
|
"--truncate-last-trajectory",
|
||
|
type=str2bool,
|
||
|
default=True,
|
||
|
help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
|
||
|
)
|
||
|
|
||
|
args, unknown = parser.parse_known_args()
|
||
|
|
||
|
main(args=args)
|