build: migrate env_manager
, rbs_gym
, rbs_runtime
to ament_python
Migrate `env_manager`, `rbs_gym`, and `rbs_runtime` from ament_cmake to ament_python. Removed unnecessary files including .json and .yaml config files
This commit is contained in:
parent
e4e3e4e3af
commit
860f7d6566
40 changed files with 750 additions and 1239 deletions
|
@ -1,31 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.8)
|
||||
project(env_manager)
|
||||
|
||||
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
endif()
|
||||
|
||||
# find dependencies
|
||||
find_package(ament_cmake REQUIRED)
|
||||
# uncomment the following section in order to fill in
|
||||
# further dependencies manually.
|
||||
# find_package(<dependency> REQUIRED)
|
||||
|
||||
ament_python_install_package(${PROJECT_NAME})
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(ament_lint_auto REQUIRED)
|
||||
# the following line skips the linter which checks for copyrights
|
||||
# comment the line when a copyright and license is added to all source files
|
||||
set(ament_cmake_copyright_FOUND TRUE)
|
||||
# the following line skips cpplint (only works in a git repo)
|
||||
# comment the line when this package is in a git repo and when
|
||||
# a copyright and license is added to all source files
|
||||
set(ament_cmake_cpplint_FOUND TRUE)
|
||||
ament_lint_auto_find_test_dependencies()
|
||||
endif()
|
||||
|
||||
|
||||
install(DIRECTORY env_manager/worlds DESTINATION share/${PROJECT_NAME})
|
||||
|
||||
ament_package()
|
|
@ -1,24 +0,0 @@
|
|||
<?xml version="1.0"?>
|
||||
<sdf version="1.9">
|
||||
<world name="rbs_gym_world">
|
||||
<!-- <physics name='1ms' type='ignored'> -->
|
||||
<!-- <dart> -->
|
||||
<!-- <collision_detector>bullet</collision_detector> -->
|
||||
<!-- <solver> -->
|
||||
<!-- <solver_type>pgs</solver_type> -->
|
||||
<!-- </solver> -->
|
||||
<!-- </dart> -->
|
||||
<!-- </physics> -->
|
||||
<!-- <plugin name='ignition::gazebo::systems::Contact' filename='ignition-gazebo-contact-system'/> -->
|
||||
|
||||
<!-- -->
|
||||
<!-- Scene -->
|
||||
<!-- -->
|
||||
<!-- <gravity>0 0 0</gravity> -->
|
||||
<scene>
|
||||
<ambient>1.0 1.0 1.0</ambient>
|
||||
<grid>false</grid>
|
||||
</scene>
|
||||
|
||||
</world>
|
||||
</sdf>
|
|
@ -7,12 +7,12 @@
|
|||
<maintainer email="ur.narmak@gmail.com">narmak</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
|
||||
<test_depend>ament_lint_auto</test_depend>
|
||||
<test_depend>ament_lint_common</test_depend>
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
|
|
0
env_manager/env_manager/resource/env_manager
Normal file
0
env_manager/env_manager/resource/env_manager
Normal file
4
env_manager/env_manager/setup.cfg
Normal file
4
env_manager/env_manager/setup.cfg
Normal file
|
@ -0,0 +1,4 @@
|
|||
[develop]
|
||||
script_dir=$base/lib/env_manager
|
||||
[install]
|
||||
install_scripts=$base/lib/env_manager
|
25
env_manager/env_manager/setup.py
Normal file
25
env_manager/env_manager/setup.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'env_manager'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.0.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='narmak',
|
||||
maintainer_email='ur.narmak@gmail.com',
|
||||
description='TODO: Package description',
|
||||
license='Apache-2.0',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
],
|
||||
},
|
||||
)
|
25
env_manager/env_manager/test/test_copyright.py
Normal file
25
env_manager/env_manager/test/test_copyright.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_copyright.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
# Remove the `skip` decorator once the source file(s) have a copyright header
|
||||
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
|
||||
@pytest.mark.copyright
|
||||
@pytest.mark.linter
|
||||
def test_copyright():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found errors'
|
25
env_manager/env_manager/test/test_flake8.py
Normal file
25
env_manager/env_manager/test/test_flake8.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2017 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_flake8.main import main_with_errors
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.flake8
|
||||
@pytest.mark.linter
|
||||
def test_flake8():
|
||||
rc, errors = main_with_errors(argv=[])
|
||||
assert rc == 0, \
|
||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||
'\n'.join(errors)
|
23
env_manager/env_manager/test/test_pep257.py
Normal file
23
env_manager/env_manager/test/test_pep257.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_pep257.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.linter
|
||||
@pytest.mark.pep257
|
||||
def test_pep257():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found code style errors / warnings'
|
|
@ -1,39 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.8)
|
||||
project(rbs_gym)
|
||||
|
||||
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
endif()
|
||||
|
||||
# find dependencies
|
||||
find_package(ament_cmake REQUIRED)
|
||||
# uncomment the following section in order to fill in
|
||||
# further dependencies manually.
|
||||
# find_package(<dependency> REQUIRED)
|
||||
|
||||
ament_python_install_package(${PROJECT_NAME})
|
||||
|
||||
install(PROGRAMS
|
||||
scripts/train.py
|
||||
scripts/spawner.py
|
||||
scripts/velocity.py
|
||||
scripts/test_agent.py
|
||||
scripts/evaluate.py
|
||||
DESTINATION lib/${PROJECT_NAME}
|
||||
)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(ament_lint_auto REQUIRED)
|
||||
# the following line skips the linter which checks for copyrights
|
||||
# comment the line when a copyright and license is added to all source files
|
||||
set(ament_cmake_copyright_FOUND TRUE)
|
||||
# the following line skips cpplint (only works in a git repo)
|
||||
# comment the line when this package is in a git repo and when
|
||||
# a copyright and license is added to all source files
|
||||
set(ament_cmake_cpplint_FOUND TRUE)
|
||||
ament_lint_auto_find_test_dependencies()
|
||||
endif()
|
||||
|
||||
install(DIRECTORY launch DESTINATION share/${PROJECT_NAME})
|
||||
|
||||
ament_package()
|
|
@ -141,7 +141,7 @@ def launch_setup(context, *args, **kwargs):
|
|||
|
||||
rl_task = Node(
|
||||
package="rbs_gym",
|
||||
executable="evaluate.py",
|
||||
executable="evaluate",
|
||||
output="log",
|
||||
arguments=args,
|
||||
parameters=[{"use_sim_time": use_sim_time}],
|
||||
|
|
|
@ -162,7 +162,7 @@ def launch_setup(context, *args, **kwargs):
|
|||
|
||||
rl_task = Node(
|
||||
package="rbs_gym",
|
||||
executable="train.py",
|
||||
executable="train",
|
||||
output="log",
|
||||
arguments = args,
|
||||
parameters=[{"use_sim_time": True}]
|
||||
|
|
|
@ -120,7 +120,7 @@ def launch_setup(context, *args, **kwargs):
|
|||
|
||||
rl_task = Node(
|
||||
package="rbs_gym",
|
||||
executable="test_agent.py",
|
||||
executable="test_agent",
|
||||
output="log",
|
||||
arguments=args,
|
||||
parameters=[{"use_sim_time": True}, robot_description],
|
||||
|
|
|
@ -161,7 +161,7 @@ def launch_setup(context, *args, **kwargs):
|
|||
|
||||
rl_task = Node(
|
||||
package="rbs_gym",
|
||||
executable="train.py",
|
||||
executable="train",
|
||||
output="log",
|
||||
arguments=args,
|
||||
parameters=[{"use_sim_time": True}]
|
||||
|
|
|
@ -4,15 +4,15 @@
|
|||
<name>rbs_gym</name>
|
||||
<version>0.0.0</version>
|
||||
<description>TODO: Package description</description>
|
||||
<maintainer email="ur.narmak@gmail.com">bill-finger</maintainer>
|
||||
<maintainer email="ur.narmak@gmail.com">narmak</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
|
||||
<test_depend>ament_lint_auto</test_depend>
|
||||
<test_depend>ament_lint_common</test_depend>
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
|
|
|
@ -7,191 +7,36 @@ from typing import Dict
|
|||
import numpy as np
|
||||
import torch as th
|
||||
import yaml
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecEnvWrapper
|
||||
|
||||
from rbs_gym import envs as gz_envs
|
||||
from rbs_gym.utils import create_test_env, get_latest_run_id, get_saved_hyperparams
|
||||
from rbs_gym.utils.utils import ALGOS, StoreDict, str2bool
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecEnvWrapper
|
||||
|
||||
|
||||
def main(args: Dict):
|
||||
|
||||
if args.exp_id == 0:
|
||||
args.exp_id = get_latest_run_id(
|
||||
os.path.join(args.log_folder, args.algo), args.env
|
||||
)
|
||||
print(f"Loading latest experiment, id={args.exp_id}")
|
||||
|
||||
# Sanity checks
|
||||
if args.exp_id > 0:
|
||||
log_path = os.path.join(args.log_folder, args.algo, f"{args.env}_{args.exp_id}")
|
||||
else:
|
||||
log_path = os.path.join(args.log_folder, args.algo)
|
||||
|
||||
assert os.path.isdir(log_path), f"The {log_path} folder was not found"
|
||||
|
||||
found = False
|
||||
for ext in ["zip"]:
|
||||
def find_model_path(log_path: str, args) -> str:
|
||||
model_extensions = ["zip"]
|
||||
for ext in model_extensions:
|
||||
model_path = os.path.join(log_path, f"{args.env}.{ext}")
|
||||
found = os.path.isfile(model_path)
|
||||
if found:
|
||||
break
|
||||
if os.path.isfile(model_path):
|
||||
return model_path
|
||||
|
||||
if args.load_best:
|
||||
model_path = os.path.join(log_path, "best_model.zip")
|
||||
found = os.path.isfile(model_path)
|
||||
best_model_path = os.path.join(log_path, "best_model.zip")
|
||||
if os.path.isfile(best_model_path):
|
||||
return best_model_path
|
||||
|
||||
if args.load_checkpoint is not None:
|
||||
model_path = os.path.join(
|
||||
checkpoint_model_path = os.path.join(
|
||||
log_path, f"rl_model_{args.load_checkpoint}_steps.zip"
|
||||
)
|
||||
found = os.path.isfile(model_path)
|
||||
if os.path.isfile(checkpoint_model_path):
|
||||
return checkpoint_model_path
|
||||
|
||||
if not found:
|
||||
raise ValueError(
|
||||
f"No model found for {args.algo} on {args.env}, path: {model_path}"
|
||||
)
|
||||
|
||||
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
|
||||
|
||||
if args.algo in off_policy_algos:
|
||||
args.n_envs = 1
|
||||
|
||||
set_random_seed(args.seed)
|
||||
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
stats_path = os.path.join(log_path, args.env)
|
||||
hyperparams, stats_path = get_saved_hyperparams(
|
||||
stats_path, norm_reward=args.norm_reward, test_mode=True
|
||||
)
|
||||
|
||||
# load env_kwargs if existing
|
||||
env_kwargs = {}
|
||||
args_path = os.path.join(log_path, args.env, "args.yml")
|
||||
if os.path.isfile(args_path):
|
||||
with open(args_path, "r") as f:
|
||||
# pytype: disable=module-attr
|
||||
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
|
||||
if loaded_args["env_kwargs"] is not None:
|
||||
env_kwargs = loaded_args["env_kwargs"]
|
||||
# overwrite with command line arguments
|
||||
if args.env_kwargs is not None:
|
||||
env_kwargs.update(args.env_kwargs)
|
||||
|
||||
log_dir = args.reward_log if args.reward_log != "" else None
|
||||
|
||||
env = create_test_env(
|
||||
args.env,
|
||||
n_envs=args.n_envs,
|
||||
stats_path=stats_path,
|
||||
seed=args.seed,
|
||||
log_dir=log_dir,
|
||||
should_render=not args.no_render,
|
||||
hyperparams=hyperparams,
|
||||
env_kwargs=env_kwargs,
|
||||
)
|
||||
|
||||
kwargs = dict(seed=args.seed)
|
||||
if args.algo in off_policy_algos:
|
||||
# Dummy buffer size as we don't need memory to evaluate the trained agent
|
||||
kwargs.update(dict(buffer_size=1))
|
||||
raise ValueError(f"No model found for {args.algo} on {args.env}, path: {log_path}")
|
||||
|
||||
|
||||
custom_objects = {'observation_space': env.observation_space, 'action_space': env.action_space}
|
||||
|
||||
model = ALGOS[args.algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
# Deterministic by default
|
||||
stochastic = args.stochastic
|
||||
deterministic = not stochastic
|
||||
|
||||
print(
|
||||
f"Evaluating for {args.n_episodes} episodes with a",
|
||||
"deterministic" if deterministic else "stochastic",
|
||||
"policy.",
|
||||
)
|
||||
|
||||
state = None
|
||||
episode_reward = 0.0
|
||||
episode_rewards, episode_lengths, success_episode_lengths = [], [], []
|
||||
ep_len = 0
|
||||
episode = 0
|
||||
# For HER, monitor success rate
|
||||
successes = []
|
||||
while episode < args.n_episodes:
|
||||
action, state = model.predict(obs, state=state, deterministic=deterministic)
|
||||
obs, reward, done, infos = env.step(action)
|
||||
if not args.no_render:
|
||||
env.render("human")
|
||||
|
||||
episode_reward += reward[0]
|
||||
ep_len += 1
|
||||
|
||||
if done and args.verbose > 0:
|
||||
episode += 1
|
||||
print(f"--- Episode {episode}/{args.n_episodes}")
|
||||
# NOTE: for env using VecNormalize, the mean reward
|
||||
# is a normalized reward when `--norm_reward` flag is passed
|
||||
print(f"Episode Reward: {episode_reward:.2f}")
|
||||
episode_rewards.append(episode_reward)
|
||||
print("Episode Length", ep_len)
|
||||
episode_lengths.append(ep_len)
|
||||
if infos[0].get("is_success") is not None:
|
||||
print("Success?:", infos[0].get("is_success", False))
|
||||
successes.append(infos[0].get("is_success", False))
|
||||
if infos[0].get("is_success"):
|
||||
success_episode_lengths.append(ep_len)
|
||||
print(f"Current success rate: {100 * np.mean(successes):.2f}%")
|
||||
episode_reward = 0.0
|
||||
ep_len = 0
|
||||
state = None
|
||||
|
||||
if args.verbose > 0 and len(successes) > 0:
|
||||
print(f"Success rate: {100 * np.mean(successes):.2f}%")
|
||||
|
||||
if args.verbose > 0 and len(episode_rewards) > 0:
|
||||
print(
|
||||
f"Mean reward: {np.mean(episode_rewards):.2f} "
|
||||
f"+/- {np.std(episode_rewards):.2f}"
|
||||
)
|
||||
|
||||
if args.verbose > 0 and len(episode_lengths) > 0:
|
||||
print(
|
||||
f"Mean episode length: {np.mean(episode_lengths):.2f} "
|
||||
f"+/- {np.std(episode_lengths):.2f}"
|
||||
)
|
||||
|
||||
if args.verbose > 0 and len(success_episode_lengths) > 0:
|
||||
print(
|
||||
f"Mean episode length of successful episodes: {np.mean(success_episode_lengths):.2f} "
|
||||
f"+/- {np.std(success_episode_lengths):.2f}"
|
||||
)
|
||||
|
||||
# Workaround for https://github.com/openai/gym/issues/893
|
||||
if not args.no_render:
|
||||
if args.n_envs == 1 and "Bullet" not in args.env and isinstance(env, VecEnv):
|
||||
# DummyVecEnv
|
||||
# Unwrap env
|
||||
while isinstance(env, VecEnvWrapper):
|
||||
env = env.venv
|
||||
if isinstance(env, DummyVecEnv):
|
||||
env.envs[0].env.close()
|
||||
else:
|
||||
env.close()
|
||||
else:
|
||||
# SubprocVecEnv
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Environment and its parameters
|
||||
|
@ -291,4 +136,160 @@ if __name__ == "__main__":
|
|||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
main(args)
|
||||
if args.exp_id == 0:
|
||||
args.exp_id = get_latest_run_id(
|
||||
os.path.join(args.log_folder, args.algo), args.env
|
||||
)
|
||||
print(f"Loading latest experiment, id={args.exp_id}")
|
||||
|
||||
# Sanity checks
|
||||
if args.exp_id > 0:
|
||||
log_path = os.path.join(args.log_folder, args.algo, f"{args.env}_{args.exp_id}")
|
||||
else:
|
||||
log_path = os.path.join(args.log_folder, args.algo)
|
||||
|
||||
assert os.path.isdir(log_path), f"The {log_path} folder was not found"
|
||||
|
||||
model_path = find_model_path(log_path, args)
|
||||
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
|
||||
|
||||
if args.algo in off_policy_algos:
|
||||
args.n_envs = 1
|
||||
|
||||
set_random_seed(args.seed)
|
||||
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
stats_path = os.path.join(log_path, args.env)
|
||||
hyperparams, stats_path = get_saved_hyperparams(
|
||||
stats_path, norm_reward=args.norm_reward, test_mode=True
|
||||
)
|
||||
|
||||
# load env_kwargs if existing
|
||||
env_kwargs = {}
|
||||
args_path = os.path.join(log_path, args.env, "args.yml")
|
||||
if os.path.isfile(args_path):
|
||||
with open(args_path, "r") as f:
|
||||
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
|
||||
if loaded_args["env_kwargs"] is not None:
|
||||
env_kwargs = loaded_args["env_kwargs"]
|
||||
# overwrite with command line arguments
|
||||
if args.env_kwargs is not None:
|
||||
env_kwargs.update(args.env_kwargs)
|
||||
|
||||
log_dir = args.reward_log if args.reward_log != "" else None
|
||||
|
||||
env = create_test_env(
|
||||
args.env,
|
||||
n_envs=args.n_envs,
|
||||
stats_path=stats_path,
|
||||
seed=args.seed,
|
||||
log_dir=log_dir,
|
||||
should_render=not args.no_render,
|
||||
hyperparams=hyperparams,
|
||||
env_kwargs=env_kwargs,
|
||||
)
|
||||
|
||||
kwargs = dict(seed=args.seed)
|
||||
if args.algo in off_policy_algos:
|
||||
# Dummy buffer size as we don't need memory to evaluate the trained agent
|
||||
kwargs.update(dict(buffer_size=1))
|
||||
|
||||
custom_objects = {
|
||||
"observation_space": env.observation_space,
|
||||
"action_space": env.action_space,
|
||||
}
|
||||
|
||||
model = ALGOS[args.algo].load(
|
||||
model_path, env=env, custom_objects=custom_objects, **kwargs
|
||||
)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
# Deterministic by default
|
||||
stochastic = args.stochastic
|
||||
deterministic = not stochastic
|
||||
|
||||
print(
|
||||
f"Evaluating for {args.n_episodes} episodes with a",
|
||||
"deterministic" if deterministic else "stochastic",
|
||||
"policy.",
|
||||
)
|
||||
|
||||
state = None
|
||||
episode_reward = 0.0
|
||||
episode_rewards, episode_lengths, success_episode_lengths = [], [], []
|
||||
ep_len = 0
|
||||
episode = 0
|
||||
# For HER, monitor success rate
|
||||
successes = []
|
||||
while episode < args.n_episodes:
|
||||
action, state = model.predict(obs, state=state, deterministic=deterministic)
|
||||
obs, reward, done, infos = env.step(action)
|
||||
if not args.no_render:
|
||||
env.render("human")
|
||||
|
||||
episode_reward += reward[0]
|
||||
ep_len += 1
|
||||
|
||||
if done and args.verbose > 0:
|
||||
episode += 1
|
||||
print(f"--- Episode {episode}/{args.n_episodes}")
|
||||
# NOTE: for env using VecNormalize, the mean reward
|
||||
# is a normalized reward when `--norm_reward` flag is passed
|
||||
print(f"Episode Reward: {episode_reward:.2f}")
|
||||
episode_rewards.append(episode_reward)
|
||||
print("Episode Length", ep_len)
|
||||
episode_lengths.append(ep_len)
|
||||
if infos[0].get("is_success") is not None:
|
||||
print("Success?:", infos[0].get("is_success", False))
|
||||
successes.append(infos[0].get("is_success", False))
|
||||
if infos[0].get("is_success"):
|
||||
success_episode_lengths.append(ep_len)
|
||||
print(f"Current success rate: {100 * np.mean(successes):.2f}%")
|
||||
episode_reward = 0.0
|
||||
ep_len = 0
|
||||
state = None
|
||||
|
||||
if args.verbose > 0 and len(successes) > 0:
|
||||
print(f"Success rate: {100 * np.mean(successes):.2f}%")
|
||||
|
||||
if args.verbose > 0 and len(episode_rewards) > 0:
|
||||
print(
|
||||
f"Mean reward: {np.mean(episode_rewards):.2f} "
|
||||
f"+/- {np.std(episode_rewards):.2f}"
|
||||
)
|
||||
|
||||
if args.verbose > 0 and len(episode_lengths) > 0:
|
||||
print(
|
||||
f"Mean episode length: {np.mean(episode_lengths):.2f} "
|
||||
f"+/- {np.std(episode_lengths):.2f}"
|
||||
)
|
||||
|
||||
if args.verbose > 0 and len(success_episode_lengths) > 0:
|
||||
print(
|
||||
f"Mean episode length of successful episodes: {np.mean(success_episode_lengths):.2f} "
|
||||
f"+/- {np.std(success_episode_lengths):.2f}"
|
||||
)
|
||||
|
||||
# Workaround for https://github.com/openai/gym/issues/893
|
||||
if not args.no_render:
|
||||
if args.n_envs == 1 and "Bullet" not in args.env and isinstance(env, VecEnv):
|
||||
# DummyVecEnv
|
||||
# Unwrap env
|
||||
while isinstance(env, VecEnvWrapper):
|
||||
env = env.venv
|
||||
if isinstance(env, DummyVecEnv):
|
||||
env.envs[0].env.close()
|
||||
else:
|
||||
env.close()
|
||||
else:
|
||||
# SubprocVecEnv
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -9,87 +9,13 @@ from typing import Dict
|
|||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
from rbs_gym import envs as gz_envs
|
||||
from rbs_gym.utils.exp_manager import ExperimentManager
|
||||
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
|
||||
def main(args: Dict):
|
||||
|
||||
# Check if the selected environment is valid
|
||||
# If it could not be found, suggest the closest match
|
||||
registered_envs = set(gym.envs.registry.keys())
|
||||
if args.env not in registered_envs:
|
||||
try:
|
||||
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
|
||||
except IndexError:
|
||||
closest_match = "'no close match found...'"
|
||||
raise ValueError(
|
||||
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
|
||||
)
|
||||
|
||||
# If no specific seed is selected, choose a random one
|
||||
if args.seed < 0:
|
||||
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
|
||||
|
||||
# Set the random seed across platforms
|
||||
set_random_seed(args.seed)
|
||||
|
||||
# Setting num threads to 1 makes things run faster on cpu
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
# Verify that pre-trained agent exists before continuing to train it
|
||||
if args.trained_agent != "":
|
||||
assert args.trained_agent.endswith(".zip") and os.path.isfile(
|
||||
args.trained_agent
|
||||
), "The trained_agent must be a valid path to a .zip file"
|
||||
|
||||
# If enabled, ensure that the run has a unique ID
|
||||
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
|
||||
|
||||
print("=" * 10, args.env, "=" * 10)
|
||||
print(f"Seed: {args.seed}")
|
||||
env_kwargs = {
|
||||
"render_mode": "human"
|
||||
}
|
||||
|
||||
exp_manager = ExperimentManager(
|
||||
args,
|
||||
args.algo,
|
||||
args.env,
|
||||
args.log_folder,
|
||||
args.tensorboard_log,
|
||||
args.n_timesteps,
|
||||
args.eval_freq,
|
||||
args.eval_episodes,
|
||||
args.save_freq,
|
||||
args.hyperparams,
|
||||
args.env_kwargs,
|
||||
args.trained_agent,
|
||||
truncate_last_trajectory=args.truncate_last_trajectory,
|
||||
uuid_str=uuid_str,
|
||||
seed=args.seed,
|
||||
log_interval=args.log_interval,
|
||||
save_replay_buffer=args.save_replay_buffer,
|
||||
preload_replay_buffer=args.preload_replay_buffer,
|
||||
verbose=args.verbose,
|
||||
vec_env_type=args.vec_env,
|
||||
)
|
||||
|
||||
# Prepare experiment
|
||||
model = exp_manager.setup_experiment()
|
||||
|
||||
exp_manager.learn(model)
|
||||
exp_manager.save_trained_model(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Environment and its parameters
|
||||
|
@ -230,4 +156,79 @@ if __name__ == "__main__":
|
|||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
main(args=args)
|
||||
# Check if the selected environment is valid
|
||||
# If it could not be found, suggest the closest match
|
||||
registered_envs = set(gym.envs.registry.keys())
|
||||
if args.env not in registered_envs:
|
||||
try:
|
||||
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
|
||||
except IndexError:
|
||||
closest_match = "'no close match found...'"
|
||||
raise ValueError(
|
||||
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
|
||||
)
|
||||
|
||||
# If no specific seed is selected, choose a random one
|
||||
if args.seed < 0:
|
||||
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
|
||||
|
||||
# Set the random seed across platforms
|
||||
set_random_seed(args.seed)
|
||||
|
||||
# Setting num threads to 1 makes things run faster on cpu
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
# Verify that pre-trained agent exists before continuing to train it
|
||||
if args.trained_agent != "":
|
||||
assert args.trained_agent.endswith(".zip") and os.path.isfile(
|
||||
args.trained_agent
|
||||
), "The trained_agent must be a valid path to a .zip file"
|
||||
|
||||
# If enabled, ensure that the run has a unique ID
|
||||
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
|
||||
|
||||
print("=" * 10, args.env, "=" * 10)
|
||||
print(f"Seed: {args.seed}")
|
||||
env_kwargs = {"render_mode": "human"}
|
||||
|
||||
exp_manager = ExperimentManager(
|
||||
args,
|
||||
args.algo,
|
||||
args.env,
|
||||
args.log_folder,
|
||||
args.tensorboard_log,
|
||||
args.n_timesteps,
|
||||
args.eval_freq,
|
||||
args.eval_episodes,
|
||||
args.save_freq,
|
||||
args.hyperparams,
|
||||
args.env_kwargs,
|
||||
args.trained_agent,
|
||||
truncate_last_trajectory=args.truncate_last_trajectory,
|
||||
uuid_str=uuid_str,
|
||||
seed=args.seed,
|
||||
log_interval=args.log_interval,
|
||||
save_replay_buffer=args.save_replay_buffer,
|
||||
preload_replay_buffer=args.preload_replay_buffer,
|
||||
verbose=args.verbose,
|
||||
vec_env_type=args.vec_env,
|
||||
)
|
||||
|
||||
# Prepare experiment
|
||||
results = exp_manager.setup_experiment()
|
||||
if results is not None:
|
||||
model, saved_hyperparams = results
|
||||
|
||||
# Normal training
|
||||
if model is not None:
|
||||
exp_manager.learn(model)
|
||||
exp_manager.save_trained_model(model)
|
||||
else:
|
||||
exp_manager.hyperparameters_optimization()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -10,53 +10,7 @@ from rbs_gym import envs as gz_envs
|
|||
from rbs_gym.utils.utils import StoreDict, str2bool
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
# Create the environment
|
||||
if args.env_kwargs is not None:
|
||||
env = gym.make(args.env, **args.env_kwargs)
|
||||
else:
|
||||
env = gym.make(args.env)
|
||||
|
||||
# Initialize random seed
|
||||
env.seed(args.seed)
|
||||
|
||||
# Check the environment
|
||||
if args.check_env:
|
||||
check_env(env, warn=True, skip_render_check=True)
|
||||
|
||||
# Step environment for bunch of episodes
|
||||
for episode in range(args.n_episodes):
|
||||
|
||||
# Initialize returned values
|
||||
done = False
|
||||
total_reward = 0
|
||||
|
||||
# Reset the environment
|
||||
observation = env.reset()
|
||||
|
||||
# Step through the current episode until it is done
|
||||
while not done:
|
||||
|
||||
# Sample random action
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Step the environment with the random action
|
||||
observation, reward, truncated, terminated, info = env.step(action)
|
||||
|
||||
done = truncated or terminated
|
||||
|
||||
# Accumulate the reward
|
||||
total_reward += reward
|
||||
|
||||
print(f"Episode #{episode}\n\treward: {total_reward}")
|
||||
|
||||
# Cleanup once done
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Environment and its parameters
|
||||
|
@ -101,4 +55,46 @@ if __name__ == "__main__":
|
|||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
main(args=args)
|
||||
# Create the environment
|
||||
if args.env_kwargs is not None:
|
||||
env = gym.make(args.env, **args.env_kwargs)
|
||||
else:
|
||||
env = gym.make(args.env)
|
||||
|
||||
# Initialize random seed
|
||||
env.seed(args.seed)
|
||||
|
||||
# Check the environment
|
||||
if args.check_env:
|
||||
check_env(env, warn=True, skip_render_check=True)
|
||||
|
||||
# Step environment for bunch of episodes
|
||||
for episode in range(args.n_episodes):
|
||||
# Initialize returned values
|
||||
done = False
|
||||
total_reward = 0
|
||||
|
||||
# Reset the environment
|
||||
observation = env.reset()
|
||||
|
||||
# Step through the current episode until it is done
|
||||
while not done:
|
||||
# Sample random action
|
||||
action = env.action_space.sample()
|
||||
|
||||
# Step the environment with the random action
|
||||
observation, reward, truncated, terminated, info = env.step(action)
|
||||
|
||||
done = truncated or terminated
|
||||
|
||||
# Accumulate the reward
|
||||
total_reward += reward
|
||||
|
||||
print(f"Episode #{episode}\n\treward: {total_reward}")
|
||||
|
||||
# Cleanup once done
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -4,7 +4,6 @@ import argparse
|
|||
import difflib
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
@ -16,114 +15,7 @@ from rbs_gym.utils.exp_manager import ExperimentManager
|
|||
from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
|
||||
|
||||
|
||||
def main(args: Dict):
|
||||
|
||||
# Check if the selected environment is valid
|
||||
# If it could not be found, suggest the closest match
|
||||
registered_envs = set(gym.envs.registry.keys())
|
||||
if args.env not in registered_envs:
|
||||
try:
|
||||
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
|
||||
except IndexError:
|
||||
closest_match = "'no close match found...'"
|
||||
raise ValueError(
|
||||
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
|
||||
)
|
||||
|
||||
# If no specific seed is selected, choose a random one
|
||||
if args.seed < 0:
|
||||
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
|
||||
|
||||
# Set the random seed across platforms
|
||||
set_random_seed(args.seed)
|
||||
|
||||
# Setting num threads to 1 makes things run faster on cpu
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
# Verify that pre-trained agent exists before continuing to train it
|
||||
if args.trained_agent != "":
|
||||
assert args.trained_agent.endswith(".zip") and os.path.isfile(
|
||||
args.trained_agent
|
||||
), "The trained_agent must be a valid path to a .zip file"
|
||||
|
||||
# If enabled, ensure that the run has a unique ID
|
||||
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
|
||||
|
||||
print("=" * 10, args.env, "=" * 10)
|
||||
print(f"Seed: {args.seed}")
|
||||
|
||||
if args.track:
|
||||
try:
|
||||
import wandb
|
||||
import datetime
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
|
||||
) from e
|
||||
|
||||
run_name = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
|
||||
tags = []
|
||||
run = wandb.init(
|
||||
name=run_name,
|
||||
project="rbs-gym",
|
||||
entity=None,
|
||||
tags=tags,
|
||||
config=vars(args),
|
||||
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
|
||||
monitor_gym=False, # auto-upload the videos of agents playing the game
|
||||
save_code=True, # optional
|
||||
)
|
||||
args.tensorboard_log = f"runs/{run_name}"
|
||||
|
||||
|
||||
exp_manager = ExperimentManager(
|
||||
args,
|
||||
args.algo,
|
||||
args.env,
|
||||
args.log_folder,
|
||||
args.tensorboard_log,
|
||||
args.n_timesteps,
|
||||
args.eval_freq,
|
||||
args.eval_episodes,
|
||||
args.save_freq,
|
||||
args.hyperparams,
|
||||
args.env_kwargs,
|
||||
args.trained_agent,
|
||||
args.optimize_hyperparameters,
|
||||
truncate_last_trajectory=args.truncate_last_trajectory,
|
||||
uuid_str=uuid_str,
|
||||
seed=args.seed,
|
||||
log_interval=args.log_interval,
|
||||
save_replay_buffer=args.save_replay_buffer,
|
||||
preload_replay_buffer=args.preload_replay_buffer,
|
||||
verbose=args.verbose,
|
||||
vec_env_type=args.vec_env,
|
||||
no_optim_plots=args.no_optim_plots,
|
||||
)
|
||||
|
||||
# Prepare experiment
|
||||
results = exp_manager.setup_experiment()
|
||||
if results is not None:
|
||||
model, saved_hyperparams = results
|
||||
if args.track:
|
||||
# we need to save the loaded hyperparameters
|
||||
args.saved_hyperparams = saved_hyperparams
|
||||
assert run is not None # make mypy happy
|
||||
run.config.setdefaults(vars(args))
|
||||
|
||||
# Normal training
|
||||
if model is not None:
|
||||
exp_manager.learn(model)
|
||||
exp_manager.save_trained_model(model)
|
||||
else:
|
||||
exp_manager.hyperparameters_optimization()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
@ -281,4 +173,111 @@ if __name__ == "__main__":
|
|||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
main(args=args)
|
||||
# Check if the selected environment is valid
|
||||
# If it could not be found, suggest the closest match
|
||||
registered_envs = set(gym.envs.registry.keys())
|
||||
if args.env not in registered_envs:
|
||||
try:
|
||||
closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
|
||||
except IndexError:
|
||||
closest_match = "'no close match found...'"
|
||||
raise ValueError(
|
||||
f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
|
||||
)
|
||||
|
||||
# If no specific seed is selected, choose a random one
|
||||
if args.seed < 0:
|
||||
args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
|
||||
|
||||
# Set the random seed across platforms
|
||||
set_random_seed(args.seed)
|
||||
|
||||
# Setting num threads to 1 makes things run faster on cpu
|
||||
if args.num_threads > 0:
|
||||
if args.verbose > 1:
|
||||
print(f"Setting torch.num_threads to {args.num_threads}")
|
||||
th.set_num_threads(args.num_threads)
|
||||
|
||||
# Verify that pre-trained agent exists before continuing to train it
|
||||
if args.trained_agent != "":
|
||||
assert args.trained_agent.endswith(".zip") and os.path.isfile(
|
||||
args.trained_agent
|
||||
), "The trained_agent must be a valid path to a .zip file"
|
||||
|
||||
# If enabled, ensure that the run has a unique ID
|
||||
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
|
||||
|
||||
print("=" * 10, args.env, "=" * 10)
|
||||
print(f"Seed: {args.seed}")
|
||||
|
||||
run = None
|
||||
if args.track:
|
||||
try:
|
||||
import wandb
|
||||
import datetime
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
|
||||
) from e
|
||||
|
||||
run_name = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
|
||||
tags = []
|
||||
run = wandb.init(
|
||||
name=run_name,
|
||||
project="rbs-gym",
|
||||
entity=None,
|
||||
tags=tags,
|
||||
config=vars(args),
|
||||
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
|
||||
monitor_gym=False, # auto-upload the videos of agents playing the game
|
||||
save_code=True, # optional
|
||||
)
|
||||
args.tensorboard_log = f"runs/{run_name}"
|
||||
|
||||
|
||||
exp_manager = ExperimentManager(
|
||||
args,
|
||||
args.algo,
|
||||
args.env,
|
||||
args.log_folder,
|
||||
args.tensorboard_log,
|
||||
args.n_timesteps,
|
||||
args.eval_freq,
|
||||
args.eval_episodes,
|
||||
args.save_freq,
|
||||
args.hyperparams,
|
||||
args.env_kwargs,
|
||||
args.trained_agent,
|
||||
args.optimize_hyperparameters,
|
||||
truncate_last_trajectory=args.truncate_last_trajectory,
|
||||
uuid_str=uuid_str,
|
||||
seed=args.seed,
|
||||
log_interval=args.log_interval,
|
||||
save_replay_buffer=args.save_replay_buffer,
|
||||
preload_replay_buffer=args.preload_replay_buffer,
|
||||
verbose=args.verbose,
|
||||
vec_env_type=args.vec_env,
|
||||
no_optim_plots=args.no_optim_plots,
|
||||
)
|
||||
|
||||
# Prepare experiment
|
||||
results = exp_manager.setup_experiment()
|
||||
if results is not None:
|
||||
model, saved_hyperparams = results
|
||||
if args.track:
|
||||
# we need to save the loaded hyperparameters
|
||||
args.saved_hyperparams = saved_hyperparams
|
||||
assert run is not None # make mypy happy
|
||||
run.config.setdefaults(vars(args))
|
||||
|
||||
# Normal training
|
||||
if model is not None:
|
||||
exp_manager.learn(model)
|
||||
exp_manager.save_trained_model(model)
|
||||
else:
|
||||
exp_manager.hyperparameters_optimization()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
env_manager/rbs_gym/resource/rbs_gym
Normal file
0
env_manager/rbs_gym/resource/rbs_gym
Normal file
|
@ -1,200 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import time
|
||||
import gym_gz_models
|
||||
import gym_gz
|
||||
from scenario import gazebo as scenario_gazebo
|
||||
from scenario import core as scenario_core
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
import numpy as np
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from rclpy.executors import MultiThreadedExecutor
|
||||
from rbs_skill_servers import CartesianControllerPublisher, TakePose
|
||||
from rclpy.action import ActionClient
|
||||
from control_msgs.action import GripperCommand
|
||||
|
||||
|
||||
class Spawner(Node):
|
||||
def __init__(self):
|
||||
super().__init__("spawner")
|
||||
self.gazebo = scenario_gazebo.GazeboSimulator(step_size=0.001,
|
||||
rtf=1.0,
|
||||
steps_per_run=1)
|
||||
self.cartesian_pose = self.create_publisher(
|
||||
PoseStamped,
|
||||
"/" + "arm0" + "/cartesian_motion_controller/target_frame", 10)
|
||||
|
||||
self.current_pose_sub = self.create_subscription(PoseStamped,
|
||||
"/arm0/cartesian_motion_controller/current_pose", self.callback, 10)
|
||||
|
||||
|
||||
self._action_client = ActionClient(self,
|
||||
GripperCommand,
|
||||
"/" + "arm0" + "/gripper_controller/gripper_cmd")
|
||||
|
||||
timer_period = 0.001 # seconds
|
||||
self.timer = self.create_timer(timer_period, self.timer_callback)
|
||||
self.ano_timer = self.create_timer(timer_period, self.another_timer)
|
||||
|
||||
scenario_gazebo.set_verbosity(scenario_gazebo.Verbosity_info)
|
||||
|
||||
self.gazebo.insert_world_from_sdf(
|
||||
"/home/bill-finger/rbs_ws/install/rbs_simulation/share/rbs_simulation/worlds/asm2.sdf")
|
||||
self.gazebo.initialize()
|
||||
|
||||
self.world = self.gazebo.get_world()
|
||||
self.current_pose: PoseStamped = PoseStamped()
|
||||
self.init_sim()
|
||||
self.cube = self.world.get_model("cube")
|
||||
self.stage = 0
|
||||
self.gripper_open = False
|
||||
|
||||
def callback(self, msg: PoseStamped):
|
||||
self.current_pose = msg
|
||||
|
||||
def timer_callback(self):
|
||||
self.gazebo.run()
|
||||
|
||||
def send_goal(self, goal: float):
|
||||
goal_msg = GripperCommand.Goal()
|
||||
goal_msg._command.position = goal
|
||||
goal_msg._command.max_effort = 1.0
|
||||
self._action_client.wait_for_server()
|
||||
self.gripper_open = not self.gripper_open
|
||||
|
||||
self._send_goal_future = self._action_client.send_goal_async(goal_msg)
|
||||
|
||||
self._send_goal_future.add_done_callback(self.goal_response_callback)
|
||||
|
||||
def goal_response_callback(self, future):
|
||||
goal_handle = future.result()
|
||||
if not goal_handle.accepted:
|
||||
self.get_logger().info('Goal rejected :(')
|
||||
return
|
||||
|
||||
self.get_logger().info('Goal accepted :)')
|
||||
|
||||
self._get_result_future = goal_handle.get_result_async()
|
||||
self._get_result_future.add_done_callback(self.get_result_callback)
|
||||
|
||||
def get_result_callback(self, future):
|
||||
result = future.result().result
|
||||
self.get_logger().info('Result: {0}'.format(result.position))
|
||||
|
||||
def another_timer(self):
|
||||
position_over_cube = np.array(self.cube.base_position()) + np.array([0, 0, 0.2])
|
||||
position_cube = np.array(self.cube.base_position()) + np.array([0, 0, 0.03])
|
||||
quat_xyzw = R.from_euler(seq="y", angles=180, degrees=True).as_quat()
|
||||
if self.stage == 0:
|
||||
if self.distance_to_target(position_over_cube, quat_xyzw) > 0.01:
|
||||
self.cartesian_pose.publish(self.get_pose(position_over_cube, quat_xyzw))
|
||||
if self.distance_to_target(position_over_cube, quat_xyzw) < 0.01:
|
||||
self.stage += 1
|
||||
if self.stage == 1:
|
||||
if self.distance_to_target(position_cube, quat_xyzw) > 0.01:
|
||||
if not self.gripper_open:
|
||||
self.send_goal(0.064)
|
||||
# rclpy.spin_until_future_complete(self, future)
|
||||
self.cartesian_pose.publish(self.get_pose(position_cube, quat_xyzw))
|
||||
if self.distance_to_target(position_cube, quat_xyzw) < 0.01:
|
||||
self.stage += 1
|
||||
|
||||
|
||||
|
||||
def distance_to_target(self, position, orientation):
|
||||
target_pose = self.get_pose(position, orientation)
|
||||
current_position = np.array([
|
||||
self.current_pose.pose.position.x,
|
||||
self.current_pose.pose.position.y,
|
||||
self.current_pose.pose.position.z
|
||||
])
|
||||
target_position = np.array([
|
||||
target_pose.pose.position.x,
|
||||
target_pose.pose.position.y,
|
||||
target_pose.pose.position.z
|
||||
])
|
||||
distance = np.linalg.norm(current_position - target_position)
|
||||
|
||||
return distance
|
||||
|
||||
def init_sim(self):
|
||||
# Create the simulator
|
||||
self.gazebo.gui()
|
||||
self.gazebo.run(paused=True)
|
||||
|
||||
self.world.to_gazebo().set_gravity((0, 0, -9.8))
|
||||
|
||||
self.world.insert_model("/home/bill-finger/rbs_ws/current.urdf")
|
||||
self.gazebo.run(paused=True)
|
||||
for model_name in self.world.model_names():
|
||||
model = self.world.get_model(model_name)
|
||||
print(f"Model: {model_name}")
|
||||
print(f" Base link: {model.base_frame()}")
|
||||
print("LINKS")
|
||||
for name in model.link_names():
|
||||
position = model.get_link(name).position()
|
||||
orientation_wxyz = np.asarray(model.get_link(name).orientation())
|
||||
orientation = R.from_quat(orientation_wxyz[[1, 2, 3, 0]]).as_euler("xyz")
|
||||
print(f" {name}:", (*position, *tuple(orientation)))
|
||||
print("JOINTS")
|
||||
for name in model.joint_names():
|
||||
print(f"{name}")
|
||||
|
||||
uri = lambda org, name: f"https://fuel.gazebosim.org/{org}/models/{name}"
|
||||
|
||||
# Download the cube SDF file
|
||||
cube_sdf = scenario_gazebo.get_model_file_from_fuel(
|
||||
uri=uri(org="openrobotics", name="wood cube 5cm"), use_cache=False
|
||||
)
|
||||
|
||||
# Sample a random position
|
||||
random_position = np.random.uniform(low=[-0.2, -0.2, 0.0], high=[-0.3, 0.2, 0.0])
|
||||
|
||||
# Get a unique name
|
||||
model_name = gym_gz.utils.scenario.get_unique_model_name(
|
||||
world=self.world, model_name="cube"
|
||||
)
|
||||
|
||||
# Insert the model
|
||||
assert self.world.insert_model(
|
||||
cube_sdf, scenario_core.Pose(random_position, [1.0, 0, 0, 0]), model_name
|
||||
)
|
||||
|
||||
model = self.world.get_model("rbs_arm")
|
||||
self.cube = self.world.get_model("cube")
|
||||
|
||||
ok_reset_pos = model.to_gazebo().reset_joint_positions(
|
||||
[0.0, -0.240, -3.142, 1.090, 0, 1.617, 0.0, 0.0, 0.0],
|
||||
[name for name in model.joint_names() if "_joint" in name]
|
||||
)
|
||||
if not ok_reset_pos:
|
||||
raise RuntimeError("Failed to reset the robot state")
|
||||
|
||||
|
||||
def get_pose(self, position, orientation) -> PoseStamped:
|
||||
msg = PoseStamped()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.header.frame_id = "base_link"
|
||||
msg.pose.position.x = position[0]
|
||||
msg.pose.position.y = position[1]
|
||||
msg.pose.position.z = position[2]
|
||||
msg.pose.orientation.x = orientation[0]
|
||||
msg.pose.orientation.y = orientation[1]
|
||||
msg.pose.orientation.z = orientation[2]
|
||||
msg.pose.orientation.w = orientation[3]
|
||||
return msg
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
executor = MultiThreadedExecutor()
|
||||
my_node = Spawner()
|
||||
executor.add_node(my_node)
|
||||
executor.spin()
|
||||
my_node.gazebo.close()
|
||||
my_node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,138 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
import numpy as np
|
||||
import quaternion
|
||||
from geometry_msgs.msg import Twist
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
import tf2_ros
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import os
|
||||
|
||||
|
||||
class Converter(Node):
|
||||
"""Convert Twist messages to PoseStamped
|
||||
|
||||
Use this node to integrate twist messages into a moving target pose in
|
||||
Cartesian space. An initial TF lookup assures that the target pose always
|
||||
starts at the robot's end-effector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("converter")
|
||||
|
||||
self.twist_topic = self.declare_parameter("twist_topic", "/cartesian_motion_controller/target_twist").value
|
||||
self.pose_topic = self.declare_parameter("pose_topic", "/cartesian_motion_controller/target_frame").value
|
||||
self.frame_id = self.declare_parameter("frame_id", "base_link").value
|
||||
self.end_effector = self.declare_parameter("end_effector", "gripper_grasp_point").value
|
||||
|
||||
self.tf_buffer = tf2_ros.Buffer()
|
||||
self.tf_listener = tf2_ros.TransformListener(self.tf_buffer, self)
|
||||
self.rot = np.quaternion(0, 0, 0, 1)
|
||||
self.pos = [0, 0, 0]
|
||||
|
||||
self.pub = self.create_publisher(PoseStamped, self.pose_topic, 3)
|
||||
self.sub = self.create_subscription(Twist, self.twist_topic, self.twist_cb, 1)
|
||||
self.last = time.time()
|
||||
|
||||
self.startup_done = False
|
||||
period = 1.0 / self.declare_parameter("publishing_rate", 100).value
|
||||
self.timer = self.create_timer(period, self.publish)
|
||||
|
||||
self.thread = threading.Thread(target=self.startup, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def startup(self):
|
||||
"""Make sure to start at the robot's current pose"""
|
||||
# Wait until we entered spinning in the main thread.
|
||||
time.sleep(1)
|
||||
try:
|
||||
start = self.tf_buffer.lookup_transform(
|
||||
target_frame=self.frame_id,
|
||||
source_frame=self.end_effector,
|
||||
time=rclpy.time.Time(),
|
||||
)
|
||||
|
||||
except (
|
||||
tf2_ros.InvalidArgumentException,
|
||||
tf2_ros.LookupException,
|
||||
tf2_ros.ConnectivityException,
|
||||
tf2_ros.ExtrapolationException,
|
||||
) as e:
|
||||
print(f"Startup failed: {e}")
|
||||
os._exit(1)
|
||||
|
||||
self.pos[0] = start.transform.translation.x
|
||||
self.pos[1] = start.transform.translation.y
|
||||
self.pos[2] = start.transform.translation.z
|
||||
self.rot.x = start.transform.rotation.x
|
||||
self.rot.y = start.transform.rotation.y
|
||||
self.rot.z = start.transform.rotation.z
|
||||
self.rot.w = start.transform.rotation.w
|
||||
self.startup_done = True
|
||||
|
||||
def twist_cb(self, data):
|
||||
"""Numerically integrate twist message into a pose
|
||||
|
||||
Use global self.frame_id as reference for the navigation commands.
|
||||
"""
|
||||
now = time.time()
|
||||
dt = now - self.last
|
||||
self.last = now
|
||||
|
||||
# Position update
|
||||
self.pos[0] += data.linear.x * dt
|
||||
self.pos[1] += data.linear.y * dt
|
||||
self.pos[2] += data.linear.z * dt
|
||||
|
||||
# Orientation update
|
||||
wx = data.angular.x
|
||||
wy = data.angular.y
|
||||
wz = data.angular.z
|
||||
|
||||
_, q = quaternion.integrate_angular_velocity(
|
||||
lambda _: (wx, wy, wz), 0, dt, self.rot
|
||||
)
|
||||
|
||||
self.rot = q[-1] # the last one is after dt passed
|
||||
|
||||
def publish(self):
|
||||
if not self.startup_done:
|
||||
return
|
||||
try:
|
||||
msg = PoseStamped()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.header.frame_id = self.frame_id
|
||||
msg.pose.position.x = self.pos[0]
|
||||
msg.pose.position.y = self.pos[1]
|
||||
msg.pose.position.z = self.pos[2]
|
||||
msg.pose.orientation.x = self.rot.x
|
||||
msg.pose.orientation.y = self.rot.y
|
||||
msg.pose.orientation.z = self.rot.z
|
||||
msg.pose.orientation.w = self.rot.w
|
||||
|
||||
self.pub.publish(msg)
|
||||
except Exception:
|
||||
# Swallow 'publish() to closed topic' error.
|
||||
# This rarely happens on killing this node.
|
||||
pass
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = Converter()
|
||||
rclpy.spin(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
rclpy.shutdown()
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
sys.exit(1)
|
4
env_manager/rbs_gym/setup.cfg
Normal file
4
env_manager/rbs_gym/setup.cfg
Normal file
|
@ -0,0 +1,4 @@
|
|||
[develop]
|
||||
script_dir=$base/lib/rbs_gym
|
||||
[install]
|
||||
install_scripts=$base/lib/rbs_gym
|
36
env_manager/rbs_gym/setup.py
Normal file
36
env_manager/rbs_gym/setup.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import os
|
||||
from glob import glob
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = "rbs_gym"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.0.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
|
||||
("share/" + package_name, ["package.xml"]),
|
||||
(
|
||||
os.path.join("share", package_name, "launch"),
|
||||
glob(os.path.join("launch", "*launch.[pxy][yma]*")),
|
||||
),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="narmak",
|
||||
maintainer_email="ur.narmak@gmail.com",
|
||||
description="TODO: Package description",
|
||||
license="Apache-2.0",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"train = rbs_gym.scripts.train:main",
|
||||
"spawner = rbs_gym.scripts.spawner:main",
|
||||
"velocity = rbs_gym.scripts.velocity:main",
|
||||
"test_agent = rbs_gym.scripts.test_agent:main",
|
||||
"evaluate = rbs_gym.scripts.evaluate:main",
|
||||
],
|
||||
},
|
||||
)
|
25
env_manager/rbs_gym/test/test_copyright.py
Normal file
25
env_manager/rbs_gym/test/test_copyright.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_copyright.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
# Remove the `skip` decorator once the source file(s) have a copyright header
|
||||
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
|
||||
@pytest.mark.copyright
|
||||
@pytest.mark.linter
|
||||
def test_copyright():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found errors'
|
25
env_manager/rbs_gym/test/test_flake8.py
Normal file
25
env_manager/rbs_gym/test/test_flake8.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2017 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_flake8.main import main_with_errors
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.flake8
|
||||
@pytest.mark.linter
|
||||
def test_flake8():
|
||||
rc, errors = main_with_errors(argv=[])
|
||||
assert rc == 0, \
|
||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||
'\n'.join(errors)
|
23
env_manager/rbs_gym/test/test_pep257.py
Normal file
23
env_manager/rbs_gym/test/test_pep257.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_pep257.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.linter
|
||||
@pytest.mark.pep257
|
||||
def test_pep257():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found code style errors / warnings'
|
|
@ -1,36 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.8)
|
||||
project(rbs_runtime)
|
||||
|
||||
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
add_compile_options(-Wall -Wextra -Wpedantic)
|
||||
endif()
|
||||
|
||||
# find dependencies
|
||||
find_package(ament_cmake REQUIRED)
|
||||
# uncomment the following section in order to fill in
|
||||
# further dependencies manually.
|
||||
# find_package(<dependency> REQUIRED)
|
||||
|
||||
ament_python_install_package(${PROJECT_NAME})
|
||||
|
||||
|
||||
install(PROGRAMS
|
||||
scripts/runtime.py
|
||||
DESTINATION lib/${PROJECT_NAME}
|
||||
)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
find_package(ament_lint_auto REQUIRED)
|
||||
# the following line skips the linter which checks for copyrights
|
||||
# comment the line when a copyright and license is added to all source files
|
||||
set(ament_cmake_copyright_FOUND TRUE)
|
||||
# the following line skips cpplint (only works in a git repo)
|
||||
# comment the line when this package is in a git repo and when
|
||||
# a copyright and license is added to all source files
|
||||
set(ament_cmake_cpplint_FOUND TRUE)
|
||||
ament_lint_auto_find_test_dependencies()
|
||||
endif()
|
||||
|
||||
install(DIRECTORY launch DESTINATION share/${PROJECT_NAME})
|
||||
|
||||
ament_package()
|
|
@ -1,139 +0,0 @@
|
|||
robot:
|
||||
name: rbs_arm
|
||||
urdf_string: ""
|
||||
spawn_position:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
spawn_quat_xyzw:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 1.0
|
||||
joint_positioins:
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
with_gripper: true
|
||||
gripper_jont_positions:
|
||||
- 0
|
||||
randomizer:
|
||||
pose: false
|
||||
spawn_volume:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 0.0
|
||||
joint_positions: false
|
||||
joint_positions_std: 0.1
|
||||
joint_positions_above_object_spawn: false
|
||||
joint_positions_above_object_spawn_elevation: 0.2
|
||||
joint_positions_above_object_spawn_xy_randomness: 0.2
|
||||
|
||||
terrain:
|
||||
type: flat
|
||||
spawn_position:
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
spawn_quat_xyzw:
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
size:
|
||||
- 1.5
|
||||
- 1.5
|
||||
model_rollouts_num: 1
|
||||
|
||||
light:
|
||||
type: sun
|
||||
direction:
|
||||
- 0.5
|
||||
- -0.25
|
||||
- -0.75
|
||||
random_minmax_elevation:
|
||||
- -0.15
|
||||
- -0.65
|
||||
color:
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
- 1.0
|
||||
distance: 1000.0
|
||||
visual: true
|
||||
radius: 25.0
|
||||
model_rollouts_num: 1
|
||||
|
||||
objects:
|
||||
- name: bishop
|
||||
type: ""
|
||||
relative_to: world
|
||||
position:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
orientation:
|
||||
- 1.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
static: false
|
||||
randomize:
|
||||
count: 0
|
||||
random_pose: false
|
||||
random_position: false
|
||||
random_orientation: false
|
||||
random_model: false
|
||||
random_spawn_position_segments: []
|
||||
random_spawn_position_update_workspace_centre: false
|
||||
random_spawn_volume:
|
||||
- 0.5
|
||||
- 0.5
|
||||
- 0.5
|
||||
models_rollouts_num: 0
|
||||
texture: []
|
||||
|
||||
camera:
|
||||
- name: robot_camera
|
||||
enable: true
|
||||
type: rgbd_camera
|
||||
relative_to: base_link
|
||||
width: 128
|
||||
height: 128
|
||||
image_format: R8G8B8
|
||||
update_rate: 10
|
||||
horizontal_fov: 1.0471975511965976
|
||||
vertical_fov: 1.0471975511965976
|
||||
clip_color:
|
||||
- 0.01
|
||||
- 1000.0
|
||||
clip_depth:
|
||||
- 0.05
|
||||
- 10.0
|
||||
noise_mean: null
|
||||
noise_stddev: null
|
||||
publish_color: false
|
||||
publish_depth: false
|
||||
publish_points: false
|
||||
spawn_position:
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
spawn_quat_xyzw:
|
||||
- 0
|
||||
- 0.70710678118
|
||||
- 0
|
||||
- 0.70710678118
|
||||
random_pose_rollouts_num: 1
|
||||
random_pose_mode: orbit
|
||||
random_pose_orbit_distance: 1.0
|
||||
random_pose_orbit_height_range:
|
||||
- 0.1
|
||||
- 0.7
|
||||
random_pose_orbit_ignore_arc_behind_robot: 0.39269908169872414
|
||||
random_pose_select_position_options: []
|
||||
random_pose_focal_point_z_offset: 0.0
|
||||
random_pose_rollout_counter: 0.0
|
|
@ -100,7 +100,7 @@ def launch_setup(context, *args, **kwargs):
|
|||
|
||||
rbs_runtime = Node(
|
||||
package="rbs_runtime",
|
||||
executable="runtime.py",
|
||||
executable="runtime",
|
||||
parameters=[robot_description, {"use_sim_time": True}],
|
||||
)
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
<maintainer email="ur.narmak@gmail.com">narmak</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
|
||||
<test_depend>ament_lint_auto</test_depend>
|
||||
<test_depend>ament_lint_common</test_depend>
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
|
|
|
@ -22,8 +22,8 @@ DEFAULT_SCENE: SceneData = SceneData(
|
|||
robot=RobotData(
|
||||
name="rbs_arm",
|
||||
with_gripper=True,
|
||||
joint_positions=[0, 0, 0, 0, 0, 0, 0],
|
||||
gripper_joint_positions=[0],
|
||||
joint_positions=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
gripper_joint_positions=0.0,
|
||||
),
|
||||
objects=[MeshObjectData("bishop", position=(0.0, 1.0, 0.3))],
|
||||
camera=[CameraData("robot_camera")],
|
0
env_manager/rbs_runtime/resource/rbs_runtime
Normal file
0
env_manager/rbs_runtime/resource/rbs_runtime
Normal file
|
@ -1,17 +0,0 @@
|
|||
from env_manager.models.configs import SceneData, RobotData, MeshObjectData, CameraData
|
||||
from dataclasses import asdict
|
||||
import json
|
||||
|
||||
scene: SceneData = SceneData(
|
||||
robot=RobotData(
|
||||
name="rbs_arm",
|
||||
with_gripper=True,
|
||||
joint_positions=[0, 0, 0, 0, 0, 0, 0],
|
||||
gripper_joint_positions=[0],
|
||||
),
|
||||
objects=[MeshObjectData("bishop", position=(0.0, 1.0, 0.3))],
|
||||
camera=[CameraData("robot_camera")],
|
||||
)
|
||||
|
||||
with open("scene_config.json", "w") as file:
|
||||
json.dump(asdict(scene), file)
|
4
env_manager/rbs_runtime/setup.cfg
Normal file
4
env_manager/rbs_runtime/setup.cfg
Normal file
|
@ -0,0 +1,4 @@
|
|||
[develop]
|
||||
script_dir=$base/lib/rbs_runtime
|
||||
[install]
|
||||
install_scripts=$base/lib/rbs_runtime
|
32
env_manager/rbs_runtime/setup.py
Normal file
32
env_manager/rbs_runtime/setup.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
from glob import glob
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = "rbs_runtime"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.0.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
|
||||
("share/" + package_name, ["package.xml"]),
|
||||
(
|
||||
os.path.join("share", package_name, "launch"),
|
||||
glob(os.path.join("launch", "*launch.[pxy][yma]*")),
|
||||
),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="narmak",
|
||||
maintainer_email="ur.narmak@gmail.com",
|
||||
description="TODO: Package description",
|
||||
license="Apache-2.0",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"runtime = rbs_runtime.scripts.runtime:main",
|
||||
],
|
||||
},
|
||||
)
|
25
env_manager/rbs_runtime/test/test_copyright.py
Normal file
25
env_manager/rbs_runtime/test/test_copyright.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_copyright.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
# Remove the `skip` decorator once the source file(s) have a copyright header
|
||||
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
|
||||
@pytest.mark.copyright
|
||||
@pytest.mark.linter
|
||||
def test_copyright():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found errors'
|
25
env_manager/rbs_runtime/test/test_flake8.py
Normal file
25
env_manager/rbs_runtime/test/test_flake8.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2017 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_flake8.main import main_with_errors
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.flake8
|
||||
@pytest.mark.linter
|
||||
def test_flake8():
|
||||
rc, errors = main_with_errors(argv=[])
|
||||
assert rc == 0, \
|
||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||
'\n'.join(errors)
|
23
env_manager/rbs_runtime/test/test_pep257.py
Normal file
23
env_manager/rbs_runtime/test/test_pep257.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_pep257.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.linter
|
||||
@pytest.mark.pep257
|
||||
def test_pep257():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found code style errors / warnings'
|
|
@ -1,186 +0,0 @@
|
|||
{
|
||||
"physics_rollouts_num": 0,
|
||||
"gravity": [
|
||||
0.0,
|
||||
0.0,
|
||||
-9.80665
|
||||
],
|
||||
"gravity_std": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.0232
|
||||
],
|
||||
"robot": {
|
||||
"name": "rbs_arm",
|
||||
"urdf_string": "",
|
||||
"spawn_position": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"spawn_quat_xyzw": [
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0
|
||||
],
|
||||
"joint_positioins": [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"with_gripper": true,
|
||||
"gripper_jont_positions": [
|
||||
0
|
||||
],
|
||||
"randomizer": {
|
||||
"pose": false,
|
||||
"spawn_volume": [
|
||||
1.0,
|
||||
1.0,
|
||||
0.0
|
||||
],
|
||||
"joint_positions": false,
|
||||
"joint_positions_std": 0.1,
|
||||
"joint_positions_above_object_spawn": false,
|
||||
"joint_positions_above_object_spawn_elevation": 0.2,
|
||||
"joint_positions_above_object_spawn_xy_randomness": 0.2
|
||||
}
|
||||
},
|
||||
"terrain": {
|
||||
"name": "ground",
|
||||
"type": "flat",
|
||||
"spawn_position": [
|
||||
0,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"spawn_quat_xyzw": [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1
|
||||
],
|
||||
"size": [
|
||||
1.5,
|
||||
1.5
|
||||
],
|
||||
"model_rollouts_num": 1
|
||||
},
|
||||
"light": {
|
||||
"type": "sun",
|
||||
"direction": [
|
||||
0.5,
|
||||
-0.25,
|
||||
-0.75
|
||||
],
|
||||
"random_minmax_elevation": [
|
||||
-0.15,
|
||||
-0.65
|
||||
],
|
||||
"color": [
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"distance": 1000.0,
|
||||
"visual": true,
|
||||
"radius": 25.0,
|
||||
"model_rollouts_num": 1
|
||||
},
|
||||
"objects": [
|
||||
{
|
||||
"name": "bishop",
|
||||
"type": "",
|
||||
"relative_to": "world",
|
||||
"position": [
|
||||
0.0,
|
||||
1.0,
|
||||
0.3
|
||||
],
|
||||
"orientation": [
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"static": false,
|
||||
"randomize": {
|
||||
"count": 0,
|
||||
"random_pose": false,
|
||||
"random_position": false,
|
||||
"random_orientation": false,
|
||||
"random_model": false,
|
||||
"random_spawn_position_segments": [],
|
||||
"random_spawn_position_update_workspace_centre": false,
|
||||
"random_spawn_volume": [
|
||||
0.5,
|
||||
0.5,
|
||||
0.5
|
||||
],
|
||||
"models_rollouts_num": 0
|
||||
},
|
||||
"texture": []
|
||||
}
|
||||
],
|
||||
"camera": [
|
||||
{
|
||||
"name": "robot_camera",
|
||||
"enable": true,
|
||||
"type": "rgbd_camera",
|
||||
"relative_to": "base_link",
|
||||
"width": 128,
|
||||
"height": 128,
|
||||
"image_format": "R8G8B8",
|
||||
"update_rate": 10,
|
||||
"horizontal_fov": 1.0471975511965976,
|
||||
"vertical_fov": 1.0471975511965976,
|
||||
"clip_color": [
|
||||
0.01,
|
||||
1000.0
|
||||
],
|
||||
"clip_depth": [
|
||||
0.05,
|
||||
10.0
|
||||
],
|
||||
"noise_mean": null,
|
||||
"noise_stddev": null,
|
||||
"publish_color": false,
|
||||
"publish_depth": false,
|
||||
"publish_points": false,
|
||||
"spawn_position": [
|
||||
0,
|
||||
0,
|
||||
1
|
||||
],
|
||||
"spawn_quat_xyzw": [
|
||||
0,
|
||||
0.70710678118,
|
||||
0,
|
||||
0.70710678118
|
||||
],
|
||||
"random_pose_rollouts_num": 1,
|
||||
"random_pose_mode": "orbit",
|
||||
"random_pose_orbit_distance": 1.0,
|
||||
"random_pose_orbit_height_range": [
|
||||
0.1,
|
||||
0.7
|
||||
],
|
||||
"random_pose_orbit_ignore_arc_behind_robot": 0.39269908169872414,
|
||||
"random_pose_select_position_options": [],
|
||||
"random_pose_focal_point_z_offset": 0.0,
|
||||
"random_pose_rollout_counter": 0.0
|
||||
}
|
||||
],
|
||||
"plugins": {
|
||||
"scene_broadcaster": false,
|
||||
"user_commands": false,
|
||||
"fts_broadcaster": false,
|
||||
"sensors_render_engine": "ogre2"
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue