runtime/env_manager/rbs_gym/scripts/optimize.py

233 lines
6.3 KiB
Python

#!/usr/bin/env -S python3 -O
import argparse
import difflib
import os
import uuid
from typing import Dict
import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3.common.utils import set_random_seed
from rbs_gym import envs as gz_envs
from rbs_gym.utils.exp_manager import ExperimentManager
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
def main(args: Dict):
# Check if the selected environment is valid
# If it could not be found, suggest the closest match
registered_envs = set(gym.envs.registry.keys())
if args.env not in registered_envs:
try:
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
except IndexError:
closest_match = "'no close match found...'"
raise ValueError(
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
)
# If no specific seed is selected, choose a random one
if args.seed < 0:
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
# Set the random seed across platforms
set_random_seed(args.seed)
# Setting num threads to 1 makes things run faster on cpu
if args.num_threads > 0:
if args.verbose > 1:
print(f"Setting torch.num_threads to {args.num_threads}")
th.set_num_threads(args.num_threads)
# Verify that pre-trained agent exists before continuing to train it
if args.trained_agent != "":
assert args.trained_agent.endswith(".zip") and os.path.isfile(
args.trained_agent
), "The trained_agent must be a valid path to a .zip file"
# If enabled, ensure that the run has a unique ID
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
print("=" * 10, args.env, "=" * 10)
print(f"Seed: {args.seed}")
env_kwargs = {
"render_mode": "human"
}
exp_manager = ExperimentManager(
args,
args.algo,
args.env,
args.log_folder,
args.tensorboard_log,
args.n_timesteps,
args.eval_freq,
args.eval_episodes,
args.save_freq,
args.hyperparams,
args.env_kwargs,
args.trained_agent,
truncate_last_trajectory=args.truncate_last_trajectory,
uuid_str=uuid_str,
seed=args.seed,
log_interval=args.log_interval,
save_replay_buffer=args.save_replay_buffer,
preload_replay_buffer=args.preload_replay_buffer,
verbose=args.verbose,
vec_env_type=args.vec_env,
)
# Prepare experiment
model = exp_manager.setup_experiment()
exp_manager.learn(model)
exp_manager.save_trained_model(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Environment and its parameters
parser.add_argument(
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
)
parser.add_argument(
"--env-kwargs",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor",
)
parser.add_argument(
"--vec-env",
type=str,
choices=["dummy", "subproc"],
default="dummy",
help="Type of VecEnv to use",
)
# Algorithm and training
parser.add_argument(
"--algo",
type=str,
choices=list(ALGOS.keys()),
required=False,
default="sac",
help="RL algorithm to use during the training",
)
parser.add_argument(
"-params",
"--hyperparams",
type=str,
nargs="+",
action=StoreDict,
help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-n",
"--n-timesteps",
type=int,
default=-1,
help="Overwrite the number of timesteps",
)
parser.add_argument(
"--num-threads",
type=int,
default=-1,
help="Number of threads for PyTorch (-1 to use default)",
)
# Continue training an already trained agent
parser.add_argument(
"-i",
"--trained-agent",
type=str,
default="",
help="Path to a pretrained agent to continue training",
)
# Random seed
parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
# Saving of model
parser.add_argument(
"--save-freq",
type=int,
default=10000,
help="Save the model every n steps (if negative, no checkpoint)",
)
parser.add_argument(
"--save-replay-buffer",
type=str2bool,
default=False,
help="Save the replay buffer too (when applicable)",
)
# Pre-load a replay buffer and start training on it
parser.add_argument(
"--preload-replay-buffer",
type=empty_str2none,
default="",
help="Path to a replay buffer that should be preloaded before starting the training process",
)
# Logging
parser.add_argument(
"-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
)
parser.add_argument(
"-tb",
"--tensorboard-log",
type=empty_str2none,
default="tensorboard_logs",
help="Tensorboard log dir",
)
parser.add_argument(
"--log-interval",
type=int,
default=-1,
help="Override log interval (default: -1, no change)",
)
parser.add_argument(
"-uuid",
"--uuid",
type=str2bool,
default=False,
help="Ensure that the run has a unique ID",
)
# Evaluation
parser.add_argument(
"--eval-freq",
type=int,
default=-1,
help="Evaluate the agent every n steps (if negative, no evaluation)",
)
parser.add_argument(
"--eval-episodes",
type=int,
default=5,
help="Number of episodes to use for evaluation",
)
# Verbosity
parser.add_argument(
"--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
)
# HER specifics
parser.add_argument(
"--truncate-last-trajectory",
type=str2bool,
default=True,
help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
)
args, unknown = parser.parse_known_args()
main(args=args)