#!/usr/bin/env -S python3 -O import argparse import difflib import os import uuid from typing import Dict import gymnasium as gym import numpy as np import torch as th from stable_baselines3.common.utils import set_random_seed from rbs_gym import envs as gz_envs from rbs_gym.utils.exp_manager import ExperimentManager from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool def main(args: Dict): # Check if the selected environment is valid # If it could not be found, suggest the closest match registered_envs = set(gym.envs.registry.keys()) if args.env not in registered_envs: try: closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0] except IndexError: closest_match = "'no close match found...'" raise ValueError( f"{args.env} not found in gym registry, you maybe meant {closest_match}?" ) # If no specific seed is selected, choose a random one if args.seed < 0: args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item() # Set the random seed across platforms set_random_seed(args.seed) # Setting num threads to 1 makes things run faster on cpu if args.num_threads > 0: if args.verbose > 1: print(f"Setting torch.num_threads to {args.num_threads}") th.set_num_threads(args.num_threads) # Verify that pre-trained agent exists before continuing to train it if args.trained_agent != "": assert args.trained_agent.endswith(".zip") and os.path.isfile( args.trained_agent ), "The trained_agent must be a valid path to a .zip file" # If enabled, ensure that the run has a unique ID uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" print("=" * 10, args.env, "=" * 10) print(f"Seed: {args.seed}") env_kwargs = { "render_mode": "human" } exp_manager = ExperimentManager( args, args.algo, args.env, args.log_folder, args.tensorboard_log, args.n_timesteps, args.eval_freq, args.eval_episodes, args.save_freq, args.hyperparams, args.env_kwargs, args.trained_agent, truncate_last_trajectory=args.truncate_last_trajectory, uuid_str=uuid_str, seed=args.seed, log_interval=args.log_interval, save_replay_buffer=args.save_replay_buffer, preload_replay_buffer=args.preload_replay_buffer, verbose=args.verbose, vec_env_type=args.vec_env, ) # Prepare experiment model = exp_manager.setup_experiment() exp_manager.learn(model) exp_manager.save_trained_model(model) if __name__ == "__main__": parser = argparse.ArgumentParser() # Environment and its parameters parser.add_argument( "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID" ) parser.add_argument( "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor", ) parser.add_argument( "--vec-env", type=str, choices=["dummy", "subproc"], default="dummy", help="Type of VecEnv to use", ) # Algorithm and training parser.add_argument( "--algo", type=str, choices=list(ALGOS.keys()), required=False, default="sac", help="RL algorithm to use during the training", ) parser.add_argument( "-params", "--hyperparams", type=str, nargs="+", action=StoreDict, help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)", ) parser.add_argument( "-n", "--n-timesteps", type=int, default=-1, help="Overwrite the number of timesteps", ) parser.add_argument( "--num-threads", type=int, default=-1, help="Number of threads for PyTorch (-1 to use default)", ) # Continue training an already trained agent parser.add_argument( "-i", "--trained-agent", type=str, default="", help="Path to a pretrained agent to continue training", ) # Random seed parser.add_argument("--seed", type=int, default=-1, help="Random generator seed") # Saving of model parser.add_argument( "--save-freq", type=int, default=10000, help="Save the model every n steps (if negative, no checkpoint)", ) parser.add_argument( "--save-replay-buffer", type=str2bool, default=False, help="Save the replay buffer too (when applicable)", ) # Pre-load a replay buffer and start training on it parser.add_argument( "--preload-replay-buffer", type=empty_str2none, default="", help="Path to a replay buffer that should be preloaded before starting the training process", ) # Logging parser.add_argument( "-f", "--log-folder", type=str, default="logs", help="Path to the log directory" ) parser.add_argument( "-tb", "--tensorboard-log", type=empty_str2none, default="tensorboard_logs", help="Tensorboard log dir", ) parser.add_argument( "--log-interval", type=int, default=-1, help="Override log interval (default: -1, no change)", ) parser.add_argument( "-uuid", "--uuid", type=str2bool, default=False, help="Ensure that the run has a unique ID", ) # Evaluation parser.add_argument( "--eval-freq", type=int, default=-1, help="Evaluate the agent every n steps (if negative, no evaluation)", ) parser.add_argument( "--eval-episodes", type=int, default=5, help="Number of episodes to use for evaluation", ) # Verbosity parser.add_argument( "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)" ) # HER specifics parser.add_argument( "--truncate-last-trajectory", type=str2bool, default=True, help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.", ) args, unknown = parser.parse_known_args() main(args=args)