Added rbs_gym package for RL & multi-robot launch setup

This commit is contained in:
Ilya Uraev 2024-07-04 11:38:08 +00:00 committed by Igor Brylyov
parent f92670cd0d
commit b58307dea1
103 changed files with 15170 additions and 653 deletions

View file

@ -0,0 +1,294 @@
#!/usr/bin/env -S python3 -O
import argparse
import os
from typing import Dict
import numpy as np
import torch as th
import yaml
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecEnvWrapper
from rbs_gym import envs as gz_envs
from rbs_gym.utils import create_test_env, get_latest_run_id, get_saved_hyperparams
from rbs_gym.utils.utils import ALGOS, StoreDict, str2bool
def main(args: Dict):
if args.exp_id == 0:
args.exp_id = get_latest_run_id(
os.path.join(args.log_folder, args.algo), args.env
)
print(f"Loading latest experiment, id={args.exp_id}")
# Sanity checks
if args.exp_id > 0:
log_path = os.path.join(args.log_folder, args.algo, f"{args.env}_{args.exp_id}")
else:
log_path = os.path.join(args.log_folder, args.algo)
assert os.path.isdir(log_path), f"The {log_path} folder was not found"
found = False
for ext in ["zip"]:
model_path = os.path.join(log_path, f"{args.env}.{ext}")
found = os.path.isfile(model_path)
if found:
break
if args.load_best:
model_path = os.path.join(log_path, "best_model.zip")
found = os.path.isfile(model_path)
if args.load_checkpoint is not None:
model_path = os.path.join(
log_path, f"rl_model_{args.load_checkpoint}_steps.zip"
)
found = os.path.isfile(model_path)
if not found:
raise ValueError(
f"No model found for {args.algo} on {args.env}, path: {model_path}"
)
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
if args.algo in off_policy_algos:
args.n_envs = 1
set_random_seed(args.seed)
if args.num_threads > 0:
if args.verbose > 1:
print(f"Setting torch.num_threads to {args.num_threads}")
th.set_num_threads(args.num_threads)
stats_path = os.path.join(log_path, args.env)
hyperparams, stats_path = get_saved_hyperparams(
stats_path, norm_reward=args.norm_reward, test_mode=True
)
# load env_kwargs if existing
env_kwargs = {}
args_path = os.path.join(log_path, args.env, "args.yml")
if os.path.isfile(args_path):
with open(args_path, "r") as f:
# pytype: disable=module-attr
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
if loaded_args["env_kwargs"] is not None:
env_kwargs = loaded_args["env_kwargs"]
# overwrite with command line arguments
if args.env_kwargs is not None:
env_kwargs.update(args.env_kwargs)
log_dir = args.reward_log if args.reward_log != "" else None
env = create_test_env(
args.env,
n_envs=args.n_envs,
stats_path=stats_path,
seed=args.seed,
log_dir=log_dir,
should_render=not args.no_render,
hyperparams=hyperparams,
env_kwargs=env_kwargs,
)
kwargs = dict(seed=args.seed)
if args.algo in off_policy_algos:
# Dummy buffer size as we don't need memory to evaluate the trained agent
kwargs.update(dict(buffer_size=1))
custom_objects = {'observation_space': env.observation_space, 'action_space': env.action_space}
model = ALGOS[args.algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs)
obs = env.reset()
# Deterministic by default
stochastic = args.stochastic
deterministic = not stochastic
print(
f"Evaluating for {args.n_episodes} episodes with a",
"deterministic" if deterministic else "stochastic",
"policy.",
)
state = None
episode_reward = 0.0
episode_rewards, episode_lengths, success_episode_lengths = [], [], []
ep_len = 0
episode = 0
# For HER, monitor success rate
successes = []
while episode < args.n_episodes:
action, state = model.predict(obs, state=state, deterministic=deterministic)
obs, reward, done, infos = env.step(action)
if not args.no_render:
env.render("human")
episode_reward += reward[0]
ep_len += 1
if done and args.verbose > 0:
episode += 1
print(f"--- Episode {episode}/{args.n_episodes}")
# NOTE: for env using VecNormalize, the mean reward
# is a normalized reward when `--norm_reward` flag is passed
print(f"Episode Reward: {episode_reward:.2f}")
episode_rewards.append(episode_reward)
print("Episode Length", ep_len)
episode_lengths.append(ep_len)
if infos[0].get("is_success") is not None:
print("Success?:", infos[0].get("is_success", False))
successes.append(infos[0].get("is_success", False))
if infos[0].get("is_success"):
success_episode_lengths.append(ep_len)
print(f"Current success rate: {100 * np.mean(successes):.2f}%")
episode_reward = 0.0
ep_len = 0
state = None
if args.verbose > 0 and len(successes) > 0:
print(f"Success rate: {100 * np.mean(successes):.2f}%")
if args.verbose > 0 and len(episode_rewards) > 0:
print(
f"Mean reward: {np.mean(episode_rewards):.2f} "
f"+/- {np.std(episode_rewards):.2f}"
)
if args.verbose > 0 and len(episode_lengths) > 0:
print(
f"Mean episode length: {np.mean(episode_lengths):.2f} "
f"+/- {np.std(episode_lengths):.2f}"
)
if args.verbose > 0 and len(success_episode_lengths) > 0:
print(
f"Mean episode length of successful episodes: {np.mean(success_episode_lengths):.2f} "
f"+/- {np.std(success_episode_lengths):.2f}"
)
# Workaround for https://github.com/openai/gym/issues/893
if not args.no_render:
if args.n_envs == 1 and "Bullet" not in args.env and isinstance(env, VecEnv):
# DummyVecEnv
# Unwrap env
while isinstance(env, VecEnvWrapper):
env = env.venv
if isinstance(env, DummyVecEnv):
env.envs[0].env.close()
else:
env.close()
else:
# SubprocVecEnv
env.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Environment and its parameters
parser.add_argument(
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
)
parser.add_argument(
"--env-kwargs",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor",
)
parser.add_argument("--n-envs", type=int, default=1, help="Number of environments")
# Algorithm
parser.add_argument(
"--algo",
type=str,
choices=list(ALGOS.keys()),
required=False,
default="sac",
help="RL algorithm to use during the training",
)
parser.add_argument(
"--num-threads",
type=int,
default=-1,
help="Number of threads for PyTorch (-1 to use default)",
)
# Test duration
parser.add_argument(
"-n",
"--n-episodes",
type=int,
default=200,
help="Number of evaluation episodes",
)
# Random seed
parser.add_argument("--seed", type=int, default=0, help="Random generator seed")
# Model to test
parser.add_argument(
"-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
)
parser.add_argument(
"--exp-id",
type=int,
default=0,
help="Experiment ID (default: 0: latest, -1: no exp folder)",
)
parser.add_argument(
"--load-best",
type=str2bool,
default=False,
help="Load best model instead of last model if available",
)
parser.add_argument(
"--load-checkpoint",
type=int,
help="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it",
)
# Deterministic/stochastic actions
parser.add_argument(
"--stochastic",
type=str2bool,
default=False,
help="Use stochastic actions instead of deterministic",
)
# Logging
parser.add_argument(
"--reward-log", type=str, default="reward_logs", help="Where to log reward"
)
parser.add_argument(
"--norm-reward",
type=str2bool,
default=False,
help="Normalize reward if applicable (trained with VecNormalize)",
)
# Disable render
parser.add_argument(
"--no-render",
type=str2bool,
default=False,
help="Do not render the environment (useful for tests)",
)
# Verbosity
parser.add_argument(
"--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
)
args, unknown = parser.parse_known_args()
main(args)

View file

@ -0,0 +1,233 @@
#!/usr/bin/env -S python3 -O
import argparse
import difflib
import os
import uuid
from typing import Dict
import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3.common.utils import set_random_seed
from rbs_gym import envs as gz_envs
from rbs_gym.utils.exp_manager import ExperimentManager
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
def main(args: Dict):
# Check if the selected environment is valid
# If it could not be found, suggest the closest match
registered_envs = set(gym.envs.registry.keys())
if args.env not in registered_envs:
try:
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
except IndexError:
closest_match = "'no close match found...'"
raise ValueError(
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
)
# If no specific seed is selected, choose a random one
if args.seed < 0:
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
# Set the random seed across platforms
set_random_seed(args.seed)
# Setting num threads to 1 makes things run faster on cpu
if args.num_threads > 0:
if args.verbose > 1:
print(f"Setting torch.num_threads to {args.num_threads}")
th.set_num_threads(args.num_threads)
# Verify that pre-trained agent exists before continuing to train it
if args.trained_agent != "":
assert args.trained_agent.endswith(".zip") and os.path.isfile(
args.trained_agent
), "The trained_agent must be a valid path to a .zip file"
# If enabled, ensure that the run has a unique ID
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
print("=" * 10, args.env, "=" * 10)
print(f"Seed: {args.seed}")
env_kwargs = {
"render_mode": "human"
}
exp_manager = ExperimentManager(
args,
args.algo,
args.env,
args.log_folder,
args.tensorboard_log,
args.n_timesteps,
args.eval_freq,
args.eval_episodes,
args.save_freq,
args.hyperparams,
args.env_kwargs,
args.trained_agent,
truncate_last_trajectory=args.truncate_last_trajectory,
uuid_str=uuid_str,
seed=args.seed,
log_interval=args.log_interval,
save_replay_buffer=args.save_replay_buffer,
preload_replay_buffer=args.preload_replay_buffer,
verbose=args.verbose,
vec_env_type=args.vec_env,
)
# Prepare experiment
model = exp_manager.setup_experiment()
exp_manager.learn(model)
exp_manager.save_trained_model(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Environment and its parameters
parser.add_argument(
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
)
parser.add_argument(
"--env-kwargs",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor",
)
parser.add_argument(
"--vec-env",
type=str,
choices=["dummy", "subproc"],
default="dummy",
help="Type of VecEnv to use",
)
# Algorithm and training
parser.add_argument(
"--algo",
type=str,
choices=list(ALGOS.keys()),
required=False,
default="sac",
help="RL algorithm to use during the training",
)
parser.add_argument(
"-params",
"--hyperparams",
type=str,
nargs="+",
action=StoreDict,
help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-n",
"--n-timesteps",
type=int,
default=-1,
help="Overwrite the number of timesteps",
)
parser.add_argument(
"--num-threads",
type=int,
default=-1,
help="Number of threads for PyTorch (-1 to use default)",
)
# Continue training an already trained agent
parser.add_argument(
"-i",
"--trained-agent",
type=str,
default="",
help="Path to a pretrained agent to continue training",
)
# Random seed
parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
# Saving of model
parser.add_argument(
"--save-freq",
type=int,
default=10000,
help="Save the model every n steps (if negative, no checkpoint)",
)
parser.add_argument(
"--save-replay-buffer",
type=str2bool,
default=False,
help="Save the replay buffer too (when applicable)",
)
# Pre-load a replay buffer and start training on it
parser.add_argument(
"--preload-replay-buffer",
type=empty_str2none,
default="",
help="Path to a replay buffer that should be preloaded before starting the training process",
)
# Logging
parser.add_argument(
"-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
)
parser.add_argument(
"-tb",
"--tensorboard-log",
type=empty_str2none,
default="tensorboard_logs",
help="Tensorboard log dir",
)
parser.add_argument(
"--log-interval",
type=int,
default=-1,
help="Override log interval (default: -1, no change)",
)
parser.add_argument(
"-uuid",
"--uuid",
type=str2bool,
default=False,
help="Ensure that the run has a unique ID",
)
# Evaluation
parser.add_argument(
"--eval-freq",
type=int,
default=-1,
help="Evaluate the agent every n steps (if negative, no evaluation)",
)
parser.add_argument(
"--eval-episodes",
type=int,
default=5,
help="Number of episodes to use for evaluation",
)
# Verbosity
parser.add_argument(
"--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
)
# HER specifics
parser.add_argument(
"--truncate-last-trajectory",
type=str2bool,
default=True,
help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
)
args, unknown = parser.parse_known_args()
main(args=args)

View file

@ -0,0 +1,200 @@
#!/usr/bin/env python3
import time
import gym_gz_models
import gym_gz
from scenario import gazebo as scenario_gazebo
from scenario import core as scenario_core
import rclpy
from rclpy.node import Node
from scipy.spatial.transform import Rotation as R
import numpy as np
from geometry_msgs.msg import PoseStamped
from rclpy.executors import MultiThreadedExecutor
from rbs_skill_servers import CartesianControllerPublisher, TakePose
from rclpy.action import ActionClient
from control_msgs.action import GripperCommand
class Spawner(Node):
def __init__(self):
super().__init__("spawner")
self.gazebo = scenario_gazebo.GazeboSimulator(step_size=0.001,
rtf=1.0,
steps_per_run=1)
self.cartesian_pose = self.create_publisher(
PoseStamped,
"/" + "arm0" + "/cartesian_motion_controller/target_frame", 10)
self.current_pose_sub = self.create_subscription(PoseStamped,
"/arm0/cartesian_motion_controller/current_pose", self.callback, 10)
self._action_client = ActionClient(self,
GripperCommand,
"/" + "arm0" + "/gripper_controller/gripper_cmd")
timer_period = 0.001 # seconds
self.timer = self.create_timer(timer_period, self.timer_callback)
self.ano_timer = self.create_timer(timer_period, self.another_timer)
scenario_gazebo.set_verbosity(scenario_gazebo.Verbosity_info)
self.gazebo.insert_world_from_sdf(
"/home/bill-finger/rbs_ws/install/rbs_simulation/share/rbs_simulation/worlds/asm2.sdf")
self.gazebo.initialize()
self.world = self.gazebo.get_world()
self.current_pose: PoseStamped = PoseStamped()
self.init_sim()
self.cube = self.world.get_model("cube")
self.stage = 0
self.gripper_open = False
def callback(self, msg: PoseStamped):
self.current_pose = msg
def timer_callback(self):
self.gazebo.run()
def send_goal(self, goal: float):
goal_msg = GripperCommand.Goal()
goal_msg._command.position = goal
goal_msg._command.max_effort = 1.0
self._action_client.wait_for_server()
self.gripper_open = not self.gripper_open
self._send_goal_future = self._action_client.send_goal_async(goal_msg)
self._send_goal_future.add_done_callback(self.goal_response_callback)
def goal_response_callback(self, future):
goal_handle = future.result()
if not goal_handle.accepted:
self.get_logger().info('Goal rejected :(')
return
self.get_logger().info('Goal accepted :)')
self._get_result_future = goal_handle.get_result_async()
self._get_result_future.add_done_callback(self.get_result_callback)
def get_result_callback(self, future):
result = future.result().result
self.get_logger().info('Result: {0}'.format(result.position))
def another_timer(self):
position_over_cube = np.array(self.cube.base_position()) + np.array([0, 0, 0.2])
position_cube = np.array(self.cube.base_position()) + np.array([0, 0, 0.03])
quat_xyzw = R.from_euler(seq="y", angles=180, degrees=True).as_quat()
if self.stage == 0:
if self.distance_to_target(position_over_cube, quat_xyzw) > 0.01:
self.cartesian_pose.publish(self.get_pose(position_over_cube, quat_xyzw))
if self.distance_to_target(position_over_cube, quat_xyzw) < 0.01:
self.stage += 1
if self.stage == 1:
if self.distance_to_target(position_cube, quat_xyzw) > 0.01:
if not self.gripper_open:
self.send_goal(0.064)
# rclpy.spin_until_future_complete(self, future)
self.cartesian_pose.publish(self.get_pose(position_cube, quat_xyzw))
if self.distance_to_target(position_cube, quat_xyzw) < 0.01:
self.stage += 1
def distance_to_target(self, position, orientation):
target_pose = self.get_pose(position, orientation)
current_position = np.array([
self.current_pose.pose.position.x,
self.current_pose.pose.position.y,
self.current_pose.pose.position.z
])
target_position = np.array([
target_pose.pose.position.x,
target_pose.pose.position.y,
target_pose.pose.position.z
])
distance = np.linalg.norm(current_position - target_position)
return distance
def init_sim(self):
# Create the simulator
self.gazebo.gui()
self.gazebo.run(paused=True)
self.world.to_gazebo().set_gravity((0, 0, -9.8))
self.world.insert_model("/home/bill-finger/rbs_ws/current.urdf")
self.gazebo.run(paused=True)
for model_name in self.world.model_names():
model = self.world.get_model(model_name)
print(f"Model: {model_name}")
print(f" Base link: {model.base_frame()}")
print("LINKS")
for name in model.link_names():
position = model.get_link(name).position()
orientation_wxyz = np.asarray(model.get_link(name).orientation())
orientation = R.from_quat(orientation_wxyz[[1, 2, 3, 0]]).as_euler("xyz")
print(f" {name}:", (*position, *tuple(orientation)))
print("JOINTS")
for name in model.joint_names():
print(f"{name}")
uri = lambda org, name: f"https://fuel.gazebosim.org/{org}/models/{name}"
# Download the cube SDF file
cube_sdf = scenario_gazebo.get_model_file_from_fuel(
uri=uri(org="openrobotics", name="wood cube 5cm"), use_cache=False
)
# Sample a random position
random_position = np.random.uniform(low=[-0.2, -0.2, 0.0], high=[-0.3, 0.2, 0.0])
# Get a unique name
model_name = gym_gz.utils.scenario.get_unique_model_name(
world=self.world, model_name="cube"
)
# Insert the model
assert self.world.insert_model(
cube_sdf, scenario_core.Pose(random_position, [1.0, 0, 0, 0]), model_name
)
model = self.world.get_model("rbs_arm")
self.cube = self.world.get_model("cube")
ok_reset_pos = model.to_gazebo().reset_joint_positions(
[0.0, -0.240, -3.142, 1.090, 0, 1.617, 0.0, 0.0, 0.0],
[name for name in model.joint_names() if "_joint" in name]
)
if not ok_reset_pos:
raise RuntimeError("Failed to reset the robot state")
def get_pose(self, position, orientation) -> PoseStamped:
msg = PoseStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = "base_link"
msg.pose.position.x = position[0]
msg.pose.position.y = position[1]
msg.pose.position.z = position[2]
msg.pose.orientation.x = orientation[0]
msg.pose.orientation.y = orientation[1]
msg.pose.orientation.z = orientation[2]
msg.pose.orientation.w = orientation[3]
return msg
def main(args=None):
rclpy.init(args=args)
executor = MultiThreadedExecutor()
my_node = Spawner()
executor.add_node(my_node)
executor.spin()
my_node.gazebo.close()
my_node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,101 @@
#!/usr/bin/env python3
import argparse
from typing import Dict
import gymnasium as gym
from stable_baselines3.common.env_checker import check_env
from rbs_gym import envs as gz_envs
from rbs_gym.utils.utils import StoreDict, str2bool
def main(args: Dict):
# Create the environment
env = gym.make(args.env, **args.env_kwargs)
# Initialize random seed
env.seed(args.seed)
# Check the environment
if args.check_env:
check_env(env, warn=True, skip_render_check=True)
# Step environment for bunch of episodes
for episode in range(args.n_episodes):
# Initialize returned values
done = False
total_reward = 0
# Reset the environment
observation = env.reset()
# Step through the current episode until it is done
while not done:
# Sample random action
action = env.action_space.sample()
# Step the environment with the random action
observation, reward, truncated, terminated, info = env.step(action)
done = truncated or terminated
# Accumulate the reward
total_reward += reward
print(f"Episode #{episode}\n\treward: {total_reward}")
# Cleanup once done
env.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Environment and its parameters
parser.add_argument(
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
)
parser.add_argument(
"--env-kwargs",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor",
)
# Number of episodes to run
parser.add_argument(
"-n",
"--n-episodes",
type=int,
default=10000,
help="Overwrite the number of episodes",
)
# Random seed
parser.add_argument("--seed", type=int, default=69, help="Random generator seed")
# Flag to check environment
parser.add_argument(
"--check-env",
type=str2bool,
default=True,
help="Flag to check the environment before running the random agent",
)
# Flag to enable rendering
parser.add_argument(
"--render",
type=str2bool,
default=False,
help="Flag to enable rendering",
)
args, unknown = parser.parse_known_args()
main(args=args)

View file

@ -0,0 +1,284 @@
#!/usr/bin/env -S python3 -O
import argparse
import difflib
import os
import uuid
from typing import Dict
import gymnasium as gym
import numpy as np
import torch as th
from stable_baselines3.common.utils import set_random_seed
from rbs_gym import envs as gz_envs
from rbs_gym.utils.exp_manager import ExperimentManager
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
def main(args: Dict):
# Check if the selected environment is valid
# If it could not be found, suggest the closest match
registered_envs = set(gym.envs.registry.keys())
if args.env not in registered_envs:
try:
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
except IndexError:
closest_match = "'no close match found...'"
raise ValueError(
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
)
# If no specific seed is selected, choose a random one
if args.seed < 0:
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
# Set the random seed across platforms
set_random_seed(args.seed)
# Setting num threads to 1 makes things run faster on cpu
if args.num_threads > 0:
if args.verbose > 1:
print(f"Setting torch.num_threads to {args.num_threads}")
th.set_num_threads(args.num_threads)
# Verify that pre-trained agent exists before continuing to train it
if args.trained_agent != "":
assert args.trained_agent.endswith(".zip") and os.path.isfile(
args.trained_agent
), "The trained_agent must be a valid path to a .zip file"
# If enabled, ensure that the run has a unique ID
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
print("=" * 10, args.env, "=" * 10)
print(f"Seed: {args.seed}")
if args.track:
try:
import wandb
import datetime
except ImportError as e:
raise ImportError(
"if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
) from e
run_name = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
tags = []
run = wandb.init(
name=run_name,
project="rbs-gym",
entity=None,
tags=tags,
config=vars(args),
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=False, # auto-upload the videos of agents playing the game
save_code=True, # optional
)
args.tensorboard_log = f"runs/{run_name}"
exp_manager = ExperimentManager(
args,
args.algo,
args.env,
args.log_folder,
args.tensorboard_log,
args.n_timesteps,
args.eval_freq,
args.eval_episodes,
args.save_freq,
args.hyperparams,
args.env_kwargs,
args.trained_agent,
args.optimize_hyperparameters,
truncate_last_trajectory=args.truncate_last_trajectory,
uuid_str=uuid_str,
seed=args.seed,
log_interval=args.log_interval,
save_replay_buffer=args.save_replay_buffer,
preload_replay_buffer=args.preload_replay_buffer,
verbose=args.verbose,
vec_env_type=args.vec_env,
no_optim_plots=args.no_optim_plots,
)
# Prepare experiment
results = exp_manager.setup_experiment()
if results is not None:
model, saved_hyperparams = results
if args.track:
# we need to save the loaded hyperparameters
args.saved_hyperparams = saved_hyperparams
assert run is not None # make mypy happy
run.config.setdefaults(vars(args))
# Normal training
if model is not None:
exp_manager.learn(model)
exp_manager.save_trained_model(model)
else:
exp_manager.hyperparameters_optimization()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Environment and its parameters
parser.add_argument(
"--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
)
parser.add_argument(
"--env-kwargs",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor",
)
parser.add_argument(
"--vec-env",
type=str,
choices=["dummy", "subproc"],
default="dummy",
help="Type of VecEnv to use",
)
# Algorithm and training
parser.add_argument(
"--algo",
type=str,
choices=list(ALGOS.keys()),
required=False,
default="sac",
help="RL algorithm to use during the training",
)
parser.add_argument(
"-params",
"--hyperparams",
type=str,
nargs="+",
action=StoreDict,
help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
)
parser.add_argument(
"-n",
"--n-timesteps",
type=int,
default=-1,
help="Overwrite the number of timesteps",
)
parser.add_argument(
"--num-threads",
type=int,
default=-1,
help="Number of threads for PyTorch (-1 to use default)",
)
# Continue training an already trained agent
parser.add_argument(
"-i",
"--trained-agent",
type=str,
default="",
help="Path to a pretrained agent to continue training",
)
# Random seed
parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
# Saving of model
parser.add_argument(
"--save-freq",
type=int,
default=10000,
help="Save the model every n steps (if negative, no checkpoint)",
)
parser.add_argument(
"--save-replay-buffer",
type=str2bool,
default=False,
help="Save the replay buffer too (when applicable)",
)
# Pre-load a replay buffer and start training on it
parser.add_argument(
"--preload-replay-buffer",
type=empty_str2none,
default="",
help="Path to a replay buffer that should be preloaded before starting the training process",
)
parser.add_argument(
"--track",
type=str2bool,
default=False,
help="Track experiment using wandb"
)
# optimization parameters
parser.add_argument(
"--optimize-hyperparameters",
type=str2bool,
default=False,
help="Run optimization or not?"
)
parser.add_argument(
"--no-optim-plots", action="store_true", default=False, help="Disable hyperparameter optimization plots"
)
# Logging
parser.add_argument(
"-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
)
parser.add_argument(
"-tb",
"--tensorboard-log",
type=empty_str2none,
default="tensorboard_logs",
help="Tensorboard log dir",
)
parser.add_argument(
"--log-interval",
type=int,
default=-1,
help="Override log interval (default: -1, no change)",
)
parser.add_argument(
"-uuid",
"--uuid",
type=str2bool,
default=False,
help="Ensure that the run has a unique ID",
)
# Evaluation
parser.add_argument(
"--eval-freq",
type=int,
default=-1,
help="Evaluate the agent every n steps (if negative, no evaluation)",
)
parser.add_argument(
"--eval-episodes",
type=int,
default=5,
help="Number of episodes to use for evaluation",
)
# Verbosity
parser.add_argument(
"--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
)
# HER specifics
parser.add_argument(
"--truncate-last-trajectory",
type=str2bool,
default=True,
help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
)
args, unknown = parser.parse_known_args()
main(args=args)

View file

@ -0,0 +1,138 @@
#!/usr/bin/env python3
import rclpy
from rclpy.node import Node
import numpy as np
import quaternion
from geometry_msgs.msg import Twist
from geometry_msgs.msg import PoseStamped
import tf2_ros
import sys
import time
import threading
import os
class Converter(Node):
"""Convert Twist messages to PoseStamped
Use this node to integrate twist messages into a moving target pose in
Cartesian space. An initial TF lookup assures that the target pose always
starts at the robot's end-effector.
"""
def __init__(self):
super().__init__("converter")
self.twist_topic = self.declare_parameter("twist_topic", "/cartesian_motion_controller/target_twist").value
self.pose_topic = self.declare_parameter("pose_topic", "/cartesian_motion_controller/target_frame").value
self.frame_id = self.declare_parameter("frame_id", "base_link").value
self.end_effector = self.declare_parameter("end_effector", "gripper_grasp_point").value
self.tf_buffer = tf2_ros.Buffer()
self.tf_listener = tf2_ros.TransformListener(self.tf_buffer, self)
self.rot = np.quaternion(0, 0, 0, 1)
self.pos = [0, 0, 0]
self.pub = self.create_publisher(PoseStamped, self.pose_topic, 3)
self.sub = self.create_subscription(Twist, self.twist_topic, self.twist_cb, 1)
self.last = time.time()
self.startup_done = False
period = 1.0 / self.declare_parameter("publishing_rate", 100).value
self.timer = self.create_timer(period, self.publish)
self.thread = threading.Thread(target=self.startup, daemon=True)
self.thread.start()
def startup(self):
"""Make sure to start at the robot's current pose"""
# Wait until we entered spinning in the main thread.
time.sleep(1)
try:
start = self.tf_buffer.lookup_transform(
target_frame=self.frame_id,
source_frame=self.end_effector,
time=rclpy.time.Time(),
)
except (
tf2_ros.InvalidArgumentException,
tf2_ros.LookupException,
tf2_ros.ConnectivityException,
tf2_ros.ExtrapolationException,
) as e:
print(f"Startup failed: {e}")
os._exit(1)
self.pos[0] = start.transform.translation.x
self.pos[1] = start.transform.translation.y
self.pos[2] = start.transform.translation.z
self.rot.x = start.transform.rotation.x
self.rot.y = start.transform.rotation.y
self.rot.z = start.transform.rotation.z
self.rot.w = start.transform.rotation.w
self.startup_done = True
def twist_cb(self, data):
"""Numerically integrate twist message into a pose
Use global self.frame_id as reference for the navigation commands.
"""
now = time.time()
dt = now - self.last
self.last = now
# Position update
self.pos[0] += data.linear.x * dt
self.pos[1] += data.linear.y * dt
self.pos[2] += data.linear.z * dt
# Orientation update
wx = data.angular.x
wy = data.angular.y
wz = data.angular.z
_, q = quaternion.integrate_angular_velocity(
lambda _: (wx, wy, wz), 0, dt, self.rot
)
self.rot = q[-1] # the last one is after dt passed
def publish(self):
if not self.startup_done:
return
try:
msg = PoseStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = self.frame_id
msg.pose.position.x = self.pos[0]
msg.pose.position.y = self.pos[1]
msg.pose.position.z = self.pos[2]
msg.pose.orientation.x = self.rot.x
msg.pose.orientation.y = self.rot.y
msg.pose.orientation.z = self.rot.z
msg.pose.orientation.w = self.rot.w
self.pub.publish(msg)
except Exception:
# Swallow 'publish() to closed topic' error.
# This rarely happens on killing this node.
pass
def main(args=None):
rclpy.init(args=args)
node = Converter()
rclpy.spin(node)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
rclpy.shutdown()
sys.exit(0)
except Exception as e:
print(e)
sys.exit(1)