runtime/env_manager/rbs_gym/launch/train.launch.py

519 lines
17 KiB
Python
Raw Normal View History

from launch import LaunchDescription
from launch.actions import (
DeclareLaunchArgument,
IncludeLaunchDescription,
OpaqueFunction,
SetEnvironmentVariable,
TimerAction
)
from launch.launch_description_sources import PythonLaunchDescriptionSource
from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
from launch_ros.substitutions import FindPackageShare
from launch_ros.actions import Node
import os
from os import cpu_count
from ament_index_python.packages import get_package_share_directory
import yaml
import xacro
def launch_setup(context, *args, **kwargs):
# Initialize Arguments
robot_type = LaunchConfiguration("robot_type")
# General arguments
with_gripper_condition = LaunchConfiguration("with_gripper")
description_package = LaunchConfiguration("description_package")
description_file = LaunchConfiguration("description_file")
use_moveit = LaunchConfiguration("use_moveit")
moveit_config_package = LaunchConfiguration("moveit_config_package")
moveit_config_file = LaunchConfiguration("moveit_config_file")
use_sim_time = LaunchConfiguration("use_sim_time")
scene_config_file = LaunchConfiguration("scene_config_file").perform(context)
ee_link_name = LaunchConfiguration("ee_link_name").perform(context)
base_link_name = LaunchConfiguration("base_link_name").perform(context)
control_space = LaunchConfiguration("control_space").perform(context)
control_strategy = LaunchConfiguration("control_strategy").perform(context)
interactive = LaunchConfiguration("interactive").perform(context)
# training arguments
env = LaunchConfiguration("env")
algo = LaunchConfiguration("algo")
hyperparams = LaunchConfiguration("hyperparams")
n_timesteps = LaunchConfiguration("n_timesteps")
num_threads = LaunchConfiguration("num_threads")
seed = LaunchConfiguration("seed")
trained_agent = LaunchConfiguration("trained_agent")
save_freq = LaunchConfiguration("save_freq")
save_replay_buffer = LaunchConfiguration("save_replay_buffer")
preload_replay_buffer = LaunchConfiguration("preload_replay_buffer")
log_folder = LaunchConfiguration("log_folder")
tensorboard_log = LaunchConfiguration("tensorboard_log")
log_interval = LaunchConfiguration("log_interval")
uuid = LaunchConfiguration("uuid")
eval_freq = LaunchConfiguration("eval_freq")
eval_episodes = LaunchConfiguration("eval_episodes")
verbose = LaunchConfiguration("verbose")
truncate_last_trajectory = LaunchConfiguration("truncate_last_trajectory")
use_sim_time = LaunchConfiguration("use_sim_time")
log_level = LaunchConfiguration("log_level")
env_kwargs = LaunchConfiguration("env_kwargs")
track = LaunchConfiguration("track")
description_package_abs_path = get_package_share_directory(
description_package.perform(context)
)
simulation_controllers = os.path.join(
description_package_abs_path, "config", "controllers.yaml"
)
xacro_file = os.path.join(
description_package_abs_path,
"urdf",
description_file.perform(context),
)
xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml"
# TODO: hide this to another place
# Load xacro_args
def param_constructor(loader, node, local_vars):
value = loader.construct_scalar(node)
return LaunchConfiguration(value).perform(
local_vars.get("context", "Launch context if not defined")
)
def variable_constructor(loader, node, local_vars):
value = loader.construct_scalar(node)
return local_vars.get(value, f"Variable '{value}' not found")
def load_xacro_args(yaml_file, local_vars):
# Get valut from ros2 argument
yaml.add_constructor(
"!param", lambda loader, node: param_constructor(loader, node, local_vars)
)
# Get value from local variable in this code
# The local variable should be initialized before the loader was called
yaml.add_constructor(
"!variable",
lambda loader, node: variable_constructor(loader, node, local_vars),
)
with open(yaml_file, "r") as file:
return yaml.load(file, Loader=yaml.FullLoader)
mappings_data = load_xacro_args(xacro_config_file, locals())
robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data)
robot_description_semantic_content = ""
if use_moveit.perform(context) == "true":
srdf_config_file = f"{description_package_abs_path}/config/srdf_xacro_args.yaml"
srdf_file = os.path.join(
get_package_share_directory(moveit_config_package.perform(context)),
"srdf",
moveit_config_file.perform(context),
)
srdf_mappings = load_xacro_args(srdf_config_file, locals())
robot_description_semantic_content = xacro.process_file(srdf_file, mappings=srdf_mappings)
robot_description_semantic_content = robot_description_semantic_content.toprettyxml(indent=" ")
control_space = "joint"
control_strategy = "position"
interactive = "false"
robot_description_content = robot_description_doc.toprettyxml(indent=" ")
robot_description = {"robot_description": robot_description_content}
single_robot_setup = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
PathJoinSubstitution(
[FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"]
)
]
),
launch_arguments={
"with_gripper": with_gripper_condition,
"controllers_file": simulation_controllers,
"robot_type": robot_type,
"description_package": description_package,
"description_file": description_file,
"robot_name": robot_type,
"use_moveit": use_moveit,
"moveit_config_package": moveit_config_package,
"moveit_config_file": moveit_config_file,
"use_sim_time": use_sim_time,
"use_controllers": "true",
"robot_description": robot_description_content,
"robot_description_semantic": robot_description_semantic_content,
"base_link_name": base_link_name,
"ee_link_name": ee_link_name,
"control_space": control_space,
"control_strategy": control_strategy,
"interactive_control": interactive,
}.items(),
)
args = [
"--env",
env,
"--env-kwargs",
env_kwargs,
"--algo",
algo,
"--hyperparams",
hyperparams,
"--n-timesteps",
n_timesteps,
"--num-threads",
num_threads,
"--seed",
seed,
"--trained-agent",
trained_agent,
"--save-freq",
save_freq,
"--save-replay-buffer",
save_replay_buffer,
"--preload-replay-buffer",
preload_replay_buffer,
"--log-folder",
log_folder,
"--tensorboard-log",
tensorboard_log,
"--log-interval",
log_interval,
"--uuid",
uuid,
"--eval-freq",
eval_freq,
"--eval-episodes",
eval_episodes,
"--verbose",
verbose,
"--track",
track,
"--truncate-last-trajectory",
truncate_last_trajectory,
"--ros-args",
"--log-level",
log_level,
]
clock_bridge = Node(
package='ros_gz_bridge',
executable='parameter_bridge',
arguments=['/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock'],
output='screen')
rl_task = Node(
package="rbs_gym",
executable="train",
output="log",
arguments=args,
parameters=[{"use_sim_time": True}]
)
delay_robot_control_stack = TimerAction(
period=20.0,
actions=[single_robot_setup]
)
nodes_to_start = [
# env,
rl_task,
clock_bridge,
delay_robot_control_stack
]
return nodes_to_start
def generate_launch_description():
declared_arguments = []
declared_arguments.append(
DeclareLaunchArgument(
"robot_type",
description="Type of robot by name",
choices=[
"rbs_arm",
"ar4",
"ur3",
"ur3e",
"ur5",
"ur5e",
"ur10",
"ur10e",
"ur16e",
],
default_value="rbs_arm",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"description_package",
default_value="rbs_arm",
description="Description package with robot URDF/XACRO files. Usually the argument \
is not set, it enables use of a custom description.",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"description_file",
default_value="rbs_arm_modular.xacro",
description="URDF/XACRO description file with the robot.",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"robot_name",
default_value="arm0",
description="Name for robot, used to apply namespace for specific robot in multirobot setup",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"moveit_config_package",
default_value="rbs_arm",
description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \
is not set, it enables use of a custom moveit config.",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"moveit_config_file",
default_value="rbs_arm.srdf.xacro",
description="MoveIt SRDF/XACRO description file with the robot.",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"use_sim_time",
default_value="true",
description="Make MoveIt to use simulation time.\
This is needed for the trajectory planing in simulation.",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"with_gripper", default_value="true", description="With gripper or not?"
)
)
declared_arguments.append(
DeclareLaunchArgument(
"use_moveit", default_value="false", description="Launch moveit?"
)
)
declared_arguments.append(
DeclareLaunchArgument(
"launch_perception", default_value="false", description="Launch perception?"
)
)
declared_arguments.append(
DeclareLaunchArgument(
"use_controllers",
default_value="true",
description="Launch controllers?",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"scene_config_file",
default_value="",
description="Path to a scene configuration file",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"ee_link_name",
default_value="",
description="End effector name of robot arm",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"base_link_name",
default_value="",
description="Base link name if robot arm",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"control_space",
default_value="task",
choices=["task", "joint"],
description="Specify the control space for the robot (e.g., task space).",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"control_strategy",
default_value="position",
choices=["position", "velocity", "effort"],
description="Specify the control strategy (e.g., position control).",
)
)
declared_arguments.append(
DeclareLaunchArgument(
"interactive",
default_value="true",
description="Wheter to run the motion_control_handle controller",
),
)
# training arguments
declared_arguments.append(
DeclareLaunchArgument(
"env",
default_value="Reach-Gazebo-v0",
description="Environment ID",
))
declared_arguments.append(
DeclareLaunchArgument(
"env_kwargs",
default_value="",
description="Optional keyword argument to pass to the env constructor.",
))
declared_arguments.append(
DeclareLaunchArgument(
"vec_env",
default_value="dummy",
description="Type of VecEnv to use (dummy or subproc).",
))
# Algorithm and training
declared_arguments.append(
DeclareLaunchArgument(
"algo",
default_value="sac",
description="RL algorithm to use during the training.",
))
declared_arguments.append(
DeclareLaunchArgument(
"n_timesteps",
default_value="-1",
description="Overwrite the number of timesteps.",
))
declared_arguments.append(
DeclareLaunchArgument(
"hyperparams",
default_value="",
description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).",
))
declared_arguments.append(
DeclareLaunchArgument(
"num_threads",
default_value="-1",
description="Number of threads for PyTorch (-1 to use default).",
))
# Continue training an already trained agent
declared_arguments.append(
DeclareLaunchArgument(
"trained_agent",
default_value="",
description="Path to a pretrained agent to continue training.",
))
# Random seed
declared_arguments.append(
DeclareLaunchArgument(
"seed",
default_value="-1",
description="Random generator seed.",
))
# Saving of model
declared_arguments.append(
DeclareLaunchArgument(
"save_freq",
default_value="10000",
description="Save the model every n steps (if negative, no checkpoint).",
))
declared_arguments.append(
DeclareLaunchArgument(
"save_replay_buffer",
default_value="False",
description="Save the replay buffer too (when applicable).",
))
# Pre-load a replay buffer and start training on it
declared_arguments.append(
DeclareLaunchArgument(
"preload_replay_buffer",
default_value="",
description="Path to a replay buffer that should be preloaded before starting the training process.",
))
# Logging
declared_arguments.append(
DeclareLaunchArgument(
"log_folder",
default_value="logs",
description="Path to the log directory.",
))
declared_arguments.append(
DeclareLaunchArgument(
"tensorboard_log",
default_value="tensorboard_logs",
description="Tensorboard log dir.",
))
declared_arguments.append(
DeclareLaunchArgument(
"log_interval",
default_value="-1",
description="Override log interval (default: -1, no change).",
))
declared_arguments.append(
DeclareLaunchArgument(
"uuid",
default_value="False",
description="Ensure that the run has a unique ID.",
))
# Evaluation
declared_arguments.append(
DeclareLaunchArgument(
"eval_freq",
default_value="-1",
description="Evaluate the agent every n steps (if negative, no evaluation).",
))
declared_arguments.append(
DeclareLaunchArgument(
"eval_episodes",
default_value="5",
description="Number of episodes to use for evaluation.",
))
# Verbosity
declared_arguments.append(
DeclareLaunchArgument(
"verbose",
default_value="1",
description="Verbose mode (0: no output, 1: INFO).",
))
declared_arguments.append(
DeclareLaunchArgument(
"truncate_last_trajectory",
default_value="True",
description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer."
))
declared_arguments.append(
DeclareLaunchArgument(
"log_level",
default_value="error",
description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
))
declared_arguments.append(
DeclareLaunchArgument(
"track",
default_value="true",
description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
))
env_variables = [
SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"),
SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2))
]
return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables)