diff --git a/env_manager/.gitignore b/env_manager/.gitignore
new file mode 100644
index 0000000..48734be
--- /dev/null
+++ b/env_manager/.gitignore
@@ -0,0 +1,178 @@
+# Created by https://www.toptal.com/developers/gitignore/api/python,ros3
+# Edit at https://www.toptal.com/developers/gitignore?templates=python,ros3
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+**/__pycache__/
+**/*.py[cod]
+**/*$py.class
+**/*.cpython-*.pyc
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+### Python Patch ###
+# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
+poetry.toml
+
+# ruff
+.ruff_cache/
+
+# LSP config files
+pyrightconfig.json
+
+#!! ERROR: ros3 is undefined. Use list command to see defined gitignore types !!#
+
+# End of https://www.toptal.com/developers/gitignore/api/python,ros3
diff --git a/env_manager/docs/about/img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png b/env_manager/docs/about/img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png
new file mode 100644
index 0000000..df47ede
Binary files /dev/null and b/env_manager/docs/about/img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png differ
diff --git a/env_manager/docs/about/img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png b/env_manager/docs/about/img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png
new file mode 100644
index 0000000..df5d2d8
Binary files /dev/null and b/env_manager/docs/about/img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png differ
diff --git a/env_manager/docs/about/img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png b/env_manager/docs/about/img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png
new file mode 100644
index 0000000..7c3dcba
Binary files /dev/null and b/env_manager/docs/about/img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png differ
diff --git a/env_manager/docs/about/img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png b/env_manager/docs/about/img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png
new file mode 100644
index 0000000..aa7b626
Binary files /dev/null and b/env_manager/docs/about/img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png differ
diff --git a/env_manager/docs/about/index.ru.md b/env_manager/docs/about/index.ru.md
new file mode 100644
index 0000000..ee8fa08
--- /dev/null
+++ b/env_manager/docs/about/index.ru.md
@@ -0,0 +1,55 @@
+# Модуль управления виртуальными средами
+---
+
+При управлении роботами в симуляторе Gazebo через фреймворк ROS2 возникает необходимость конфигурировать не только робота-манипулятора, но и саму сцену. Однако стандартный подход, основанный на конфигурационных файлах Gazebo, зачастую оказывается избыточным и недостаточно гибким для динамических сценариев, к которым относится обучение с подкреплением.
+
+**env_manager** — это пакет, предназначенный для конфигурирования сцен в симуляторе Gazebo, предоставляющий более удобный и гибкий подход к созданию и настройке симуляционных сред.
+
+## Возможности пакета
+
+С последнего обновления модуль был полностью переработан. Если ранее его функции ограничивались указанием объектов, находящихся в среде, для работы в ROS2, то теперь он предоставляет инструменты для:
+- полного конфигурирования сцены,
+- настройки объектов наблюдения для ROS2.
+
+Конфигурация осуществляется с использованием **датаклассов** или **YAML-файлов**, что соответствует декларативному подходу описания сцены. Это делает процесс настройки интуитивно понятным и легко масштабируемым. Пример файла описания сцены, а также файл с конфигурацией по умолчанию доступны [здесь](https://gitlab.com/solid-sinusoid/env_manager/-/blob/b425a1b012bc8320bba7b68e5481da187d64d76e/rbs_runtime/config/default-scene-config.yaml).
+
+## Возможности конфигурации
+
+Модуль поддерживает добавление различных типов объектов в сцену, включая:
+- **Модель**
+- **Меш**
+- **Бокс**
+- **Цилиндр**
+- **Сферу**
+
+Различие между "моделью" и "мешем" заключается в том, находится ли объект в библиотеке **rbs_assets_library** (подробнее о ней см. [соответствующий раздел](https://robossembler.org/docs/software/ros2#rbs_assets_library)). Дополнительно поддерживается **рандомизация объектов**, позволяющая случайным образом изменять их цвет и положение в сцене.
+
+Помимо объектов, с помощью пакета можно настраивать:
+- **Источники света**
+- **Сенсоры**
+- **Роботов**
+- **Рабочие поверхности**
+
+Каждый тип объекта обладает как параметрами размещения, так и параметрами рандомизации. Для камер предусмотрены настройки публикации данных:
+- изображения глубины
+- цветного изображения
+- облаков точек.
+
+Параметры рандомизации могут включать в себя положение, ориентацию в заданных пользователем пределах, а также. Для рабочей поверхности также включается возможность рандомизации текстуры, а для робота имеется возможность рандомизировать его положения, в том числе конфигурацию и расположение базы робота.
+
+## Архитектура и спецификации
+
+Основная структура модуля включает обертки для добавления объектов в сцену. Полная спецификация доступных параметров и взаимосвязей между классами представлена в папке конфигураций. Для каждой категории объектов используются отдельные датаклассы, что упрощает организацию и модификацию параметров.
+
+Диаграмма классов конфигурации сцены представлена ниже:
+
+
+*Диаграмма классов конфигурации сцены*
+
+## Примеры
+
+Ниже представлены различные сцены, созданные с использованием возможностей **env_manager**:
+
+| **Сценарий 1** | **Сценарий 2** | **Сценарий 3** |
+|-----------------|-----------------|-----------------|
+|  |  |  |
diff --git a/env_manager/docs/getting_started/getting_started.ru.md b/env_manager/docs/getting_started/getting_started.ru.md
new file mode 100644
index 0000000..35af9b9
--- /dev/null
+++ b/env_manager/docs/getting_started/getting_started.ru.md
@@ -0,0 +1,48 @@
+# Начало работы с rbs_gym
+
+Пакет входит в проект Robossembler. Для установки пакета необходимо произвести установку всего проекта по [инструкции](https://gitlab.com/robossembler/robossembler-ros2/-/blob/b109c97b5c093e665135179668cb2091e6708387/docs/ru/installation.md)
+
+## Запуск тестовой среды
+
+Для запуска тестовой среды необходимо выполнить команду
+
+```sh
+ros2 launch rbs_gym test_env.launch.py base_link_name:=base_link ee_link_name:=gripper_grasp_point control_space:=task control_strategy:=effort interactive:=false
+```
+
+Это продемонстрирует что все установилось и работает адекватно. Для визуализации можно воспользоваться графическим клиентом Gazebo в новом терминале:
+```sh
+ign gazebo -g
+```
+
+## Запуск обучения тестовой среды
+
+Запуск обучения производится следующей командой:
+
+```sh
+ros2 launch rbs_gym train.launch.py base_link_name:=base_link ee_link_name:=gripper_grasp_point control_space:=task control_strategy:=effort interactive:=false
+```
+
+Команда запустит обучения алгоритмом SAC. Полный перечень аргументов можно посмотреть общим флагом `--show-args` применимым ко всем файлам запуска запускаемым посредством команды `ros2 launch`.
+
+```sh
+ros2 launch rbs_gym train.launch.py --show-args
+```
+
+Метрики качества обучения можно наблюдать в папке `logs`, которая автоматически создастся в том месте откуда Вы запускали обучение агента. Для этого надо перейти в эту директорию и выполнить команду
+
+```sh
+cd logs
+aim up
+```
+
+Это выведет в консоль ссылку на веб интерфейс [Aim](https://aimstack.io/) который необходимо открыть в браузере.
+
+## Модификация сцены
+
+Обычно среда всегда прочно связана со сценой. Поэтому для модификации сцены часто также необходимо вносить правки и в среду
+
+Сцена задается как конфигурация [`env_manager`](../about/index.ru.md) пример конфигурации для сцены можно подглядеть [тут](../../rbs_gym/rbs_gym/envs/__init__.py) для тестовой среды
+
+В настоящий момент функционал активно разрабатывается в дальнейшем появится более удобный способ модификации сцены и модификации параметров среды с использованием [Hydra](https://hydra.cc/)
+
diff --git a/env_manager/env_manager/LICENSE b/env_manager/env_manager/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/env_manager/env_manager/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/env_manager/env_manager/env_manager/__init__.py b/env_manager/env_manager/env_manager/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/env_manager/env_manager/env_manager/models/__init__.py b/env_manager/env_manager/env_manager/models/__init__.py
new file mode 100644
index 0000000..2cef2a1
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/__init__.py
@@ -0,0 +1,5 @@
+from .lights import *
+from .objects import *
+from .robots import *
+from .sensors import *
+from .terrains import *
diff --git a/env_manager/env_manager/env_manager/models/configs/__init__.py b/env_manager/env_manager/env_manager/models/configs/__init__.py
new file mode 100644
index 0000000..10201ce
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/__init__.py
@@ -0,0 +1,15 @@
+from .camera import CameraData
+from .light import LightData
+from .objects import (
+ BoxObjectData,
+ CylinderObjectData,
+ MeshData,
+ ModelData,
+ ObjectData,
+ PlaneObjectData,
+ SphereObjectData,
+ ObjectRandomizerData,
+)
+from .robot import RobotData
+from .terrain import TerrainData
+from .scene import SceneData
diff --git a/env_manager/env_manager/env_manager/models/configs/camera.py b/env_manager/env_manager/env_manager/models/configs/camera.py
new file mode 100644
index 0000000..c76e8d9
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/camera.py
@@ -0,0 +1,111 @@
+from dataclasses import dataclass, field
+
+import numpy as np
+
+
+@dataclass
+class CameraData:
+ """
+ CameraData stores the parameters and configuration settings for an RGB-D camera.
+
+ Attributes:
+ ----------
+ name : str
+ The name of the camera instance. Default is an empty string.
+ enable : bool
+ Boolean flag to enable or disable the camera. Default is True.
+ type : str
+ Type of the camera. Default is "rgbd_camera".
+ relative_to : str
+ The reference frame relative to which the camera is placed. Default is "base_link".
+
+ width : int
+ The image width (in pixels) captured by the camera. Default is 128.
+ height : int
+ The image height (in pixels) captured by the camera. Default is 128.
+ image_format : str
+ The image format used (e.g., "R8G8B8"). Default is "R8G8B8".
+ update_rate : int
+ The rate (in Hz) at which the camera provides updates. Default is 10 Hz.
+ horizontal_fov : float
+ The horizontal field of view (in radians). Default is pi / 3.0.
+ vertical_fov : float
+ The vertical field of view (in radians). Default is pi / 3.0.
+
+ clip_color : tuple[float, float]
+ The near and far clipping planes for the color camera. Default is (0.01, 1000.0).
+ clip_depth : tuple[float, float]
+ The near and far clipping planes for the depth camera. Default is (0.05, 10.0).
+
+ noise_mean : float | None
+ The mean value of the Gaussian noise applied to the camera's data. Default is None (no noise).
+ noise_stddev : float | None
+ The standard deviation of the Gaussian noise applied to the camera's data. Default is None (no noise).
+
+ publish_color : bool
+ Whether or not to publish color data from the camera. Default is False.
+ publish_depth : bool
+ Whether or not to publish depth data from the camera. Default is False.
+ publish_points : bool
+ Whether or not to publish point cloud data from the camera. Default is False.
+
+ spawn_position : tuple[float, float, float]
+ The initial spawn position (x, y, z) of the camera relative to the reference frame. Default is (0, 0, 1).
+ spawn_quat_xyzw : tuple[float, float, float, float]
+ The initial spawn orientation of the camera in quaternion (x, y, z, w). Default is (0, 0.70710678118, 0, 0.70710678118).
+
+ random_pose_rollouts_num : int
+ The number of random pose rollouts. Default is 1.
+ random_pose_mode : str
+ The mode for random pose generation (e.g., "orbit"). Default is "orbit".
+ random_pose_orbit_distance : float
+ The fixed distance from the object in "orbit" mode. Default is 1.0.
+ random_pose_orbit_height_range : tuple[float, float]
+ The range of vertical positions (z-axis) in "orbit" mode. Default is (0.1, 0.7).
+ random_pose_orbit_ignore_arc_behind_robot : float
+ The arc angle (in radians) behind the robot to ignore when generating random poses. Default is pi / 8.
+ random_pose_select_position_options : list[tuple[float, float, float]]
+ A list of preset position options for the camera in random pose mode. Default is an empty list.
+ random_pose_focal_point_z_offset : float
+ The offset in the z-direction of the focal point when generating random poses. Default is 0.0.
+ random_pose_rollout_counter : float
+ A counter tracking the number of rollouts completed for random poses. Default is 0.0.
+ """
+
+ name: str = field(default_factory=str)
+ enable: bool = field(default=True)
+ type: str = field(default="rgbd_camera")
+ relative_to: str = field(default="base_link")
+
+ width: int = field(default=128)
+ height: int = field(default=128)
+ image_format: str = field(default="R8G8B8")
+ update_rate: int = field(default=10)
+ horizontal_fov: float = field(default=np.pi / 3.0)
+ vertical_fov: float = field(default=np.pi / 3.0)
+
+ clip_color: tuple[float, float] = field(default=(0.01, 1000.0))
+ clip_depth: tuple[float, float] = field(default=(0.05, 10.0))
+
+ noise_mean: float | None = None
+ noise_stddev: float | None = None
+
+ publish_color: bool = field(default=False)
+ publish_depth: bool = field(default=False)
+ publish_points: bool = field(default=False)
+
+ spawn_position: tuple[float, float, float] = field(default=(0.0, 0.0, 1.0))
+ spawn_quat_xyzw: tuple[float, float, float, float] = field(
+ default=(0.0, 0.70710678118, 0.0, 0.70710678118)
+ )
+
+ random_pose_rollouts_num: int = field(default=1)
+ random_pose_mode: str = field(default="orbit")
+ random_pose_orbit_distance: float = field(default=1.0)
+ random_pose_orbit_height_range: tuple[float, float] = field(default=(0.1, 0.7))
+ random_pose_orbit_ignore_arc_behind_robot: float = field(default=np.pi / 8)
+ random_pose_select_position_options: list[tuple[float, float, float]] = field(
+ default_factory=list
+ )
+ random_pose_focal_point_z_offset: float = field(default=0.0)
+ random_pose_rollout_counter: float = field(default=0.0)
diff --git a/env_manager/env_manager/env_manager/models/configs/light.py b/env_manager/env_manager/env_manager/models/configs/light.py
new file mode 100644
index 0000000..ea15f52
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/light.py
@@ -0,0 +1,39 @@
+from dataclasses import dataclass, field
+
+
+@dataclass
+class LightData:
+ """
+ LightData stores the configuration and settings for a light source in a simulation or rendering environment.
+
+ Attributes:
+ ----------
+ type : str
+ The type of the light source (e.g., "sun", "point", "spot"). Default is "sun".
+ direction : tuple[float, float, float]
+ The direction vector of the light source in (x, y, z). This is typically used for directional lights like the sun.
+ Default is (0.5, -0.25, -0.75).
+ random_minmax_elevation : tuple[float, float]
+ The minimum and maximum elevation angles (in radians) for randomizing the light's direction.
+ Default is (-0.15, -0.65).
+ color : tuple[float, float, float, float]
+ The RGBA color of the light source. Each value ranges from 0 to 1. Default is (1.0, 1.0, 1.0, 1.0), which represents white light.
+ distance : float
+ The effective distance of the light source. Default is 1000.0 units.
+ visual : bool
+ A flag indicating whether the light source is visually represented in the simulation (e.g., whether it casts visible rays).
+ Default is True.
+ radius : float
+ The radius of the light's influence (for point and spot lights). Default is 25.0 units.
+ model_rollouts_num : int
+ The number of rollouts for randomizing light configurations in different simulation runs. Default is 1.
+ """
+
+ type: str = field(default="sun")
+ direction: tuple[float, float, float] = field(default=(0.5, -0.25, -0.75))
+ random_minmax_elevation: tuple[float, float] = field(default=(-0.15, -0.65))
+ color: tuple[float, float, float, float] = field(default=(1.0, 1.0, 1.0, 1.0))
+ distance: float = field(default=1000.0)
+ visual: bool = field(default=True)
+ radius: float = field(default=25.0)
+ model_rollouts_num: int = field(default=1)
diff --git a/env_manager/env_manager/env_manager/models/configs/objects.py b/env_manager/env_manager/env_manager/models/configs/objects.py
new file mode 100644
index 0000000..167423d
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/objects.py
@@ -0,0 +1,198 @@
+from dataclasses import dataclass, field
+
+from env_manager.utils.types import Pose
+
+
+@dataclass
+class ObjectRandomizerData:
+ """
+ ObjectRandomizerData contains parameters for randomizing object properties during simulation.
+
+ Attributes:
+ ----------
+ count : int
+ The number of objects to randomize. Default is 0.
+ random_pose : bool
+ Flag indicating whether to randomize both the position and orientation of objects. Default is False.
+ random_position : bool
+ Flag indicating whether to randomize the position of objects. Default is False.
+ random_orientation : bool
+ Flag indicating whether to randomize the orientation of objects. Default is False.
+ random_model : bool
+ Flag indicating whether to randomize the model of the objects. Default is False.
+ random_spawn_position_segments : list
+ List of segments within which the object can be randomly spawned. Default is an empty list.
+ random_spawn_position_update_workspace_centre : bool
+ Flag indicating whether to update the workspace center during random spawning. Default is False.
+ random_spawn_volume : tuple[float, float, float]
+ The volume within which objects can be randomly spawned, defined by (x, y, z). Default is (0.5, 0.5, 0.5).
+ models_rollouts_num : int
+ The number of rollouts for randomizing models. Default is 0.
+ """
+
+ count: int = field(default=0)
+ random_pose: bool = field(default=False)
+ random_position: bool = field(default=False)
+ random_orientation: bool = field(default=False)
+ random_model: bool = field(default=False)
+ random_spawn_position_segments: list = field(default_factory=list)
+ random_spawn_position_update_workspace_centre: bool = field(default=False)
+ random_spawn_volume: tuple[float, float, float] = field(default=(0.5, 0.5, 0.5))
+ models_rollouts_num: int = field(default=0)
+ random_color: bool = field(default=False)
+
+
+@dataclass
+class ObjectData:
+ """
+ ObjectData stores the base properties for any object in the simulation environment.
+
+ Attributes:
+ ----------
+ name : str
+ The name of the object. Default is "object".
+ type : str
+ The type of the object (e.g., "sphere", "box"). Default is an empty string.
+ relative_to : str
+ The reference frame relative to which the object is positioned. Default is "world".
+ position : tuple[float, float, float]
+ The position of the object in (x, y, z) coordinates. Default is (0.0, 0.0, 0.0).
+ orientation : tuple[float, float, float, float]
+ The orientation of the object as a quaternion (x, y, z, w). Default is (1.0, 0.0, 0.0, 0.0).
+ static : bool
+ Flag indicating if the object is static in the simulation (immovable). Default is False.
+ randomize : ObjectRandomizerData
+ Object randomizer settings for generating randomized object properties. Default is an empty ObjectRandomizerData instance.
+ """
+
+ name: str = field(default="object")
+ type: str = field(default_factory=str)
+ relative_to: str = field(default="world")
+ position: tuple[float, float, float] = field(default=(0.0, 0.0, 0.0))
+ orientation: tuple[float, float, float, float] = field(default=(1.0, 0.0, 0.0, 0.0))
+ color: tuple[float, float, float, float] | None = field(default=None)
+ static: bool = field(default=False)
+ randomize: ObjectRandomizerData = field(default_factory=ObjectRandomizerData)
+
+
+@dataclass
+class PrimitiveObjectData(ObjectData):
+ """
+ PrimitiveObjectData defines the base properties for primitive objects (e.g., spheres, boxes) in the simulation.
+
+ Attributes:
+ ----------
+ collision : bool
+ Flag indicating whether the object participates in collision detection. Default is True.
+ visual : bool
+ Flag indicating whether the object has a visual representation in the simulation. Default is True.
+ color : tuple[float, float, float, float]
+ The color of the object, represented in RGBA format. Default is (0.8, 0.8, 0.8, 1.0).
+ mass : float
+ The mass of the object. Default is 0.1.
+ """
+
+ collision: bool = field(default=True)
+ visual: bool = field(default=True)
+ mass: float = field(default=0.1)
+
+
+@dataclass
+class SphereObjectData(PrimitiveObjectData):
+ """
+ SphereObjectData defines the specific properties for a spherical object in the simulation.
+
+ Attributes:
+ ----------
+ radius : float
+ The radius of the sphere. Default is 0.025.
+ friction : float
+ The friction coefficient of the sphere when in contact with surfaces. Default is 1.0.
+ """
+
+ radius: float = field(default=0.025)
+ friction: float = field(default=1.0)
+
+
+@dataclass
+class PlaneObjectData(PrimitiveObjectData):
+ """
+ PlaneObjectData defines the specific properties for a planar object in the simulation.
+
+ Attributes:
+ ----------
+ size : tuple[float, float]
+ The size of the plane, defined by its width and length. Default is (1.0, 1.0).
+ direction : tuple[float, float, float]
+ The normal vector representing the direction the plane faces. Default is (0.0, 0.0, 1.0).
+ friction : float
+ The friction coefficient of the plane when in contact with other objects. Default is 1.0.
+ """
+
+ size: tuple = field(default=(1.0, 1.0))
+ direction: tuple = field(default=(0.0, 0.0, 1.0))
+ friction: float = field(default=1.0)
+
+
+@dataclass
+class CylinderObjectData(PrimitiveObjectData):
+ """
+ CylinderObjectData defines the specific properties for a cylindrical object in the simulation.
+
+ Attributes:
+ ----------
+ radius : float
+ The radius of the cylinder. Default is 0.025.
+ length : float
+ The length of the cylinder. Default is 0.1.
+ friction : float
+ The friction coefficient of the cylinder when in contact with surfaces. Default is 1.0.
+ """
+
+ radius: float = field(default=0.025)
+ length: float = field(default=0.1)
+ friction: float = field(default=1.0)
+
+
+@dataclass
+class BoxObjectData(PrimitiveObjectData):
+ """
+ BoxObjectData defines the specific properties for a box-shaped object in the simulation.
+
+ Attributes:
+ ----------
+ size : tuple[float, float, float]
+ The dimensions of the box in (width, height, depth). Default is (0.05, 0.05, 0.05).
+ friction : float
+ The friction coefficient of the box when in contact with surfaces. Default is 1.0.
+ """
+
+ size: tuple = field(default=(0.05, 0.05, 0.05))
+ friction: float = field(default=1.0)
+
+
+@dataclass
+class ModelData(ObjectData):
+ """
+ MeshObjectData defines the specific properties for a mesh-based object in the simulation.
+
+ Attributes:
+ ----------
+ texture : list[float]
+ A list of texture coordinates or texture properties applied to the mesh. Default is an empty list.
+ """
+
+ texture: list[float] = field(default_factory=list)
+
+
+@dataclass
+class MeshData(ModelData):
+ mass: float = field(default_factory=float)
+ inertia: tuple[float, float, float, float, float, float] = field(
+ default_factory=tuple
+ )
+ cm: Pose = field(default_factory=Pose)
+ collision: str = field(default_factory=str)
+ visual: str = field(default_factory=str)
+ friction: float = field(default_factory=float)
+ density: float = field(default_factory=float)
diff --git a/env_manager/env_manager/env_manager/models/configs/robot.py b/env_manager/env_manager/env_manager/models/configs/robot.py
new file mode 100644
index 0000000..b2f90ca
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/robot.py
@@ -0,0 +1,103 @@
+from dataclasses import dataclass, field
+from enum import Enum
+
+
+@dataclass
+class ToolData:
+ name: str = field(default_factory=str)
+ type: str = field(default_factory=str)
+
+
+@dataclass
+class GripperData(ToolData):
+ pass
+
+
+@dataclass
+class ParallelGripperData(GripperData):
+ pass
+
+
+@dataclass
+class MultifingerGripperData(GripperData):
+ pass
+
+
+@dataclass
+class VacuumGripperData(GripperData):
+ pass
+
+
+class GripperEnum(Enum):
+ parallel = ParallelGripperData
+ mulrifinger = MultifingerGripperData
+ vacuum = VacuumGripperData
+
+
+@dataclass
+class RobotRandomizerData:
+ """
+ RobotRandomizerData stores configuration parameters for randomizing robot properties during simulation.
+
+ Attributes:
+ ----------
+ pose : bool
+ Flag indicating whether to randomize the robot's pose (position and orientation). Default is False.
+ spawn_volume : tuple[float, float, float]
+ The volume within which the robot can be spawned, defined by (x, y, z) dimensions. Default is (1.0, 1.0, 0.0).
+ joint_positions : bool
+ Flag indicating whether to randomize the robot's joint positions. Default is False.
+ joint_positions_std : float
+ The standard deviation for randomizing the robot's joint positions. Default is 0.1.
+ joint_positions_above_object_spawn : bool
+ Flag indicating whether the robot's joint positions should be randomized to place the robot above the object's spawn position. Default is False.
+ joint_positions_above_object_spawn_elevation : float
+ The elevation above the object's spawn position when placing the robot's joints. Default is 0.2.
+ joint_positions_above_object_spawn_xy_randomness : float
+ The randomness in the x and y coordinates when placing the robot's joints above the object's spawn position. Default is 0.2.
+ """
+
+ pose: bool = field(default=False)
+ spawn_volume: tuple[float, float, float] = field(default=(1.0, 1.0, 0.0))
+ joint_positions: bool = field(default=False)
+ joint_positions_std: float = field(default=0.1)
+ joint_positions_above_object_spawn: bool = field(default=False)
+ joint_positions_above_object_spawn_elevation: float = field(default=0.2)
+ joint_positions_above_object_spawn_xy_randomness: float = field(default=0.2)
+
+
+@dataclass
+class RobotData:
+ """
+ RobotData stores the base properties and configurations for a robot in the simulation.
+
+ Attributes:
+ ----------
+ name : str
+ The name of the robot. Default is an empty string.
+ urdf_string : str
+ Optional parameter that can store a URDF. This parameter is overridden by the node parameters if set in the node configuration.
+ spawn_position : tuple[float, float, float]
+ The position where the robot will be spawned in (x, y, z) coordinates. Default is (0.0, 0.0, 0.0).
+ spawn_quat_xyzw : tuple[float, float, float, float]
+ The orientation of the robot in quaternion format (x, y, z, w) at spawn. Default is (0.0, 0.0, 0.0, 1.0).
+ joint_positions : list[float]
+ A list of the robot's joint positions. Default is an empty list.
+ with_gripper : bool
+ Flag indicating whether the robot is equipped with a gripper. Default is False.
+ gripper_joint_positions : list[float] | float
+ The joint positions for the gripper. Can be a list of floats or a single float. Default is an empty list.
+ randomizer : RobotRandomizerData
+ The randomization settings for the robot, allowing various properties to be randomized in simulation. Default is an instance of RobotRandomizerData.
+ """
+
+ name: str = field(default_factory=str)
+ urdf_string: str = field(default_factory=str)
+ spawn_position: tuple[float, float, float] = field(default=(0.0, 0.0, 0.0))
+ spawn_quat_xyzw: tuple[float, float, float, float] = field(
+ default=(0.0, 0.0, 0.0, 1.0)
+ )
+ joint_positions: list[float] = field(default_factory=list)
+ with_gripper: bool = field(default=False)
+ gripper_joint_positions: list[float] | float = field(default_factory=list)
+ randomizer: RobotRandomizerData = field(default_factory=RobotRandomizerData)
diff --git a/env_manager/env_manager/env_manager/models/configs/scene.py b/env_manager/env_manager/env_manager/models/configs/scene.py
new file mode 100644
index 0000000..94d078f
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/scene.py
@@ -0,0 +1,68 @@
+from dataclasses import dataclass, field
+
+from .camera import CameraData
+from .light import LightData
+from .objects import ObjectData
+from .robot import RobotData
+from .terrain import TerrainData
+
+
+@dataclass
+class PluginsData:
+ """
+ PluginsData stores the configuration for various plugins used in the simulation environment.
+
+ Attributes:
+ ----------
+ scene_broadcaster : bool
+ Flag indicating whether the scene broadcaster plugin is enabled. Default is False.
+ user_commands : bool
+ Flag indicating whether the user commands plugin is enabled. Default is False.
+ fts_broadcaster : bool
+ Flag indicating whether the force torque sensor (FTS) broadcaster plugin is enabled. Default is False.
+ sensors_render_engine : str
+ The rendering engine used for sensors. Default is "ogre2".
+ """
+
+ scene_broadcaster: bool = field(default_factory=bool)
+ user_commands: bool = field(default_factory=bool)
+ fts_broadcaster: bool = field(default_factory=bool)
+ sensors_render_engine: str = field(default="ogre2")
+
+
+@dataclass
+class SceneData:
+ """
+ SceneData contains the configuration and settings for the simulation scene.
+
+ Attributes:
+ ----------
+ physics_rollouts_num : int
+ The number of physics rollouts to perform in the simulation. Default is 0.
+ gravity : tuple[float, float, float]
+ The gravitational acceleration vector applied in the scene, represented as (x, y, z). Default is (0.0, 0.0, -9.80665).
+ gravity_std : tuple[float, float, float]
+ The standard deviation for the gravitational acceleration, represented as (x, y, z). Default is (0.0, 0.0, 0.0232).
+ robot : RobotData
+ The configuration data for the robot present in the scene. Default is an instance of RobotData.
+ terrain : TerrainData
+ The configuration data for the terrain in the scene. Default is an instance of TerrainData.
+ light : LightData
+ The configuration data for the lighting in the scene. Default is an instance of LightData.
+ objects : list[ObjectData]
+ A list of objects present in the scene, represented by their ObjectData configurations. Default is an empty list.
+ camera : list[CameraData]
+ A list of cameras in the scene, represented by their CameraData configurations. Default is an empty list.
+ plugins : PluginsData
+ The configuration data for various plugins utilized in the simulation environment. Default is an instance of PluginsData.
+ """
+
+ physics_rollouts_num: int = field(default=0)
+ gravity: tuple[float, float, float] = field(default=(0.0, 0.0, -9.80665))
+ gravity_std: tuple[float, float, float] = field(default=(0.0, 0.0, 0.0232))
+ robot: RobotData = field(default_factory=RobotData)
+ terrain: TerrainData = field(default_factory=TerrainData)
+ light: LightData = field(default_factory=LightData)
+ objects: list[ObjectData] = field(default_factory=list)
+ camera: list[CameraData] = field(default_factory=list)
+ plugins: PluginsData = field(default_factory=PluginsData)
diff --git a/env_manager/env_manager/env_manager/models/configs/terrain.py b/env_manager/env_manager/env_manager/models/configs/terrain.py
new file mode 100644
index 0000000..f240132
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/configs/terrain.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass, field
+
+
+@dataclass
+class TerrainData:
+ """
+ TerrainData stores the configuration for the terrain in the simulation environment.
+
+ Attributes:
+ ----------
+ name : str
+ The name of the terrain. Default is "ground".
+ type : str
+ The type of terrain (e.g., "flat", "hilly", "uneven"). Default is "flat".
+ spawn_position : tuple[float, float, float]
+ The position where the terrain will be spawned in the simulation, represented as (x, y, z). Default is (0, 0, 0).
+ spawn_quat_xyzw : tuple[float, float, float, float]
+ The orientation of the terrain at spawn, represented as a quaternion (x, y, z, w). Default is (0, 0, 0, 1).
+ size : tuple[float, float]
+ The size of the terrain, represented as (width, length). Default is (1.5, 1.5).
+ model_rollouts_num : int
+ The number of rollouts for randomizing terrain models. Default is 1.
+ """
+
+ name: str = field(default="ground")
+ type: str = field(default="flat")
+ spawn_position: tuple[float, float, float] = field(default=(0, 0, 0))
+ spawn_quat_xyzw: tuple[float, float, float, float] = field(default=(0, 0, 0, 1))
+ size: tuple[float, float] = field(default=(1.5, 1.5))
+ ambient: tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0))
+ specular: tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0))
+ diffuse: tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0))
+ model_rollouts_num: int = field(default=1)
diff --git a/env_manager/env_manager/env_manager/models/lights/__init__.py b/env_manager/env_manager/env_manager/models/lights/__init__.py
new file mode 100644
index 0000000..f791750
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/lights/__init__.py
@@ -0,0 +1,34 @@
+from enum import Enum
+from gym_gz.scenario.model_wrapper import ModelWrapper
+
+from .random_sun import RandomSun
+from .sun import Sun
+
+
+# Enum для типов света
+class LightType(Enum):
+ SUN = "sun"
+ RANDOM_SUN = "random_sun"
+
+
+LIGHT_MODEL_CLASSES = {
+ LightType.SUN: Sun,
+ LightType.RANDOM_SUN: RandomSun,
+}
+
+
+def get_light_model_class(light_type: str) -> type[ModelWrapper]:
+ try:
+ light_enum = LightType(light_type)
+ return LIGHT_MODEL_CLASSES[light_enum]
+ except KeyError:
+ raise ValueError(f"Unsupported light type: {light_type}")
+
+
+def is_light_type_randomizable(light_type: str) -> bool:
+ try:
+ light_enum = LightType(light_type)
+ return light_enum == LightType.RANDOM_SUN
+ except ValueError:
+ return False
+
diff --git a/env_manager/env_manager/env_manager/models/lights/random_sun.py b/env_manager/env_manager/env_manager/models/lights/random_sun.py
new file mode 100644
index 0000000..e91f56c
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/lights/random_sun.py
@@ -0,0 +1,240 @@
+import numpy as np
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from numpy.random import RandomState
+from scenario import core as scenario
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class RandomSun(model_wrapper.ModelWrapper):
+ """
+ RandomSun is a class that creates a randomly positioned directional light (like the Sun) in a simulation world.
+
+ Attributes:
+ ----------
+ world : scenario_gazebo.World
+ The simulation world where the model is inserted.
+ name : str
+ The name of the sun model. Default is "sun".
+ minmax_elevation : tuple[float, float]
+ Minimum and maximum values for the random elevation angle of the sun. Default is (-0.15, -0.65).
+ distance : float
+ Distance from the origin to the sun in the simulation. Default is 800.0.
+ visual : bool
+ Flag indicating whether the sun should have a visual sphere in the simulation. Default is True.
+ radius : float
+ Radius of the visual sphere representing the sun. Default is 20.0.
+ color_minmax_r : tuple[float, float]
+ Range for the red component of the sun's light color. Default is (1.0, 1.0).
+ color_minmax_g : tuple[float, float]
+ Range for the green component of the sun's light color. Default is (1.0, 1.0).
+ color_minmax_b : tuple[float, float]
+ Range for the blue component of the sun's light color. Default is (1.0, 1.0).
+ specular : float
+ The specular factor for the sun's light. Default is 1.0.
+ attenuation_minmax_range : tuple[float, float]
+ Range for the light attenuation (falloff) distance. Default is (750.0, 15000.0).
+ attenuation_minmax_constant : tuple[float, float]
+ Range for the constant component of the light attenuation. Default is (0.5, 1.0).
+ attenuation_minmax_linear : tuple[float, float]
+ Range for the linear component of the light attenuation. Default is (0.001, 0.1).
+ attenuation_minmax_quadratic : tuple[float, float]
+ Range for the quadratic component of the light attenuation. Default is (0.0001, 0.01).
+ np_random : RandomState | None
+ Random state for generating random values. Default is None, in which case a new random generator is created.
+
+ Raises:
+ -------
+ RuntimeError:
+ Raised if the sun model fails to be inserted into the world.
+
+ Methods:
+ --------
+ get_sdf():
+ Generates the SDF string for the sun model.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "sun",
+ minmax_elevation: tuple[float, float] = (-0.15, -0.65),
+ distance: float = 800.0,
+ visual: bool = True,
+ radius: float = 20.0,
+ color_minmax_r: tuple[float, float] = (1.0, 1.0),
+ color_minmax_g: tuple[float, float] = (1.0, 1.0),
+ color_minmax_b: tuple[float, float] = (1.0, 1.0),
+ specular: float = 1.0,
+ attenuation_minmax_range: tuple[float, float] = (750.0, 15000.0),
+ attenuation_minmax_constant: tuple[float, float] = (0.5, 1.0),
+ attenuation_minmax_linear: tuple[float, float] = (0.001, 0.1),
+ attenuation_minmax_quadratic: tuple[float, float] = (0.0001, 0.01),
+ np_random: RandomState | None = None,
+ **kwargs,
+ ):
+ if np_random is None:
+ np_random = np.random.default_rng()
+
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Get random yaw direction
+ direction = np_random.uniform(-1.0, 1.0, (2,))
+ # Normalize yaw direction
+ direction = direction / np.linalg.norm(direction)
+
+ # Get random elevation
+ direction = np.append(
+ direction,
+ np_random.uniform(minmax_elevation[0], minmax_elevation[1]),
+ )
+ # Normalize again
+ direction = direction / np.linalg.norm(direction)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(
+ (
+ -direction[0] * distance,
+ -direction[1] * distance,
+ -direction[2] * distance,
+ ),
+ (1, 0, 0, 0),
+ )
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ model_name=model_name,
+ direction=tuple(direction),
+ visual=visual,
+ radius=radius,
+ color_minmax_r=color_minmax_r,
+ color_minmax_g=color_minmax_g,
+ color_minmax_b=color_minmax_b,
+ attenuation_minmax_range=attenuation_minmax_range,
+ attenuation_minmax_constant=attenuation_minmax_constant,
+ attenuation_minmax_linear=attenuation_minmax_linear,
+ attenuation_minmax_quadratic=attenuation_minmax_quadratic,
+ specular=specular,
+ np_random=np_random,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
+
+ @classmethod
+ def get_sdf(
+ self,
+ model_name: str,
+ direction: tuple[float, float, float],
+ visual: bool,
+ radius: float,
+ color_minmax_r: tuple[float, float],
+ color_minmax_g: tuple[float, float],
+ color_minmax_b: tuple[float, float],
+ attenuation_minmax_range: tuple[float, float],
+ attenuation_minmax_constant: tuple[float, float],
+ attenuation_minmax_linear: tuple[float, float],
+ attenuation_minmax_quadratic: tuple[float, float],
+ specular: float,
+ np_random: RandomState,
+ ) -> str:
+ """
+ Generates the SDF string for the sun model.
+
+ Args:
+ -----
+ model_name : str
+ The name of the model.
+ direction : Tuple[float, float, float]
+ The direction of the sun's light.
+ visual : bool
+ If True, a visual representation of the sun will be created.
+ radius : float
+ The radius of the visual representation.
+ color_minmax_r : Tuple[float, float]
+ Range for the red component of the light color.
+ color_minmax_g : Tuple[float, float]
+ Range for the green component of the light color.
+ color_minmax_b : Tuple[float, float]
+ Range for the blue component of the light color.
+ attenuation_minmax_range : Tuple[float, float]
+ Range for light attenuation distance.
+ attenuation_minmax_constant : Tuple[float, float]
+ Range for the constant attenuation factor.
+ attenuation_minmax_linear : Tuple[float, float]
+ Range for the linear attenuation factor.
+ attenuation_minmax_quadratic : Tuple[float, float]
+ Range for the quadratic attenuation factor.
+ specular : float
+ The specular reflection factor for the light.
+ np_random : RandomState
+ The random number generator used to sample random values for the parameters.
+
+ Returns:
+ --------
+ str:
+ The SDF string for the sun model.
+ """
+ # Sample random values for parameters
+ color_r = np_random.uniform(color_minmax_r[0], color_minmax_r[1])
+ color_g = np_random.uniform(color_minmax_g[0], color_minmax_g[1])
+ color_b = np_random.uniform(color_minmax_b[0], color_minmax_b[1])
+ attenuation_range = np_random.uniform(
+ attenuation_minmax_range[0], attenuation_minmax_range[1]
+ )
+ attenuation_constant = np_random.uniform(
+ attenuation_minmax_constant[0], attenuation_minmax_constant[1]
+ )
+ attenuation_linear = np_random.uniform(
+ attenuation_minmax_linear[0], attenuation_minmax_linear[1]
+ )
+ attenuation_quadratic = np_random.uniform(
+ attenuation_minmax_quadratic[0], attenuation_minmax_quadratic[1]
+ )
+
+ return f'''
+
+ true
+
+
+ {direction[0]} {direction[1]} {direction[2]}
+
+ {attenuation_range}
+ {attenuation_constant}
+ {attenuation_linear}
+ {attenuation_quadratic}
+
+ {color_r} {color_g} {color_b} 1
+ {specular*color_r} {specular*color_g} {specular*color_b} 1
+ true
+
+ {
+ f"""
+
+
+
+ {radius}
+
+
+
+ {color_r} {color_g} {color_b} 1
+
+ false
+
+ """ if visual else ""
+ }
+
+
+ '''
diff --git a/env_manager/env_manager/env_manager/models/lights/sun.py b/env_manager/env_manager/env_manager/models/lights/sun.py
new file mode 100644
index 0000000..a10c565
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/lights/sun.py
@@ -0,0 +1,191 @@
+import numpy as np
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Sun(model_wrapper.ModelWrapper):
+ """
+ Sun is a class that represents a directional light source in the simulation, similar to the Sun.
+
+ It can have both visual and light properties, with customizable parameters such as color, direction, and attenuation.
+
+ Attributes:
+ ----------
+ world : scenario_gazebo.World
+ The Gazebo world where the sun model will be inserted.
+ name : str
+ The name of the sun model. Default is "sun".
+ direction : tuple[float, float, float]
+ The direction of the sunlight, normalized. Default is (0.5, -0.25, -0.75).
+ color : tuple[float, float, float, float]
+ The RGBA color values for the light's diffuse color. Default is (1.0, 1.0, 1.0, 1.0).
+ distance : float
+ The distance from the origin to the sun. Default is 800.0.
+ visual : bool
+ If True, a visual representation of the sun will be added. Default is True.
+ radius : float
+ The radius of the visual representation of the sun. Default is 20.0.
+ specular : float
+ The intensity of the specular reflection. Default is 1.0.
+ attenuation_range : float
+ The maximum range for the light attenuation. Default is 10000.0.
+ attenuation_constant : float
+ The constant attenuation factor. Default is 0.9.
+ attenuation_linear : float
+ The linear attenuation factor. Default is 0.01.
+ attenuation_quadratic : float
+ The quadratic attenuation factor. Default is 0.001.
+
+ Raises:
+ -------
+ RuntimeError:
+ If the sun model fails to be inserted into the Gazebo world.
+
+ Methods:
+ --------
+ get_sdf() -> str:
+ Generates the SDF string used to describe the sun model in the Gazebo simulation.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "sun",
+ direction: tuple[float, float, float] = (0.5, -0.25, -0.75),
+ color: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0),
+ distance: float = 800.0,
+ visual: bool = True,
+ radius: float = 20.0,
+ specular: float = 1.0,
+ attenuation_range: float = 10000.0,
+ attenuation_constant: float = 0.9,
+ attenuation_linear: float = 0.01,
+ attenuation_quadratic: float = 0.001,
+ **kwargs,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Normalize direction
+ np_direction = np.array(direction)
+ np_direction = np_direction / np.linalg.norm(np_direction)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(
+ (
+ -np_direction[0] * distance,
+ -np_direction[1] * distance,
+ -np_direction[2] * distance,
+ ),
+ (1, 0, 0, 0),
+ )
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ model_name=model_name,
+ direction=tuple(np_direction),
+ color=color,
+ visual=visual,
+ radius=radius,
+ specular=specular,
+ attenuation_range=attenuation_range,
+ attenuation_constant=attenuation_constant,
+ attenuation_linear=attenuation_linear,
+ attenuation_quadratic=attenuation_quadratic,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
+
+ @classmethod
+ def get_sdf(
+ cls,
+ model_name: str,
+ direction: tuple[float, float, float],
+ color: tuple[float, float, float, float],
+ visual: bool,
+ radius: float,
+ specular: float,
+ attenuation_range: float,
+ attenuation_constant: float,
+ attenuation_linear: float,
+ attenuation_quadratic: float,
+ ) -> str:
+ """
+ Generates the SDF string for the sun model.
+
+ Args:
+ -----
+ model_name : str
+ The name of the sun model.
+ direction : tuple[float, float, float]
+ The direction vector for the sunlight.
+ color : tuple[float, float, float, float]
+ The RGBA color values for the sunlight.
+ visual : bool
+ Whether to include a visual representation of the sun (a sphere).
+ radius : float
+ The radius of the visual sphere.
+ specular : float
+ The specular reflection intensity.
+ attenuation_range : float
+ The range of the light attenuation.
+ attenuation_constant : float
+ The constant factor for the light attenuation.
+ attenuation_linear : float
+ The linear factor for the light attenuation.
+ attenuation_quadratic : float
+ The quadratic factor for the light attenuation.
+
+ Returns:
+ --------
+ str:
+ The SDF string for the sun model.
+ """
+
+ return f'''
+
+ true
+
+
+ {direction[0]} {direction[1]} {direction[2]}
+
+ {attenuation_range}
+ {attenuation_constant}
+ {attenuation_linear}
+ {attenuation_quadratic}
+
+ {color[0]} {color[1]} {color[2]} 1
+ {specular*color[0]} {specular*color[1]} {specular*color[2]} 1
+ true
+
+ {
+ f"""
+
+
+
+ {radius}
+
+
+
+ {color[0]} {color[1]} {color[2]} 1
+
+ false
+
+ """ if visual else ""
+ }
+
+
+ '''
diff --git a/env_manager/env_manager/env_manager/models/objects/__init__.py b/env_manager/env_manager/env_manager/models/objects/__init__.py
new file mode 100644
index 0000000..a1a70c6
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/__init__.py
@@ -0,0 +1,53 @@
+from enum import Enum
+
+from gym_gz.scenario.model_wrapper import ModelWrapper
+
+from .model import Model
+from .mesh import Mesh
+from .primitives import Box, Cylinder, Plane, Sphere
+# from .random_object import RandomObject
+from .random_primitive import RandomPrimitive
+
+
+class ObjectType(Enum):
+ BOX = "box"
+ SPHERE = "sphere"
+ CYLINDER = "cylinder"
+ PLANE = "plane"
+ RANDOM_PRIMITIVE = "random_primitive"
+ RANDOM_MESH = "random_mesh"
+ MODEL = "model"
+ MESH = "mesh"
+
+OBJECT_MODEL_CLASSES = {
+ ObjectType.BOX: Box,
+ ObjectType.SPHERE: Sphere,
+ ObjectType.CYLINDER: Cylinder,
+ ObjectType.PLANE: Plane,
+ ObjectType.RANDOM_PRIMITIVE: RandomPrimitive,
+ # ObjectType.RANDOM_MESH: RandomObject,
+ ObjectType.MODEL: Model,
+ ObjectType.MESH: Mesh
+}
+
+
+RANDOMIZABLE_TYPES = {
+ ObjectType.RANDOM_PRIMITIVE,
+ ObjectType.RANDOM_MESH,
+}
+
+
+def get_object_model_class(object_type: str) -> type[ModelWrapper]:
+ try:
+ object_enum = ObjectType(object_type)
+ return OBJECT_MODEL_CLASSES[object_enum]
+ except KeyError:
+ raise ValueError(f"Unsupported object type: {object_type}")
+
+
+def is_object_type_randomizable(object_type: str) -> bool:
+ try:
+ object_enum = ObjectType(object_type)
+ return object_enum in RANDOMIZABLE_TYPES
+ except ValueError:
+ return False
diff --git a/env_manager/env_manager/env_manager/models/objects/mesh.py b/env_manager/env_manager/env_manager/models/objects/mesh.py
new file mode 100644
index 0000000..f467dfd
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/mesh.py
@@ -0,0 +1,183 @@
+import numpy as np
+import trimesh
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+from trimesh import Trimesh
+
+from env_manager.utils.types import Point
+
+InertiaTensor = tuple[float, float, float, float, float, float]
+
+
+class Mesh(model_wrapper.ModelWrapper):
+ """ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "object",
+ type: str = "mesh",
+ relative_to: str = "world",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ color: tuple[float, float, float, float] | None = None,
+ static: bool = False,
+ texture: list[float] = [],
+ mass: float = 0.0,
+ inertia: InertiaTensor | None = None,
+ cm: Point | None = None,
+ collision: str = "",
+ visual: str = "",
+ friction: float = 0.0,
+ density: float = 0.0,
+ **kwargs,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ # Calculate inertia parameters for provided mesh file if not exist
+ if not inertia and collision and density:
+ mass, cm, inertia = self.calculate_inertia(collision, density)
+ else:
+ raise RuntimeError(
+ f"Please provide collision mesh filepath for model {name} to calculate inertia"
+ )
+
+ if not color:
+ color = tuple(np.random.uniform(0.0, 1.0, (3,)))
+ color = (color[0], color[1], color[2], 1.0)
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ name=name,
+ static=static,
+ collision=collision,
+ visual=visual,
+ mass=mass,
+ inertia=inertia,
+ friction=friction,
+ color=color,
+ center_of_mass=cm,
+ gui_only=True,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError(f"Failed to insert {model_name}")
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ super().__init__(model=model)
+
+ def calculate_inertia(self, file_path, density):
+ mesh = trimesh.load(file_path)
+
+ if not isinstance(mesh, Trimesh):
+ raise RuntimeError("Please provide correct stl mesh filepath")
+
+ volume = mesh.volume
+ mass: float = volume * density
+ center_of_mass: Point = tuple(mesh.center_mass)
+ inertia = mesh.moment_inertia
+ eigenvalues = np.linalg.eigvals(inertia)
+ inertia_tensor: InertiaTensor = (
+ inertia[0, 0],
+ inertia[0, 1],
+ inertia[0, 2],
+ inertia[1, 1],
+ inertia[1, 2],
+ inertia[2, 2],
+ )
+ return mass, center_of_mass, inertia_tensor
+
+ @classmethod
+ def get_sdf(
+ cls,
+ name: str,
+ static: bool,
+ collision: str,
+ visual: str,
+ mass: float,
+ inertia: InertiaTensor,
+ friction: float,
+ color: tuple[float, float, float, float],
+ center_of_mass: Point,
+ gui_only: bool,
+ ) -> str:
+ """
+ Generates the SDF string for the box model.
+
+ Args:
+ - mesh_args (MeshPureData): Object that contain data of provided mesh data
+
+ Returns:
+ The SDF string that defines the box model in Gazebo.
+ """
+ return f'''
+
+ {"true" if static else "false"}
+
+ {
+ f"""
+
+
+
+ {collision}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+ """ if collision else ""
+ }
+ {
+ f"""
+
+
+
+ {visual}
+
+
+
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+
+ {1.0 - color[3]}
+ {'1false' if gui_only else ''}
+
+ """ if visual else ""
+ }
+
+ {mass}
+ {center_of_mass[0]} {center_of_mass[1]} {center_of_mass[2]} 0 0 0
+
+ {inertia[0]}
+ {inertia[1]}
+ {inertia[2]}
+ {inertia[3]}
+ {inertia[4]}
+ {inertia[5]}
+
+
+
+
+ '''
diff --git a/env_manager/env_manager/env_manager/models/objects/model.py b/env_manager/env_manager/env_manager/models/objects/model.py
new file mode 100644
index 0000000..1f6de89
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/model.py
@@ -0,0 +1,75 @@
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from rbs_assets_library import get_model_file
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Model(model_wrapper.ModelWrapper):
+ """
+ Represents a 3D mesh model in the Gazebo simulation environment.
+ This class is responsible for loading and inserting a mesh model
+ into the Gazebo world with specified attributes.
+
+ Args:
+ world (scenario_gazebo.World): The Gazebo world where the mesh model will be inserted.
+ name (str, optional): The name of the mesh model. Defaults to "object".
+ position (tuple[float, float, float], optional): The position of the mesh in the world. Defaults to (0, 0, 0).
+ orientation (tuple[float, float, float, float], optional): The orientation of the mesh in quaternion format. Defaults to (1, 0, 0, 0).
+ gui_only (bool, optional): If True, the visual representation of the mesh will only appear in the GUI. Defaults to False.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ RuntimeError: If the mesh model fails to be inserted into the Gazebo world.
+
+ Methods:
+ get_sdf(model_name: str) -> str:
+ Generates the SDF string used to describe the mesh model in the Gazebo simulation.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "object",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ gui_only: bool = False,
+ **kwargs,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ model_name=name,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model(sdf, initial_pose, model_name)
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ super().__init__(model=model)
+
+ @classmethod
+ def get_sdf(
+ cls,
+ model_name: str,
+ ) -> str:
+ """
+ Generates the SDF string for the specified mesh model.
+
+ Args:
+ model_name (str): The name of the mesh model to generate the SDF for.
+
+ Returns:
+ str: The SDF string that defines the mesh model in Gazebo.
+ """
+ return get_model_file(model_name)
diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/__init__.py b/env_manager/env_manager/env_manager/models/objects/primitives/__init__.py
new file mode 100644
index 0000000..0dcf311
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/primitives/__init__.py
@@ -0,0 +1,4 @@
+from .box import Box
+from .cylinder import Cylinder
+from .plane import Plane
+from .sphere import Sphere
diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/box.py b/env_manager/env_manager/env_manager/models/objects/primitives/box.py
new file mode 100644
index 0000000..fae02bb
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/primitives/box.py
@@ -0,0 +1,176 @@
+import numpy as np
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Box(model_wrapper.ModelWrapper):
+ """
+ The Box class represents a 3D box model in the Gazebo simulation environment.
+ It includes physical and visual properties such as size, mass, color, and collision properties.
+
+ Attributes:
+ world (scenario_gazebo.World): The Gazebo world where the box model will be inserted.
+ name (str): The name of the box model. Default is "box".
+ position (tuple[float, float, float]): The position of the box in the world. Default is (0, 0, 0).
+ orientation (tuple[float, float, float, float]): The orientation of the box in quaternion format. Default is (1, 0, 0, 0).
+ size (tuple[float, float, float]): The size of the box (width, length, height). Default is (0.05, 0.05, 0.05).
+ mass (float): The mass of the box in kilograms. Default is 0.1.
+ static (bool): If True, the box will be immovable and static. Default is False.
+ collision (bool): If True, the box will have collision properties. Default is True.
+ friction (float): The friction coefficient for the box’s surface. Default is 1.0.
+ visual (bool): If True, the box will have a visual representation. Default is True.
+ gui_only (bool): If True, the visual representation of the box will only appear in the GUI, not in the simulation physics. Default is False.
+ color (tuple[float, float, float, float]): The RGBA color of the box. Default is (0.8, 0.8, 0.8, 1.0).
+
+ Raises:
+ RuntimeError:
+ If the box model fails to be inserted into the Gazebo world.
+
+ Methods:
+ --------
+ get_sdf() -> str:
+ Generates the SDF string used to describe the box model in the Gazebo simulation.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "box",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ size: tuple[float, float, float] = (0.05, 0.05, 0.05),
+ mass: float = 0.1,
+ static: bool = False,
+ collision: bool = True,
+ friction: float = 1.0,
+ visual: bool = True,
+ gui_only: bool = False,
+ color: tuple[float, float, float, float] | None = None,
+ **kwargs,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ if not color:
+ color = tuple(np.random.uniform(0.0, 1.0, (3,)))
+ color = (color[0], color[1], color[2], 1.0)
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ model_name=model_name,
+ size=size,
+ mass=mass,
+ static=static,
+ collision=collision,
+ friction=friction,
+ visual=visual,
+ gui_only=gui_only,
+ color=color,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError(f"Failed to insert {model_name}")
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ super().__init__(model=model)
+
+ @classmethod
+ def get_sdf(
+ cls,
+ model_name: str,
+ size: tuple[float, float, float],
+ mass: float,
+ static: bool,
+ collision: bool,
+ friction: float,
+ visual: bool,
+ gui_only: bool,
+ color: tuple[float, float, float, float],
+ ) -> str:
+ """
+ Generates the SDF string for the box model.
+
+ Args:
+ model_name (str): The name of the box model.
+ size (tuple[float, float, float]): The dimensions of the box (width, length, height).
+ mass (float): The mass of the box.
+ static (bool): If True, the box will be static and immovable.
+ collision (bool): If True, the box will have collision properties.
+ friction (float): The friction coefficient for the box.
+ visual (bool): If True, the box will have a visual representation.
+ gui_only (bool): If True, the box's visual representation will only appear in the GUI and will not affect physics.
+ color (tuple[float, float, float, float]): The RGBA color of the box.
+
+ Returns:
+ The SDF string that defines the box model in Gazebo.
+ """
+ return f'''
+
+ {"true" if static else "false"}
+
+ {
+ f"""
+
+
+
+ {size[0]} {size[1]} {size[2]}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+ """ if collision else ""
+ }
+ {
+ f"""
+
+
+
+ {size[0]} {size[1]} {size[2]}
+
+
+
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+
+ {1.0 - color[3]}
+ {'1false' if gui_only else ''}
+
+ """ if visual else ""
+ }
+
+ {mass}
+
+ {(size[1]**2 + size[2]**2) * mass / 12}
+ {(size[0]**2 + size[2]**2) * mass / 12}
+ {(size[0]**2 + size[1]**2) * mass / 12}
+ 0.0
+ 0.0
+ 0.0
+
+
+
+
+ '''
diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/cylinder.py b/env_manager/env_manager/env_manager/models/objects/primitives/cylinder.py
new file mode 100644
index 0000000..38df76f
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/primitives/cylinder.py
@@ -0,0 +1,173 @@
+import numpy as np
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Cylinder(model_wrapper.ModelWrapper):
+ """
+ The Cylinder class represents a cylindrical model in the Gazebo simulation environment.
+
+ Attributes:
+ world (scenario_gazebo.World): The Gazebo world where the cylinder model will be inserted.
+ name (str): The name of the cylinder model. Default is "cylinder".
+ position (tuple[float, float, float]): The position of the cylinder in the world. Default is (0, 0, 0).
+ orientation (tuple[float, float, float, float]): The orientation of the cylinder in quaternion format. Default is (1, 0, 0, 0).
+ radius (float): The radius of the cylinder. Default is 0.025.
+ length (float): The length/height of the cylinder. Default is 0.1.
+ mass (float): The mass of the cylinder in kilograms. Default is 0.1.
+ static (bool): If True, the cylinder will be immovable. Default is False.
+ collision (bool): If True, the cylinder will have collision properties. Default is True.
+ friction (float): The friction coefficient for the cylinder's surface. Default is 1.0.
+ visual (bool): If True, the cylinder will have a visual representation. Default is True.
+ gui_only (bool): If True, the visual representation of the cylinder will only appear in the GUI, not in the simulation physics. Default is False.
+ color (tuple[float, float, float, float]): The RGBA color of the cylinder. Default is (0.8, 0.8, 0.8, 1.0).
+
+ Raises:
+ RuntimeError: If the cylinder model fails to be inserted into the Gazebo world.
+
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "cylinder",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ radius: float = 0.025,
+ length: float = 0.1,
+ mass: float = 0.1,
+ static: bool = False,
+ collision: bool = True,
+ friction: float = 1.0,
+ visual: bool = True,
+ gui_only: bool = False,
+ color: tuple[float, float, float, float] | None = None,
+ **kwargs,
+ ):
+ model_name = get_unique_model_name(world, name)
+
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ if not color:
+ color = tuple(np.random.uniform(0.0, 1.0, (3,)))
+ color = (color[0], color[1], color[2], 1.0)
+
+ sdf = self.get_sdf(
+ model_name=model_name,
+ radius=radius,
+ length=length,
+ mass=mass,
+ static=static,
+ collision=collision,
+ friction=friction,
+ visual=visual,
+ gui_only=gui_only,
+ color=color,
+ )
+
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError(f"Failed to insert {model_name}")
+
+ model = world.get_model(model_name)
+
+ super().__init__(model=model)
+
+ @classmethod
+ def get_sdf(
+ cls,
+ model_name: str,
+ radius: float,
+ length: float,
+ mass: float,
+ static: bool,
+ collision: bool,
+ friction: float,
+ visual: bool,
+ gui_only: bool,
+ color: tuple[float, float, float, float],
+ ) -> str:
+ """
+ Generates the SDF string for the cylinder model.
+
+ Args:
+ model_name (str): The name of the model.
+ radius (float): The radius of the cylinder.
+ length (float): The length or height of the cylinder.
+ mass (float): The mass of the cylinder in kilograms.
+ static (bool): If True, the cylinder will remain immovable in the simulation.
+ collision (bool): If True, adds collision properties to the cylinder.
+ friction (float): The friction coefficient for the cylinder's surface.
+ visual (bool): If True, a visual representation of the cylinder will be created.
+ gui_only (bool): If True, the visual representation will only appear in the GUI, without impacting the simulation's physics.
+ color (tuple[float, float, float, float]): The RGBA color of the cylinder, where each value is between 0 and 1.
+
+ Returns:
+ str: The SDF string representing the cylinder model
+ """
+ inertia_xx_yy = (3 * radius**2 + length**2) * mass / 12
+
+ return f'''
+
+ {"true" if static else "false"}
+
+ {
+ f"""
+
+
+
+ {radius}
+ {length}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+ """ if collision else ""
+ }
+ {
+ f"""
+
+
+
+ {radius}
+ {length}
+
+
+
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+
+ {1.0-color[3]}
+ {'1false' if gui_only else ''}
+
+ """ if visual else ""
+ }
+
+ {mass}
+
+ {inertia_xx_yy}
+ {inertia_xx_yy}
+ {(mass*radius**2)/2}
+ 0.0
+ 0.0
+ 0.0
+
+
+
+
+ '''
diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/plane.py b/env_manager/env_manager/env_manager/models/objects/primitives/plane.py
new file mode 100644
index 0000000..ffbeefb
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/primitives/plane.py
@@ -0,0 +1,100 @@
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Plane(model_wrapper.ModelWrapper):
+ """
+ The Plane class represents a plane model in the Gazebo simulation environment.
+ It allows for defining a flat surface with collision and visual properties, as well as its orientation and friction settings.
+
+ Attributes:
+ world (scenario_gazebo.World): The Gazebo world where the plane model will be inserted.
+ name (str): The name of the plane model. Default is "plane".
+ position (tuple[float, float, float]): The position of the plane in the world. Default is (0, 0, 0).
+ orientation (tuple[float, float, float, float]): The orientation of the plane in quaternion format. Default is (1, 0, 0, 0).
+ size (tuple[float, float]): The size of the plane along the x and y axes. Default is (1.0, 1.0).
+ direction (tuple[float, float, float]): The normal vector representing the plane's direction. Default is (0.0, 0.0, 1.0), representing a horizontal plane.
+ collision (bool): If True, the plane will have collision properties. Default is True.
+ friction (float): The friction coefficient for the plane's surface. Default is 1.0.
+ visual (bool): If True, the plane will have a visual representation. Default is True.
+
+ Raises:
+ RuntimeError: If the plane model fails to be inserted into the Gazebo world.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "plane",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ size: tuple[float, float] = (1.0, 1.0),
+ direction: tuple[float, float, float] = (0.0, 0.0, 1.0),
+ collision: bool = True,
+ friction: float = 1.0,
+ visual: bool = True,
+ **kwargs,
+ ):
+ model_name = get_unique_model_name(world, name)
+
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ sdf = f'''
+
+ true
+
+ {
+ f"""
+
+
+
+ {direction[0]} {direction[1]} {direction[2]}
+ {size[0]} {size[1]}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+ """ if collision else ""
+ }
+ {
+ f"""
+
+
+
+ {direction[0]} {direction[1]} {direction[2]}
+ {size[0]} {size[1]}
+
+
+
+ 0.8 0.8 0.8 1
+ 0.8 0.8 0.8 1
+ 0.8 0.8 0.8 1
+
+
+ """ if visual else ""
+ }
+
+
+ '''
+
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError(f"Failed to insert {model_name}")
+
+ model = world.get_model(model_name)
+
+ super().__init__(model=model)
diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/sphere.py b/env_manager/env_manager/env_manager/models/objects/primitives/sphere.py
new file mode 100644
index 0000000..1fa2761
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/primitives/sphere.py
@@ -0,0 +1,176 @@
+import numpy as np
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Sphere(model_wrapper.ModelWrapper):
+ """
+ A Sphere model for Gazebo simulation.
+
+ This class represents a spherical object that can be inserted into a Gazebo world.
+ The sphere can be customized by specifying its physical properties, such as radius, mass,
+ collision parameters, friction, and visual attributes.
+
+ Attributes:
+ world (scenario_gazebo.World): The Gazebo world where the sphere model will be inserted.
+ name (str): The name of the sphere model. Defaults to "sphere".
+ position (tuple[float, float, float]): The position of the sphere in the world. Defaults to (0, 0, 0).
+ orientation (tuple[float, float, float, float]): The orientation of the sphere in quaternion format. Defaults to (1, 0, 0, 0).
+ radius (float): The radius of the sphere. Defaults to 0.025.
+ mass (float): The mass of the sphere. Defaults to 0.1.
+ static (bool): If True, the sphere will be static in the simulation. Defaults to False.
+ collision (bool): If True, the sphere will have collision properties. Defaults to True.
+ friction (float): The friction coefficient of the sphere's surface. Defaults to 1.0.
+ visual (bool): If True, the sphere will have a visual representation. Defaults to True.
+ gui_only (bool): If True, the visual will only appear in the GUI. Defaults to False.
+ color (tuple[float, float, float, float]): The RGBA color of the sphere. Defaults to (0.8, 0.8, 0.8, 1.0).
+
+ Raises:
+ RuntimeError: If the sphere fails to be inserted into the Gazebo world.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "sphere",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ radius: float = 0.025,
+ mass: float = 0.1,
+ static: bool = False,
+ collision: bool = True,
+ friction: float = 1.0,
+ visual: bool = True,
+ gui_only: bool = False,
+ color: tuple[float, float, float, float] | None = None,
+ **kwargs,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ if not color:
+ color = tuple(np.random.uniform(0.0, 1.0, (3,)))
+ color = (color[0], color[1], color[2], 1.0)
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ model_name=model_name,
+ radius=radius,
+ mass=mass,
+ static=static,
+ collision=collision,
+ friction=friction,
+ visual=visual,
+ gui_only=gui_only,
+ color=color,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError(f"Failed to insert {model_name}")
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ super().__init__(model=model)
+
+ @classmethod
+ def get_sdf(
+ cls,
+ model_name: str,
+ radius: float,
+ mass: float,
+ static: bool,
+ collision: bool,
+ friction: float,
+ visual: bool,
+ gui_only: bool,
+ color: tuple[float, float, float, float],
+ ) -> str:
+ """
+ Generates the SDF (Simulation Description Format) string for the sphere model.
+
+ Args:
+ model_name (str): The name of the model.
+ radius (float): The radius of the sphere.
+ mass (float): The mass of the sphere.
+ static (bool): Whether the sphere is static in the simulation.
+ collision (bool): Whether the sphere should have collision properties.
+ friction (float): The friction coefficient for the sphere.
+ visual (bool): Whether the sphere should have a visual representation.
+ gui_only (bool): Whether the visual representation is only visible in the GUI.
+ color (tuple[float, float, float, float]): The RGBA color of the sphere.
+
+ Returns:
+ str: The SDF string representing the sphere.
+ """
+ # Inertia is identical for all axes
+ inertia_xx_yy_zz = (mass * radius**2) * 2 / 5
+
+ return f'''
+
+ {"true" if static else "false"}
+
+ {
+ f"""
+
+
+
+ {radius}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+ """ if collision else ""
+ }
+ {
+ f"""
+
+
+
+ {radius}
+
+
+
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+ {color[0]} {color[1]} {color[2]} {color[3]}
+
+ {1.0 - color[3]}
+ {'1false' if gui_only else ''}
+
+ """ if visual else ""
+ }
+
+ {mass}
+
+ {inertia_xx_yy_zz}
+ {inertia_xx_yy_zz}
+ {inertia_xx_yy_zz}
+ 0.0
+ 0.0
+ 0.0
+
+
+
+
+ '''
diff --git a/env_manager/env_manager/env_manager/models/objects/random_object.py b/env_manager/env_manager/env_manager/models/objects/random_object.py
new file mode 100644
index 0000000..5f399c6
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/random_object.py
@@ -0,0 +1,84 @@
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from numpy.random import RandomState
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+from env_manager.models.utils import ModelCollectionRandomizer
+
+
+class RandomObject(model_wrapper.ModelWrapper):
+ """
+ Represents a randomly selected 3D object model in the Gazebo simulation environment.
+ This class allows for the insertion of various models based on a collection of model paths,
+ utilizing a randomization strategy for the chosen model.
+
+ Args:
+ world (scenario_gazebo.World): The Gazebo world where the random object model will be inserted.
+ name (str, optional): The name of the random object model. Defaults to "object".
+ position (tuple[float, float, float], optional): The position of the object in the world. Defaults to (0, 0, 0).
+ orientation (tuple[float, float, float, float], optional): The orientation of the object in quaternion format. Defaults to (1, 0, 0, 0).
+ model_paths (str | None, optional): Paths to the model files. Must be set; raises an error if None.
+ owner (str, optional): The owner of the model collection. Defaults to "GoogleResearch".
+ collection (str, optional): The collection of models to choose from. Defaults to "Google Scanned Objects".
+ server (str, optional): The server URL for the model collection. Defaults to "https://fuel.ignitionrobotics.org".
+ server_version (str, optional): The version of the server to use. Defaults to "1.0".
+ unique_cache (bool, optional): If True, enables caching of unique models. Defaults to False.
+ reset_collection (bool, optional): If True, resets the model collection for new selections. Defaults to False.
+ np_random (RandomState | None, optional): An instance of RandomState for random number generation. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ RuntimeError: If the model path is not set or if the random object model fails to be inserted into the Gazebo world.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "object",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ model_paths: str | None = None,
+ owner: str = "GoogleResearch",
+ collection: str = "Google Scanned Objects",
+ server: str = "https://fuel.ignitionrobotics.org",
+ server_version: str = "1.0",
+ unique_cache: bool = False,
+ reset_collection: bool = False,
+ np_random: RandomState | None = None,
+ **kwargs,
+ ):
+ # Get a unique model name
+ if model_paths is None:
+ raise RuntimeError("Set model path for continue")
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ model_collection_randomizer = ModelCollectionRandomizer(
+ model_paths=model_paths,
+ owner=owner,
+ collection=collection,
+ server=server,
+ server_version=server_version,
+ unique_cache=unique_cache,
+ reset_collection=reset_collection,
+ np_random=np_random,
+ )
+
+ # Note: using default arguments here
+ modified_sdf_file = model_collection_randomizer.random_model()
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model(
+ modified_sdf_file, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
diff --git a/env_manager/env_manager/env_manager/models/objects/random_primitive.py b/env_manager/env_manager/env_manager/models/objects/random_primitive.py
new file mode 100644
index 0000000..58510b1
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/objects/random_primitive.py
@@ -0,0 +1,164 @@
+import numpy as np
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from numpy.random import RandomState
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+from . import Box, Cylinder, Sphere
+
+
+class RandomPrimitive(model_wrapper.ModelWrapper):
+ """
+ Represents a randomly generated primitive shape (box, cylinder, or sphere) in the Gazebo simulation environment.
+ This class allows for the insertion of various primitive models based on the user's specifications or randomly chosen.
+
+ Args:
+ world (scenario_gazebo.World): The Gazebo world where the primitive model will be inserted.
+ name (str, optional): The name of the primitive model. Defaults to "primitive".
+ use_specific_primitive ((str | None), optional): If specified, the exact type of primitive to create ('box', 'cylinder', or 'sphere'). Defaults to None, which will randomly select a primitive type.
+ position (tuple[float, float, float], optional): The position of the primitive in the world. Defaults to (0, 0, 0).
+ orientation (tuple[float, float, float, float], optional): The orientation of the primitive in quaternion format. Defaults to (1, 0, 0, 0).
+ static (bool, optional): If True, the primitive will be static and immovable. Defaults to False.
+ collision (bool, optional): If True, the primitive will have collision properties. Defaults to True.
+ visual (bool, optional): If True, the primitive will have a visual representation. Defaults to True.
+ gui_only (bool, optional): If True, the visual representation will only appear in the GUI and not in the simulation physics. Defaults to False.
+ np_random (RandomState | None, optional): An instance of RandomState for random number generation. If None, a default random generator will be used. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ RuntimeError: If the primitive model fails to be inserted into the Gazebo world.
+ TypeError: If the specified primitive type is not supported.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "primitive",
+ use_specific_primitive: (str | None) = None,
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ static: bool = False,
+ collision: bool = True,
+ visual: bool = True,
+ gui_only: bool = False,
+ np_random: RandomState | None = None,
+ **kwargs,
+ ):
+ if np_random is None:
+ np_random = np.random.default_rng()
+
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ # Create SDF string for the model
+ sdf = self.get_sdf(
+ model_name=model_name,
+ use_specific_primitive=use_specific_primitive,
+ static=static,
+ collision=collision,
+ visual=visual,
+ gui_only=gui_only,
+ np_random=np_random,
+ )
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
+
+ @classmethod
+ def get_sdf(
+ cls,
+ model_name: str,
+ use_specific_primitive: (str | None),
+ static: bool,
+ collision: bool,
+ visual: bool,
+ gui_only: bool,
+ np_random: RandomState,
+ ) -> str:
+ """
+ Generates the SDF (Simulation Description Format) string for the specified primitive model.
+
+ This method can create the SDF representation for a box, cylinder, or sphere based on the provided parameters.
+ If a specific primitive type is not provided, one will be randomly selected.
+
+ Args:
+ model_name (str): The name of the model being generated.
+ use_specific_primitive ((str | None)): The specific type of primitive to create ('box', 'cylinder', or 'sphere'). If None, a random primitive will be chosen.
+ static (bool): If True, the primitive will be static and immovable.
+ collision (bool): If True, the primitive will have collision properties.
+ visual (bool): If True, the primitive will have a visual representation.
+ gui_only (bool): If True, the visual representation will only appear in the GUI and will not affect physics.
+ np_random (RandomState): An instance of RandomState for random number generation.
+
+ Returns:
+ str: The SDF string that defines the specified primitive model, including its physical and visual properties.
+
+ Raises:
+ TypeError: If the specified primitive type is not supported.
+
+ """
+ if use_specific_primitive is not None:
+ primitive = use_specific_primitive
+ else:
+ primitive = np_random.choice(["box", "cylinder", "sphere"])
+
+ mass = np_random.uniform(0.05, 0.25)
+ friction = np_random.uniform(0.75, 1.5)
+ color = tuple(np_random.uniform(0.0, 1.0, (3,)))
+ color: tuple[float, float, float, float] = (color[0], color[1], color[2], 1.0)
+
+ if "box" == primitive:
+ return Box.get_sdf(
+ model_name=model_name,
+ size=tuple(np_random.uniform(0.04, 0.06, (3,))),
+ mass=mass,
+ static=static,
+ collision=collision,
+ friction=friction,
+ visual=visual,
+ gui_only=gui_only,
+ color=color,
+ )
+ elif "cylinder" == primitive:
+ return Cylinder.get_sdf(
+ model_name=model_name,
+ radius=np_random.uniform(0.01, 0.0375),
+ length=np_random.uniform(0.025, 0.05),
+ mass=mass,
+ static=static,
+ collision=collision,
+ friction=friction,
+ visual=visual,
+ gui_only=gui_only,
+ color=color,
+ )
+ elif "sphere" == primitive:
+ return Sphere.get_sdf(
+ model_name=model_name,
+ radius=np_random.uniform(0.01, 0.0375),
+ mass=mass,
+ static=static,
+ collision=collision,
+ friction=friction,
+ visual=visual,
+ gui_only=gui_only,
+ color=color,
+ )
+ else:
+ raise TypeError(
+ f"Type '{use_specific_primitive}' in not a supported primitive. Pleasure use 'box', 'cylinder' or 'sphere."
+ )
diff --git a/env_manager/env_manager/env_manager/models/robots/__init__.py b/env_manager/env_manager/env_manager/models/robots/__init__.py
new file mode 100644
index 0000000..c690e8c
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/robots/__init__.py
@@ -0,0 +1,16 @@
+from enum import Enum
+
+from .rbs_arm import RbsArm
+from .robot import RobotWrapper
+
+
+class RobotEnum(Enum):
+ RBS_ARM = "rbs_arm"
+
+
+def get_robot_model_class(robot_model: str) -> type[RobotWrapper]:
+ model_mapping = {
+ RobotEnum.RBS_ARM.value: RbsArm,
+ }
+
+ return model_mapping.get(robot_model, RbsArm)
diff --git a/env_manager/env_manager/env_manager/models/robots/rbs_arm.py b/env_manager/env_manager/env_manager/models/robots/rbs_arm.py
new file mode 100644
index 0000000..d724946
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/robots/rbs_arm.py
@@ -0,0 +1,153 @@
+from gym_gz.utils.scenario import get_unique_model_name
+from rclpy.node import Node
+from robot_builder.parser.urdf import URDF_parser
+from robot_builder.elements.robot import Robot
+
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+from .robot import RobotWrapper
+
+
+class RbsArm(RobotWrapper):
+ """
+ A class representing a robotic arm built using the `robot_builder` library.
+
+ This class is responsible for creating a robotic arm model within a Gazebo simulation environment.
+ It allows for the manipulation of joint positions for both the arm and the gripper.
+
+ Attributes:
+ - DEFAULT_ARM_JOINT_POSITIONS (list[float]): The default joint positions for the arm.
+ - OPEN_GRIPPER_JOINT_POSITIONS (list[float]): The joint positions for the gripper in the open state.
+ - CLOSED_GRIPPER_JOINT_POSITIONS (list[float]): The joint positions for the gripper in the closed state.
+ - DEFAULT_GRIPPER_JOINT_POSITIONS (list[float]): The default joint positions for the gripper.
+
+ Args:
+ - world (scenario_gazebo.World): The Gazebo world where the robot model will be inserted.
+ - node (Node): The ROS2 node associated with the robotic arm.
+ - urdf_string (str): The URDF string defining the robot.
+ - name (str, optional): The name of the robotic arm. Defaults to "rbs_arm".
+ - position (tuple[float, float, float], optional): The initial position of the robot in the world. Defaults to (0.0, 0.0, 0.0).
+ - orientation (tuple[float, float, float, float], optional): The initial orientation of the robot in quaternion format. Defaults to (1.0, 0, 0, 0).
+ - initial_arm_joint_positions (list[float] | float, optional): The initial joint positions for the arm. Defaults to `DEFAULT_ARM_JOINT_POSITIONS`.
+ - initial_gripper_joint_positions (list[float] | float, optional): The initial joint positions for the gripper. Defaults to `DEFAULT_GRIPPER_JOINT_POSITIONS`.
+
+ Raises:
+ - RuntimeError: If the robot model fails to be inserted into the Gazebo world.
+ """
+
+ DEFAULT_ARM_JOINT_POSITIONS: list[float] = [
+ 0.0,
+ 0.5,
+ 3.14159,
+ 1.5,
+ 0.0,
+ 1.4,
+ 0.0,
+ ]
+
+ OPEN_GRIPPER_JOINT_POSITIONS: list[float] = [
+ 0.064,
+ ]
+
+ CLOSED_GRIPPER_JOINT_POSITIONS: list[float] = [
+ 0.0,
+ ]
+
+ DEFAULT_GRIPPER_JOINT_POSITIONS: list[float] = CLOSED_GRIPPER_JOINT_POSITIONS
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ node: Node,
+ urdf_string: str,
+ name: str = "rbs_arm",
+ position: tuple[float, float, float] = (0.0, 0.0, 0.0),
+ orientation: tuple[float, float, float, float] = (1.0, 0, 0, 0),
+ initial_arm_joint_positions: list[float] | float = DEFAULT_ARM_JOINT_POSITIONS,
+ initial_gripper_joint_positions: list[float]
+ | float = DEFAULT_GRIPPER_JOINT_POSITIONS,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Setup initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ # Insert the model
+ ok_model = world.insert_model_from_string(urdf_string, initial_pose, model_name)
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Parse robot to get metadata
+ self._robot: Robot = URDF_parser.load_string(urdf_string)
+
+ self.__initial_gripper_joint_positions = (
+ [float(initial_gripper_joint_positions)]
+ * len(self._robot.gripper_actuated_joint_names)
+ if isinstance(initial_gripper_joint_positions, (int, float))
+ else initial_gripper_joint_positions
+ )
+
+ self.__initial_arm_joint_positions = (
+ [float(initial_arm_joint_positions)] * len(self._robot.actuated_joint_names)
+ if isinstance(initial_arm_joint_positions, (int, float))
+ else initial_arm_joint_positions
+ )
+
+ # Set initial joint configuration
+ self.set_initial_joint_positions(model)
+ super().__init__(model=model)
+
+ @property
+ def robot(self) -> Robot:
+ """Returns the robot metadata parsed from the URDF string.
+
+ Returns:
+ Robot: The robot instance containing metadata.
+ """
+ return self._robot
+
+ @property
+ def initial_arm_joint_positions(self) -> list[float]:
+ """Returns the initial joint positions for the arm.
+
+ Returns:
+ list[float]: The initial joint positions for the arm.
+ """
+ return self.__initial_arm_joint_positions
+
+ @property
+ def initial_gripper_joint_positions(self) -> list[float]:
+ """Returns the initial joint positions for the gripper.
+
+ Returns:
+ list[float]: The initial joint positions for the gripper.
+ """
+ return self.__initial_gripper_joint_positions
+
+ def set_initial_joint_positions(self, model):
+ """Sets the initial positions for the robot's joints.
+
+ This method resets the joint positions of both the arm and gripper to their specified initial values.
+
+ Args:
+ model: The model representation of the robot within the Gazebo environment.
+
+ Raises:
+ RuntimeError: If resetting the joint positions fails for any of the joints.
+ """
+ model = model.to_gazebo()
+
+ joint_position_data = [
+ (self.__initial_arm_joint_positions, self._robot.actuated_joint_names),
+ (self.__initial_gripper_joint_positions, self._robot.gripper_actuated_joint_names),
+ ]
+
+ for positions, joint_names in joint_position_data:
+ print(f"Setting joint positions for {joint_names}: {positions}")
+ if not model.reset_joint_positions(positions, joint_names):
+ raise RuntimeError(f"Failed to set initial positions of {joint_names}'s joints")
diff --git a/env_manager/env_manager/env_manager/models/robots/robot.py b/env_manager/env_manager/env_manager/models/robots/robot.py
new file mode 100644
index 0000000..fdfa062
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/robots/robot.py
@@ -0,0 +1,95 @@
+from gym_gz.scenario import model_wrapper
+from robot_builder.elements.robot import Robot
+from scenario import gazebo as scenario_gazebo
+from abc import abstractmethod
+
+
+class RobotWrapper(model_wrapper.ModelWrapper):
+ """
+ An abstract base class for robot models in a Gazebo simulation.
+
+ This class extends the ModelWrapper from gym_gz and provides a structure for creating
+ robot models with specific configurations, including joint positions for arms and grippers.
+
+ Args:
+ world (scenario_gazebo.World, optional): The Gazebo world where the robot model will be inserted.
+ urdf_string (str, optional): The URDF (Unified Robot Description Format) string defining the robot.
+ name (str, optional): The name of the robot model. Defaults to None.
+ position (tuple[float, float, float], optional): The initial position of the robot in the world. Defaults to None.
+ orientation (tuple[float, float, float, float], optional): The initial orientation of the robot in quaternion format. Defaults to None.
+ initial_arm_joint_positions (list[float] | float, optional): The initial joint positions for the arm. Defaults to an empty list.
+ initial_gripper_joint_positions (list[float] | float, optional): The initial joint positions for the gripper. Defaults to an empty list.
+ model (model_wrapper.ModelWrapper, optional): An existing model instance to initialize the wrapper. Must be provided.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ ValueError: If the model parameter is not provided.
+ """
+ def __init__(
+ self,
+ world: scenario_gazebo.World | None = None,
+ urdf_string: str | None = None,
+ name: str | None = None,
+ position: tuple[float, float, float] | None = None,
+ orientation: tuple[float, float, float, float] | None = None,
+ initial_arm_joint_positions: list[float] | float = [],
+ initial_gripper_joint_positions: list[float] | float = [],
+ model: model_wrapper.ModelWrapper | None = None,
+ **kwargs,
+ ):
+ if model is not None:
+ super().__init__(model=model)
+ else:
+ raise ValueError("Model should be defined for the parent class")
+
+ @property
+ @abstractmethod
+ def robot(self) -> Robot:
+ """Returns the robot instance containing metadata.
+
+ This property must be implemented by subclasses to return the specific robot metadata.
+
+ Returns:
+ Robot: The robot instance.
+ """
+ pass
+
+
+ @property
+ @abstractmethod
+ def initial_gripper_joint_positions(self) -> list[float]:
+ """Returns the initial joint positions for the gripper.
+
+ This property must be implemented by subclasses to provide the initial positions of the gripper joints.
+
+ Returns:
+ list[float]: The initial joint positions for the gripper.
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def initial_arm_joint_positions(self) -> list[float]:
+ """Returns the initial joint positions for the arm.
+
+ This property must be implemented by subclasses to provide the initial positions of the arm joints.
+
+ Returns:
+ list[float]: The initial joint positions for the arm.
+ """
+ pass
+
+ @abstractmethod
+ def set_initial_joint_positions(self, model):
+ """Sets the initial positions for the robot's joints.
+
+ This method must be implemented by subclasses to reset the joint positions of the robot
+ to their specified initial values.
+
+ Args:
+ model: The model representation of the robot within the Gazebo environment.
+
+ Raises:
+ RuntimeError: If resetting the joint positions fails.
+ """
+ pass
diff --git a/env_manager/env_manager/env_manager/models/sensors/__init__.py b/env_manager/env_manager/env_manager/models/sensors/__init__.py
new file mode 100644
index 0000000..0cbbbe7
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/sensors/__init__.py
@@ -0,0 +1 @@
+from .camera import Camera
diff --git a/env_manager/env_manager/env_manager/models/sensors/camera.py b/env_manager/env_manager/env_manager/models/sensors/camera.py
new file mode 100644
index 0000000..63e1616
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/sensors/camera.py
@@ -0,0 +1,436 @@
+import os
+from threading import Thread
+
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+# from env_manager.models.utils import ModelCollectionRandomizer
+
+
+class Camera(model_wrapper.ModelWrapper):
+ """
+ Represents a camera model in a Gazebo simulation.
+
+ This class extends the ModelWrapper from gym_gz to define a camera model that can be inserted
+ into the Gazebo world. It supports different types of cameras and can bridge ROS2 topics for
+ camera data.
+
+ Args:
+ world (scenario_gazebo.World): The Gazebo world where the camera will be inserted.
+ name (str, optional): The name of the camera model. If None, a unique name is generated.
+ position (tuple[float, float, float], optional): The initial position of the camera. Defaults to (0, 0, 0).
+ orientation (tuple[float, float, float, float], optional): The initial orientation of the camera in quaternion. Defaults to (1, 0, 0, 0).
+ static (bool, optional): If True, the camera is a static model. Defaults to True.
+ camera_type (str, optional): The type of camera to create. Defaults to "rgbd_camera".
+ width (int, optional): The width of the camera image. Defaults to 212.
+ height (int, optional): The height of the camera image. Defaults to 120.
+ image_format (str, optional): The format of the camera image. Defaults to "R8G8B8".
+ update_rate (int, optional): The update rate for the camera sensor. Defaults to 15.
+ horizontal_fov (float, optional): The horizontal field of view for the camera. Defaults to 1.567821.
+ vertical_fov (float, optional): The vertical field of view for the camera. Defaults to 1.022238.
+ clip_color (tuple[float, float], optional): The near and far clipping distances for color. Defaults to (0.02, 1000.0).
+ clip_depth (tuple[float, float], optional): The near and far clipping distances for depth. Defaults to (0.02, 10.0).
+ noise_mean (float, optional): The mean of the noise added to the camera images. Defaults to None.
+ noise_stddev (float, optional): The standard deviation of the noise added to the camera images. Defaults to None.
+ ros2_bridge_color (bool, optional): If True, a ROS2 bridge for color images is created. Defaults to False.
+ ros2_bridge_depth (bool, optional): If True, a ROS2 bridge for depth images is created. Defaults to False.
+ ros2_bridge_points (bool, optional): If True, a ROS2 bridge for point cloud data is created. Defaults to False.
+ visibility_mask (int, optional): The visibility mask for the camera sensor. Defaults to 0.
+ visual (str, optional): The type of visual representation for the camera. Defaults to None.
+
+ Raises:
+ ValueError: If the visual mesh or textures cannot be found.
+ RuntimeError: If the camera model fails to insert into the Gazebo world.
+ """
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str | None = None,
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ static: bool = True,
+ camera_type: str = "rgbd_camera",
+ width: int = 212,
+ height: int = 120,
+ image_format: str = "R8G8B8",
+ update_rate: int = 15,
+ horizontal_fov: float = 1.567821,
+ vertical_fov: float = 1.022238,
+ clip_color: tuple[float, float] = (0.02, 1000.0),
+ clip_depth: tuple[float, float] = (0.02, 10.0),
+ noise_mean: float | None = None,
+ noise_stddev: float | None = None,
+ ros2_bridge_color: bool = False,
+ ros2_bridge_depth: bool = False,
+ ros2_bridge_points: bool = False,
+ visibility_mask: int = 0,
+ visual: str | None = None,
+ ):
+ # Get a unique model name
+ if name is not None:
+ model_name = get_unique_model_name(world, name)
+ else:
+ model_name = get_unique_model_name(world, camera_type)
+ self._model_name = model_name
+ self._camera_type = camera_type
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ use_mesh: bool = False
+ mesh_path_visual: str = ""
+
+ albedo_map = None
+ normal_map = None
+ roughness_map = None
+ metalness_map = None
+ # Get resources for visual (if enabled)
+ # if visual:
+ # if "intel_realsense_d435" == visual:
+ # use_mesh = True
+ #
+ # # Get path to the model and the important directories
+ # model_path = ModelCollectionRandomizer.get_collection_paths(
+ # owner="OpenRobotics",
+ # collection="",
+ # model_name="Intel RealSense D435",
+ # )[0]
+ #
+ # mesh_dir = os.path.join(model_path, "meshes")
+ # texture_dir = os.path.join(model_path, "materials", "textures")
+ #
+ # # Get path to the mesh
+ # mesh_path_visual = os.path.join(mesh_dir, "realsense.dae")
+ # # Make sure that it exists
+ # if not os.path.exists(mesh_path_visual):
+ # raise ValueError(
+ # f"Visual mesh '{mesh_path_visual}' for Camera model is not a valid file."
+ # )
+ #
+ # # Find PBR textures
+ # if texture_dir:
+ # # List all files
+ # texture_files = os.listdir(texture_dir)
+ #
+ # # Extract the appropriate files
+ # for texture in texture_files:
+ # texture_lower = texture.lower()
+ # if "basecolor" in texture_lower or "albedo" in texture_lower:
+ # albedo_map = os.path.join(texture_dir, texture)
+ # elif "normal" in texture_lower:
+ # normal_map = os.path.join(texture_dir, texture)
+ # elif "roughness" in texture_lower:
+ # roughness_map = os.path.join(texture_dir, texture)
+ # elif (
+ # "specular" in texture_lower or "metalness" in texture_lower
+ # ):
+ # metalness_map = os.path.join(texture_dir, texture)
+ #
+ # if not (albedo_map and normal_map and roughness_map and metalness_map):
+ # raise ValueError("Not all textures for Camera model were found.")
+
+ # Create SDF string for the model
+ sdf = f'''
+
+ {static}
+
+
+ {model_name}
+ true
+ {update_rate}
+
+
+ {width}
+ {height}
+ {image_format}
+
+ {horizontal_fov}
+ {vertical_fov}
+
+ {clip_color[0]}
+ {clip_color[1]}
+
+ {
+ f"""
+
+ {clip_depth[0]}
+ {clip_depth[1]}
+
+ """ if "rgbd" in model_name else ""
+ }
+ {
+ f"""
+ gaussian
+ {noise_mean}
+ {noise_stddev}
+ """ if noise_mean is not None and noise_stddev is not None else ""
+ }
+ {visibility_mask}
+
+ true
+
+ {
+ f"""
+
+ -0.01 0 0 0 1.5707963 0
+
+
+ 0.02
+ 0.02
+
+
+
+ 0.0 0.8 0.0
+ 0.0 0.8 0.0
+ 0.0 0.8 0.0
+
+
+
+ -0.05 0 0 0 0 0
+
+
+ 0.06 0.05 0.05
+
+
+
+ 0.0 0.8 0.0
+ 0.0 0.8 0.0
+ 0.0 0.8 0.0
+
+
+ """ if visual and not use_mesh else ""
+ }
+ {
+ f"""
+
+ 0.0615752
+
+ 9.108e-05
+ 0.0
+ 0.0
+ 2.51e-06
+ 0.0
+ 8.931e-05
+
+
+
+ 0 0 0 0 0 1.5707963
+
+
+ {mesh_path_visual}
+
+ RealSense
+
false
+
+
+
+
+ 1 1 1 1
+ 1 1 1 1
+
+
+ {albedo_map}
+ {normal_map}
+ {roughness_map}
+ {metalness_map}
+
+
+
+
+ """ if visual and use_mesh else ""
+ }
+
+
+ '''
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
+
+ if ros2_bridge_color or ros2_bridge_depth or ros2_bridge_points:
+ self.__threads = []
+ self.__threads.append(
+ Thread(
+ target=self.construct_ros2_bridge,
+ args=(
+ self.info_topic,
+ "sensor_msgs/msg/CameraInfo",
+ "gz.msgs.CameraInfo",
+ ),
+ daemon=True,
+ )
+ )
+ if ros2_bridge_color:
+ self.__threads.append(
+ Thread(
+ target=self.construct_ros2_bridge,
+ args=(
+ self.color_topic,
+ "sensor_msgs/msg/Image",
+ "gz.msgs.Image",
+ ),
+ daemon=True,
+ )
+ )
+
+ if ros2_bridge_depth:
+ self.__threads.append(
+ Thread(
+ target=self.construct_ros2_bridge,
+ args=(
+ self.depth_topic,
+ "sensor_msgs/msg/Image",
+ "gz.msgs.Image",
+ ),
+ daemon=True,
+ )
+ )
+
+ if ros2_bridge_points:
+ self.__threads.append(
+ Thread(
+ target=self.construct_ros2_bridge,
+ args=(
+ self.points_topic,
+ "sensor_msgs/msg/PointCloud2",
+ "gz.msgs.PointCloudPacked",
+ ),
+ daemon=True,
+ )
+ )
+
+ for thread in self.__threads:
+ thread.start()
+
+ def __del__(self):
+ """Cleans up threads when the Camera object is deleted."""
+ if hasattr(self, "__threads"):
+ for thread in self.__threads:
+ thread.join()
+
+ @classmethod
+ def construct_ros2_bridge(cls, topic: str, ros_msg: str, ign_msg: str):
+ """
+ Constructs and runs a ROS2 bridge command for a given topic.
+
+ Args:
+ topic (str): The topic to bridge.
+ ros_msg (str): The ROS2 message type to use.
+ ign_msg (str): The Ignition message type to use.
+ """
+ node_name = "parameter_bridge" + topic.replace("/", "_")
+ command = (
+ f"ros2 run ros_gz_bridge parameter_bridge {topic}@{ros_msg}[{ign_msg} "
+ + f"--ros-args --remap __node:={node_name} --ros-args -p use_sim_time:=true"
+ )
+ os.system(command)
+
+ @classmethod
+ def get_frame_id(cls, model_name: str) -> str:
+ """
+ Gets the frame ID for the camera model.
+
+ Args:
+ model_name (str): The name of the camera model.
+
+ Returns:
+ str: The frame ID.
+ """
+ return f"{model_name}/{model_name}_link/camera"
+
+ @property
+ def frame_id(self) -> str:
+ """Returns the frame ID of the camera."""
+ return self.get_frame_id(self._model_name)
+
+ @classmethod
+ def get_color_topic(cls, model_name: str, camera_type: str) -> str:
+ """
+ Gets the color topic for the camera.
+
+ Args:
+ model_name (str): The name of the camera model.
+ camera_type (str): The type of the camera.
+
+ Returns:
+ str: The color topic.
+ """
+ return f"/{model_name}/image" if "rgbd" in camera_type else f"/{model_name}"
+
+ @property
+ def color_topic(self) -> str:
+ """Returns the color topic for the camera."""
+ return self.get_color_topic(self._model_name, self._camera_type)
+
+ @classmethod
+ def get_depth_topic(cls, model_name: str, camera_type: str) -> str:
+ """
+ Gets the depth topic for the camera.
+
+ Args:
+ model_name (str): The name of the camera model.
+ camera_type (str): The type of the camera.
+
+ Returns:
+ str: The depth topic.
+ """
+ return (
+ f"/{model_name}/depth_image" if "rgbd" in camera_type else f"/{model_name}"
+ )
+
+ @property
+ def depth_topic(self) -> str:
+ """Returns the depth topic for the camera."""
+ return self.get_depth_topic(self._model_name, self._camera_type)
+
+ @classmethod
+ def get_points_topic(cls, model_name: str) -> str:
+ """
+ Gets the points topic for the camera.
+
+ Args:
+ model_name (str): The name of the camera model.
+
+ Returns:
+ str: /{model_name}/points.
+ """
+ return f"/{model_name}/points"
+
+ @property
+ def points_topic(self) -> str:
+ """Returns the points topic for the camera."""
+ return self.get_points_topic(self._model_name)
+
+ @property
+ def info_topic(self):
+ """Returns the camera info topic."""
+ return f"/{self._model_name}/camera_info"
+
+ @classmethod
+ def get_link_name(cls, model_name: str) -> str:
+ """
+ Gets the link name for the camera model.
+
+ Args:
+ model_name (str): The name of the camera model.
+
+ Returns:
+ str: {model_name}_link.
+ """
+ return f"{model_name}_link"
+
+ @property
+ def link_name(self) -> str:
+ """Returns the link name for the camera."""
+ return self.get_link_name(self._model_name)
diff --git a/env_manager/env_manager/env_manager/models/terrains/__init__.py b/env_manager/env_manager/env_manager/models/terrains/__init__.py
new file mode 100644
index 0000000..9829bf3
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/terrains/__init__.py
@@ -0,0 +1,25 @@
+# from gym_gz.scenario.model_wrapper import ModelWrapper
+
+from .ground import Ground
+
+
+# def get_terrain_model_class(terrain_type: str) -> type[ModelWrapper]:
+# # TODO: Refactor into enum
+#
+# if "flat" == terrain_type:
+# return Ground
+# elif "random_flat" == terrain_type:
+# return RandomGround
+# elif "lunar_heightmap" == terrain_type:
+# return LunarHeightmap
+# elif "lunar_surface" == terrain_type:
+# return LunarSurface
+# elif "random_lunar_surface" == terrain_type:
+# return RandomLunarSurface
+# else:
+# raise AttributeError(f"Unsupported terrain [{terrain_type}]")
+#
+#
+# def is_terrain_type_randomizable(terrain_type: str) -> bool:
+#
+# return "random_flat" == terrain_type or "random_lunar_surface" == terrain_type
diff --git a/env_manager/env_manager/env_manager/models/terrains/ground.py b/env_manager/env_manager/env_manager/models/terrains/ground.py
new file mode 100644
index 0000000..09044a3
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/terrains/ground.py
@@ -0,0 +1,104 @@
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+
+class Ground(model_wrapper.ModelWrapper):
+ """
+ Represents a ground model in a Gazebo simulation.
+
+ This class extends the ModelWrapper from gym_gz to define a ground model that can be
+ inserted into the Gazebo world. The ground is defined by its size, position, orientation,
+ and friction properties.
+
+ Args:
+ world (scenario_gazebo.World): The Gazebo world where the ground will be inserted.
+ name (str, optional): The name of the ground model. Defaults to "ground".
+ position (tuple[float, float, float], optional): The initial position of the ground. Defaults to (0, 0, 0).
+ orientation (tuple[float, float, float, float], optional): The initial orientation of the ground in quaternion. Defaults to (1, 0, 0, 0).
+ size (tuple[float, float], optional): The size of the ground plane. Defaults to (1.5, 1.5).
+ collision_thickness (float, optional): The thickness of the collision surface. Defaults to 0.05.
+ friction (float, optional): The friction coefficient for the ground surface. Defaults to 5.0.
+ ambient (tuple[float, float, float, float], optional): The ambient color of the material. Defaults to (0.8, 0.8, 0.8, 1.0).
+ specular (tuple[float, float, float, float], optional): The specular color of the material. Defaults to (0.8, 0.8, 0.8, 1.0).
+ diffuse (tuple[float, float, float, float], optional): The diffuse color of the material. Defaults to (0.8, 0.8, 0.8, 1.0).
+ **kwargs: Additional keyword arguments for future extensions.
+
+ Raises:
+ RuntimeError: If the ground model fails to insert into the Gazebo world.
+ """
+
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "ground",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ size: tuple[float, float] = (1.5, 1.5),
+ collision_thickness=0.05,
+ friction: float = 5.0,
+ ambient: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0),
+ specular: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0),
+ diffuse: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0),
+ **kwargs,
+ ):
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ # Create SDF string for the model with the provided material properties
+ sdf = f"""
+
+ true
+
+
+
+
+ 0 0 1
+ {size[0]} {size[1]}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+
+
+
+ 0 0 1
+ {size[0]} {size[1]}
+
+
+
+ {' '.join(map(str, ambient))}
+ {' '.join(map(str, diffuse))}
+ {' '.join(map(str, specular))}
+
+
+
+
+ """
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
diff --git a/env_manager/env_manager/env_manager/models/terrains/random_ground.py b/env_manager/env_manager/env_manager/models/terrains/random_ground.py
new file mode 100644
index 0000000..948b376
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/terrains/random_ground.py
@@ -0,0 +1,137 @@
+import os
+import numpy as np
+
+from gym_gz.scenario import model_wrapper
+from gym_gz.utils.scenario import get_unique_model_name
+from scenario import core as scenario_core
+from scenario import gazebo as scenario_gazebo
+
+from rbs_assets_library import get_textures_path
+
+
+class RandomGround(model_wrapper.ModelWrapper):
+ def __init__(
+ self,
+ world: scenario_gazebo.World,
+ name: str = "random_ground",
+ position: tuple[float, float, float] = (0, 0, 0),
+ orientation: tuple[float, float, float, float] = (1, 0, 0, 0),
+ size: tuple[float, float] = (1.0, 1.0),
+ friction: float = 5.0,
+ texture_dir: str | None = None,
+ **kwargs,
+ ):
+
+ np_random = np.random.default_rng()
+
+ # Get a unique model name
+ model_name = get_unique_model_name(world, name)
+
+ # Initial pose
+ initial_pose = scenario_core.Pose(position, orientation)
+
+ # Get textures from ENV variable if not directly specified
+ if not texture_dir:
+ texture_dir = os.environ.get("TEXTURE_DIRS", default="")
+ if not texture_dir:
+ texture_dir = get_textures_path()
+
+ # Find random PBR texture
+ albedo_map = None
+ normal_map = None
+ roughness_map = None
+ metalness_map = None
+ if texture_dir:
+ if ":" in texture_dir:
+ textures = []
+ for d in texture_dir.split(":"):
+ textures.extend([os.path.join(d, f) for f in os.listdir(d)])
+ else:
+ # Get list of the available textures
+ textures = os.listdir(texture_dir)
+
+ # Choose a random texture from these
+ random_texture_dir = str(np_random.choice(textures))
+
+ random_texture_dir = texture_dir + random_texture_dir
+
+ # List all files
+ texture_files = os.listdir(random_texture_dir)
+
+ # Extract the appropriate files
+ for texture in texture_files:
+ texture_lower = texture.lower()
+ if "color" in texture_lower or "albedo" in texture_lower:
+ albedo_map = os.path.join(random_texture_dir, texture)
+ elif "normal" in texture_lower:
+ normal_map = os.path.join(random_texture_dir, texture)
+ elif "roughness" in texture_lower:
+ roughness_map = os.path.join(random_texture_dir, texture)
+ elif "specular" in texture_lower or "metalness" in texture_lower:
+ metalness_map = os.path.join(random_texture_dir, texture)
+
+ # Create SDF string for the model
+ sdf = f"""
+
+ true
+
+
+
+
+ 0 0 1
+ {size[0]} {size[1]}
+
+
+
+
+
+ {friction}
+ {friction}
+ 0 0 0
+ 0.0
+ 0.0
+
+
+
+
+
+
+
+ 0 0 1
+ {size[0]} {size[1]}
+
+
+
+ 1 1 1 1
+ 1 1 1 1
+ 1 1 1 1
+
+
+ {"%s"
+ % albedo_map if albedo_map is not None else ""}
+ {"%s"
+ % normal_map if normal_map is not None else ""}
+ {"%s"
+ % roughness_map if roughness_map is not None else ""}
+ {"%s"
+ % metalness_map if metalness_map is not None else ""}
+
+
+
+
+
+
+ """
+
+ # Insert the model
+ ok_model = world.to_gazebo().insert_model_from_string(
+ sdf, initial_pose, model_name
+ )
+ if not ok_model:
+ raise RuntimeError("Failed to insert " + model_name)
+
+ # Get the model
+ model = world.get_model(model_name)
+
+ # Initialize base class
+ model_wrapper.ModelWrapper.__init__(self, model=model)
diff --git a/env_manager/env_manager/env_manager/models/utils/__init__.py b/env_manager/env_manager/env_manager/models/utils/__init__.py
new file mode 100644
index 0000000..1c438a3
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/utils/__init__.py
@@ -0,0 +1,2 @@
+# from .model_collection_randomizer import ModelCollectionRandomizer
+from .xacro2sdf import xacro2sdf
diff --git a/env_manager/env_manager/env_manager/models/utils/model_collection_randomizer.py b/env_manager/env_manager/env_manager/models/utils/model_collection_randomizer.py
new file mode 100644
index 0000000..428a7a5
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/utils/model_collection_randomizer.py
@@ -0,0 +1,1266 @@
+import glob
+import os
+
+import numpy as np
+import trimesh
+from gym_gz.utils import logger
+from numpy.random import RandomState
+from pcg_gazebo.parsers import parse_sdf
+from pcg_gazebo.parsers.sdf import create_sdf_element
+from scenario import gazebo as scenario_gazebo
+
+# Note: only models with mesh geometry are supported
+
+
+class ModelCollectionRandomizer:
+ """
+ A class to randomize and manage paths for model collections in simulation environments.
+
+ This class allows for the selection and management of model paths, enabling features
+ like caching, blacklisting, and random sampling from a specified collection.
+
+ Attributes:
+ _class_model_paths (list[str] | None): Class-level cache for model paths.
+ __sdf_base_name (str): The base name for the SDF model file.
+ __configured_sdf_base_name (str): The base name for the modified SDF model file.
+ __blacklisted_base_name (str): The base name used for blacklisted models.
+ __collision_mesh_dir (str): Directory where collision meshes are stored.
+ __collision_mesh_file_type (str): The file type of the collision meshes.
+ __original_scale_base_name (str): The base name for original scale metadata files.
+ _unique_cache (bool): Whether to use a unique cache for model paths.
+ _enable_blacklisting (bool): Whether to enable blacklisting of unusable models.
+ _model_paths (list[str] | None): Instance-specific model paths if unique cache is enabled.
+ np_random (RandomState): Numpy random state for random sampling.
+ """
+
+ _class_model_paths = None
+ __sdf_base_name = "model.sdf"
+ __configured_sdf_base_name = "model_modified.sdf"
+ __blacklisted_base_name = "BLACKLISTED"
+ __collision_mesh_dir = "meshes/collision/"
+ __collision_mesh_file_type = "stl"
+ __original_scale_base_name = "original_scale.txt"
+
+ def __init__(
+ self,
+ model_paths=None,
+ owner="GoogleResearch",
+ collection="Google Scanned Objects",
+ server="https://fuel.ignitionrobotics.org",
+ server_version="1.0",
+ unique_cache=False,
+ reset_collection=False,
+ enable_blacklisting=True,
+ np_random: RandomState | None = None,
+ ):
+ """
+ Initializes the ModelCollectionRandomizer.
+
+ Args:
+ model_paths (list[str] | None): List of model paths to initialize with.
+ owner (str): The owner of the model collection (default is "GoogleResearch").
+ collection (str): The name of the model collection (default is "Google Scanned Objects").
+ server (str): The server URL for the model collection (default is "https://fuel.ignitionrobotics.org").
+ server_version (str): The server version to use (default is "1.0").
+ unique_cache (bool): Whether to use a unique cache for the instance (default is False).
+ reset_collection (bool): Whether to reset the class-level model paths (default is False).
+ enable_blacklisting (bool): Whether to enable blacklisting of models (default is True).
+ np_random (RandomState | None): Numpy random state for random operations (default is None).
+
+ Raises:
+ ValueError: If an invalid model path is provided.
+ """
+
+ # If enabled, the newly created objects of this class will use its own individual cache
+ # for model paths and must discover/download them on its own
+ self._unique_cache = unique_cache
+
+ # Flag that determines if models that cannot be used are blacklisted
+ self._enable_blacklisting = enable_blacklisting
+
+ # If enabled, the cache of the class used to store model paths among instances will be reset
+ if reset_collection and not self._unique_cache:
+ self._class_model_paths = None
+
+ # Get file path to all models from
+ # a) `model_paths` arg
+ # b) local cache owner (if `owner` has some models, i.e `collection` is already downloaded)
+ # c) Fuel collection (if `owner` has no models in local cache)
+ if model_paths is not None:
+ # Use arg
+ if self._unique_cache:
+ self._model_paths = model_paths
+ else:
+ self._class_model_paths = model_paths
+ else:
+ # Use local cache or Fuel
+ if self._unique_cache:
+ self._model_paths = self.get_collection_paths(
+ owner=owner,
+ collection=collection,
+ server=server,
+ server_version=server_version,
+ )
+ elif self._class_model_paths is None:
+ # Executed only once, unless the paths are reset with `reset_collection` arg
+ self._class_model_paths = self.get_collection_paths(
+ owner=owner,
+ collection=collection,
+ server=server,
+ server_version=server_version,
+ )
+
+ # Initialise rng with (with seed is desired)
+ if np_random is not None:
+ self.np_random = np_random
+ else:
+ self.np_random = np.random.default_rng()
+
+ @classmethod
+ def get_collection_paths(
+ cls,
+ owner="GoogleResearch",
+ collection="Google Scanned Objects",
+ server="https://fuel.ignitionrobotics.org",
+ server_version="1.0",
+ model_name: str = "",
+ ) -> list[str]:
+ """
+ Retrieves model paths from the local cache or downloads them from a specified server.
+
+ The method first checks the local cache for models belonging to the specified owner.
+ If no models are found, it attempts to download the models from the specified Fuel server.
+
+ Args:
+ cls: The class reference.
+ owner (str): The owner of the model collection (default is "GoogleResearch").
+ collection (str): The name of the model collection (default is "Google Scanned Objects").
+ server (str): The server URL for the model collection (default is "https://fuel.ignitionrobotics.org").
+ server_version (str): The server version to use (default is "1.0").
+ model_name (str): The name of the specific model to fetch (default is an empty string).
+
+ Returns:
+ list[str]: A list of paths to the retrieved models.
+
+ Raises:
+ RuntimeError: If the download command fails or if no models are found after the download.
+ """
+
+ # First check the local cache (for performance)
+ # Note: This unfortunately does not check if models belong to the specified collection
+ # TODO: Make sure models belong to the collection if sampled from local cache
+ model_paths = scenario_gazebo.get_local_cache_model_paths(
+ owner=owner, name=model_name
+ )
+ if len(model_paths) > 0:
+ return model_paths
+
+ # Else download the models from Fuel and then try again
+ if collection:
+ download_uri = "%s/%s/%s/collections/%s" % (
+ server,
+ server_version,
+ owner,
+ collection,
+ )
+ elif model_name:
+ download_uri = "%s/%s/%s/models/%s" % (
+ server,
+ server_version,
+ owner,
+ model_name,
+ )
+ download_command = 'ign fuel download -v 3 -t model -j %s -u "%s"' % (
+ os.cpu_count(),
+ download_uri,
+ )
+ os.system(download_command)
+
+ model_paths = scenario_gazebo.get_local_cache_model_paths(
+ owner=owner, name=model_name
+ )
+ if 0 == len(model_paths):
+ logger.error(
+ 'URI "%s" is not valid and does not contain any models that are \
+ owned by the owner of the collection'
+ % download_uri
+ )
+ pass
+
+ return model_paths
+
+ def random_model(
+ self,
+ min_scale=0.125,
+ max_scale=0.175,
+ min_mass=0.05,
+ max_mass=0.25,
+ min_friction=0.75,
+ max_friction=1.5,
+ decimation_fraction_of_visual=0.25,
+ decimation_min_faces=40,
+ decimation_max_faces=200,
+ max_faces=40000,
+ max_vertices=None,
+ component_min_faces_fraction=0.1,
+ component_max_volume_fraction=0.35,
+ fix_mtl_texture_paths=True,
+ skip_blacklisted=True,
+ return_sdf_path=True,
+ ) -> str:
+ """
+ Selects and configures a random model from the collection.
+
+ The method attempts to find a valid model, applying randomization
+ and returning the path to the configured SDF file or the model directory.
+
+ Args:
+ min_scale (float): Minimum scale for the model.
+ max_scale (float): Maximum scale for the model.
+ min_mass (float): Minimum mass for the model.
+ max_mass (float): Maximum mass for the model.
+ min_friction (float): Minimum friction coefficient.
+ max_friction (float): Maximum friction coefficient.
+ decimation_fraction_of_visual (float): Fraction of visual decimation.
+ decimation_min_faces (int): Minimum number of faces for decimation.
+ decimation_max_faces (int): Maximum number of faces for decimation.
+ max_faces (int): Maximum faces for the model.
+ max_vertices (int, optional): Maximum vertices for the model.
+ component_min_faces_fraction (float): Minimum face fraction for components.
+ component_max_volume_fraction (float): Maximum volume fraction for components.
+ fix_mtl_texture_paths (bool): Whether to fix MTL texture paths.
+ skip_blacklisted (bool): Whether to skip blacklisted models.
+ return_sdf_path (bool): If True, return the configured SDF file path.
+
+ Returns:
+ str: Path to the configured SDF file or model directory.
+
+ Raises:
+ RuntimeError: If a valid model cannot be found after multiple attempts.
+ """
+
+ # Loop until a model is found, checked for validity, configured and returned
+ # If any of these steps fail, sample another model and try again
+ # Note: Due to this behaviour, the function could stall if all models are invalid
+ # TODO: Add a simple timeout to random sampling of valid model (# of attempts or time-based)
+ while True:
+ # Get path to a random model from the collection
+ model_path = self.get_random_model_path()
+
+ # Check if the model is already blacklisted and skip if desired
+ if skip_blacklisted and self.is_blacklisted(model_path):
+ continue
+
+ # Check is the model is already configured
+ if self.is_configured(model_path):
+ # If so, break the loop
+ break
+
+ # Process the model and break loop only if it is valid
+ if self.process_model(
+ model_path,
+ decimation_fraction_of_visual=decimation_fraction_of_visual,
+ decimation_min_faces=decimation_min_faces,
+ decimation_max_faces=decimation_max_faces,
+ max_faces=max_faces,
+ max_vertices=max_vertices,
+ component_min_faces_fraction=component_min_faces_fraction,
+ component_max_volume_fraction=component_max_volume_fraction,
+ fix_mtl_texture_paths=fix_mtl_texture_paths,
+ ):
+ break
+
+ # Apply randomization
+ self.randomize_configured_model(
+ model_path,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ min_friction=min_friction,
+ max_friction=max_friction,
+ min_mass=min_mass,
+ max_mass=max_mass,
+ )
+
+ if return_sdf_path:
+ # Return path to the configured SDF file
+ return self.get_configured_sdf_path(model_path)
+ else:
+ # Return path to the model directory
+ return model_path
+
+ def process_all_models(
+ self,
+ decimation_fraction_of_visual=0.025,
+ decimation_min_faces=8,
+ decimation_max_faces=400,
+ max_faces=40000,
+ max_vertices=None,
+ component_min_faces_fraction=0.1,
+ component_max_volume_fraction=0.35,
+ fix_mtl_texture_paths=True,
+ ):
+ """
+ Processes all models in the collection, applying configuration and decimation.
+
+ This method iterates over each model in the collection, applying visual decimation
+ and checking for validity. If a model cannot be processed successfully, it is
+ blacklisted. The method tracks and prints the number of processed models and
+ the count of blacklisted models at the end of execution.
+
+ Args:
+ decimation_fraction_of_visual (float): Fraction of visual decimation to apply.
+ Default is 0.025.
+ decimation_min_faces (int): Minimum number of faces for decimation. Default is 8.
+ decimation_max_faces (int): Maximum number of faces for decimation. Default is 400.
+ max_faces (int): Maximum faces allowed for the model. Default is 40000.
+ max_vertices (int, optional): Maximum vertices allowed for the model. Default is None.
+ component_min_faces_fraction (float): Minimum face fraction for components during
+ processing. Default is 0.1.
+ component_max_volume_fraction (float): Maximum volume fraction for components during
+ processing. Default is 0.35.
+ fix_mtl_texture_paths (bool): Whether to fix MTL texture paths for the models.
+ Default is True.
+
+ Returns:
+ None: This method modifies the models in place and does not return a value.
+
+ Prints:
+ - The status of each processed model along with its index.
+ - The total number of blacklisted models after processing all models.
+
+ Raises:
+ None: The method does not raise exceptions but may log issues if models cannot be
+ processed.
+ """
+ if self._unique_cache:
+ model_paths = self._model_paths
+ else:
+ model_paths = self._class_model_paths
+
+ blacklist_model_counter = 0
+ for i in range(len(model_paths)):
+ if not self.process_model(
+ model_paths[i],
+ decimation_fraction_of_visual=decimation_fraction_of_visual,
+ decimation_min_faces=decimation_min_faces,
+ decimation_max_faces=decimation_max_faces,
+ max_faces=max_faces,
+ max_vertices=max_vertices,
+ component_min_faces_fraction=component_min_faces_fraction,
+ component_max_volume_fraction=component_max_volume_fraction,
+ fix_mtl_texture_paths=fix_mtl_texture_paths,
+ ):
+ blacklist_model_counter += 1
+
+ print('Processed model %i/%i "%s"' % (i, len(model_paths), model_paths[i]))
+
+ print("Number of blacklisted models: %i" % blacklist_model_counter)
+
+ def process_model(
+ self,
+ model_path,
+ decimation_fraction_of_visual=0.25,
+ decimation_min_faces=40,
+ decimation_max_faces=200,
+ max_faces=40000,
+ max_vertices=None,
+ component_min_faces_fraction=0.1,
+ component_max_volume_fraction=0.35,
+ fix_mtl_texture_paths=True,
+ ) -> bool:
+ """
+ Processes a specified model to configure it for simulation, applying geometry
+ decimation and checking for various validity criteria.
+
+ This method parses the SDF (Simulation Description Format) of the specified
+ model, processes its components, checks for geometry validity, and updates
+ inertial properties as needed. The processed model is then saved back to
+ an SDF file.
+
+ Args:
+ model_path (str): Path to the model to be processed.
+ decimation_fraction_of_visual (float): Fraction of visual geometry to retain
+ during decimation. Default is 0.25.
+ decimation_min_faces (int): Minimum number of faces allowed in the decimated
+ visual geometry. Default is 40.
+ decimation_max_faces (int): Maximum number of faces allowed in the decimated
+ visual geometry. Default is 200.
+ max_faces (int): Maximum number of faces for the entire model. Default is 40000.
+ max_vertices (Optional[int]): Maximum number of vertices for the model. Default is None.
+ component_min_faces_fraction (float): Minimum face fraction for components to
+ be considered valid during processing. Default is 0.1.
+ component_max_volume_fraction (float): Maximum volume fraction for components
+ to be considered valid during processing. Default is 0.35.
+ fix_mtl_texture_paths (bool): Whether to fix texture paths in MTL files for
+ mesh formats. Default is True.
+
+ Returns:
+ bool: Returns True if the model was processed successfully; otherwise,
+ returns False if any validity checks fail or if processing issues are encountered.
+
+ Raises:
+ None: The method does not raise exceptions but may return False if processing
+ fails due to model invalidity.
+
+ Notes:
+ - The model is blacklisted if it fails to meet the specified criteria for
+ geometry and inertial properties.
+ - The method updates the SDF of the model with processed geometries and
+ inertial properties.
+ """
+
+ # Parse the SDF of the model
+ sdf = parse_sdf(self.get_sdf_path(model_path))
+
+ # Process the model(s) contained in the SDF
+ for model in sdf.models:
+ # Process the link(s) of each model
+ for link in model.links:
+ # Get rid of the existing collisions prior to simplifying it
+ link.collisions.clear()
+
+ # Values for the total inertial properties of current link
+ # These values will be updated for each body that the link contains
+ total_mass = 0.0
+ total_inertia = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
+ common_centre_of_mass = [0.0, 0.0, 0.0]
+
+ # Go through the visuals and process them
+ for visual in link.visuals:
+ # Get path to the mesh of the link's visual
+ mesh_path = self.get_mesh_path(model_path, visual)
+
+ # If desired, fix texture path in 'mtl' files for '.obj' mesh format
+ if fix_mtl_texture_paths:
+ self.fix_mtl_texture_paths(
+ model_path, mesh_path, model.attributes["name"]
+ )
+
+ # Load the mesh (without materials)
+ mesh = trimesh.load(mesh_path, force="mesh", skip_materials=True)
+
+ # Check if model has too much geometry (blacklist if needed)
+ if not self.check_excessive_geometry(
+ mesh, model_path, max_faces=max_faces, max_vertices=max_vertices
+ ):
+ return False
+
+ # Check if model has disconnected geometry/components (blacklist if needed)
+ if not self.check_disconnected_components(
+ mesh,
+ model_path,
+ component_min_faces_fraction=component_min_faces_fraction,
+ component_max_volume_fraction=component_max_volume_fraction,
+ ):
+ return False
+
+ # Compute inertial properties for this mesh
+ (
+ total_mass,
+ total_inertia,
+ common_centre_of_mass,
+ ) = self.sum_inertial_properties(
+ mesh, total_mass, total_inertia, common_centre_of_mass
+ )
+
+ # Add decimated collision geometry to the SDF
+ self.add_collision(
+ mesh,
+ link,
+ model_path,
+ fraction_of_visual=decimation_fraction_of_visual,
+ min_faces=decimation_min_faces,
+ max_faces=decimation_max_faces,
+ )
+
+ # Write original scale (size) into the SDF
+ # This is used for later reference during randomization (for scale limits)
+ self.write_original_scale(mesh, model_path)
+
+ # Make sure the link has valid inertial properties (blacklist if needed)
+ if not self.check_inertial_properties(
+ model_path, total_mass, total_inertia
+ ):
+ return False
+
+ # Write inertial properties to the SDF of the link
+ self.write_inertial_properties(
+ link, total_mass, total_inertia, common_centre_of_mass
+ )
+
+ # Write the configured SDF into a file
+ sdf.export_xml(self.get_configured_sdf_path(model_path))
+ return True
+
+ def add_collision(
+ self,
+ mesh,
+ link,
+ model_path,
+ fraction_of_visual=0.05,
+ min_faces=8,
+ max_faces=750,
+ friction=1.0,
+ ):
+ """
+ Adds collision geometry to a specified link in a model, using a simplified
+ version of the provided mesh.
+
+ This method creates a collision mesh by decimating the visual geometry of
+ the model, generating an SDF (Simulation Description Format) element for
+ the collision, and adding it to the specified link.
+
+ Args:
+ mesh (trimesh.Trimesh): The original mesh geometry from which the
+ collision geometry will be created.
+ link (Link): The link to which the collision geometry will be added.
+ model_path (str): The path to the model where the collision mesh will
+ be stored.
+ fraction_of_visual (float): The fraction of the visual mesh to retain
+ for the collision mesh. Default is 0.05.
+ min_faces (int): The minimum number of faces allowed in the collision
+ mesh after decimation. Default is 8.
+ max_faces (int): The maximum number of faces allowed in the collision
+ mesh after decimation. Default is 750.
+ friction (float): The friction coefficient to be applied to the collision
+ surface. Default is 1.0.
+
+ Returns:
+ None: This method does not return a value.
+
+ Notes:
+ - The collision mesh is generated through quadratic decimation of the
+ original visual mesh based on the specified parameters.
+ - The created collision mesh is stored in the same directory as the
+ model, and its path is relative to the model path.
+ - The method updates the link's SDF with the collision geometry and
+ surface properties.
+ """
+
+ # Determine name of path to the collision geometry
+ collision_name = (
+ link.attributes["name"] + "_collision_" + str(len(link.collisions))
+ )
+ collision_mesh_path = self.get_collision_mesh_path(model_path, collision_name)
+
+ # Determine number of faces to keep after the decimation
+ face_count = min(
+ max(fraction_of_visual * len(mesh.faces), min_faces), max_faces
+ )
+
+ # Simplify mesh via decimation
+ collision_mesh = mesh.simplify_quadratic_decimation(face_count)
+
+ # Export the collision mesh to the appropriate location
+ os.makedirs(os.path.dirname(collision_mesh_path), exist_ok=True)
+ collision_mesh.export(
+ collision_mesh_path, file_type=self.__collision_mesh_file_type
+ )
+
+ # Create collision SDF element
+ collision = create_sdf_element("collision")
+
+ # Add collision geometry to the SDF
+ collision.geometry.mesh = create_sdf_element("mesh")
+ collision.geometry.mesh.uri = os.path.relpath(
+ collision_mesh_path, start=model_path
+ )
+
+ # Add surface friction to the SDF of collision (default to 1 and randomize later)
+ collision.surface = create_sdf_element("surface")
+ collision.surface.friction = create_sdf_element("friction", "surface")
+ collision.surface.friction.ode = create_sdf_element("ode", "collision")
+ collision.surface.friction.ode.mu = friction
+ collision.surface.friction.ode.mu2 = friction
+
+ # Add it to the SDF of the link
+ collision_name = os.path.basename(collision_mesh_path).split(".")[0]
+ link.add_collision(collision_name, collision)
+
+ def sum_inertial_properties(
+ self, mesh, total_mass, total_inertia, common_centre_of_mass, density=1.0
+ ) -> tuple[float, float, float]:
+ # Arbitrary density is used here
+ # The mass will be randomized once it is fully computed for a link
+ mesh.density = density
+
+ # Tmp variable to store the mass of all previous geometry, used to determine centre of mass
+ mass_of_others = total_mass
+
+ # For each additional mesh, simply add the mass and inertia
+ total_mass += mesh.mass
+ total_inertia += mesh.moment_inertia
+
+ # Compute a common centre of mass between all previous geometry and the new mesh
+ common_centre_of_mass = [
+ mass_of_others * common_centre_of_mass[0] + mesh.mass * mesh.center_mass[0],
+ mass_of_others * common_centre_of_mass[1] + mesh.mass * mesh.center_mass[1],
+ mass_of_others * common_centre_of_mass[2] + mesh.mass * mesh.center_mass[2],
+ ] / total_mass
+
+ return total_mass, total_inertia, common_centre_of_mass
+
+ def randomize_configured_model(
+ self,
+ model_path,
+ min_scale=0.05,
+ max_scale=0.25,
+ min_mass=0.1,
+ max_mass=3.0,
+ min_friction=0.75,
+ max_friction=1.5,
+ ):
+ """
+ Randomizes the scale, mass, and friction properties of a configured model
+ in the specified SDF file.
+
+ This method modifies the properties of each link in the model, applying
+ random values within specified ranges for scale, mass, and friction.
+ The randomized values overwrite the original settings in the configured
+ SDF file.
+
+ Args:
+ model_path (str): The path to the model directory containing the
+ configured SDF file.
+ min_scale (float): The minimum scale factor for randomization.
+ Default is 0.05.
+ max_scale (float): The maximum scale factor for randomization.
+ Default is 0.25.
+ min_mass (float): The minimum mass for randomization. Default is 0.1.
+ max_mass (float): The maximum mass for randomization. Default is 3.0.
+ min_friction (float): The minimum friction coefficient for randomization.
+ Default is 0.75.
+ max_friction (float): The maximum friction coefficient for randomization.
+ Default is 1.5.
+
+ Returns:
+ None: This method does not return a value.
+
+ Notes:
+ - The configured SDF file is updated in place, meaning the original
+ properties are overwritten with the new randomized values.
+ - The randomization is applied to each link in the model, ensuring
+ a diverse range of physical properties.
+ """
+
+ # Get path to the configured SDF file
+ configured_sdf_path = self.get_configured_sdf_path(model_path)
+
+ # Parse the configured SDF that needs to be randomized
+ sdf = parse_sdf(configured_sdf_path)
+
+ # Process the model(s) contained in the SDF
+ for model in sdf.models:
+ # Process the link(s) of each model
+ for link in model.links:
+ # Randomize scale of the link
+ self.randomize_scale(
+ model_path, link, min_scale=min_scale, max_scale=max_scale
+ )
+
+ # Randomize inertial properties of the link
+ self.randomize_inertial(link, min_mass=min_mass, max_mass=max_mass)
+
+ # Randomize friction of the link
+ self.randomize_friction(
+ link, min_friction=min_friction, max_friction=max_friction
+ )
+
+ # Overwrite the configured SDF file with randomized values
+ sdf.export_xml(configured_sdf_path)
+
+ def randomize_scale(self, model_path, link, min_scale=0.05, max_scale=0.25):
+ """
+ Randomizes the scale of a link's geometry within a specified range.
+
+ This method modifies the scale of the visual and collision geometry
+ of a link in the model. The scale is randomized using a uniform
+ distribution within the given minimum and maximum scale values.
+ Additionally, the link's inertial properties (mass and inertia)
+ are recalculated based on the new scale.
+
+ Args:
+ model_path (str): The path to the model directory, used to
+ read the original scale of the mesh.
+ link (Link): The link object containing the visual and collision
+ geometries to be scaled.
+ min_scale (float): The minimum scale factor for randomization.
+ Default is 0.05.
+ max_scale (float): The maximum scale factor for randomization.
+ Default is 0.25.
+
+ Returns:
+ bool: Returns `False` if the link has more than one visual,
+ indicating that scaling is not supported. Returns `None`
+ if the scaling is successful.
+
+ Notes:
+ - This method currently supports only links that contain a
+ single mesh geometry.
+ - The mass of the link is scaled by the cube of the scale factor,
+ while the inertial properties are scaled by the fifth power of
+ the scale factor.
+ - The scale values are applied uniformly to all dimensions (x, y, z).
+ """
+
+ # Note: This function currently supports only scaling of links with single mesh geometry
+ if len(link.visuals) > 1:
+ return False
+
+ # Get a random scale for the size of mesh
+ random_scale = self.np_random.uniform(min_scale, max_scale)
+
+ # Determine a scale factor that will result in such scale for the size of mesh
+ original_mesh_scale = self.read_original_scale(model_path)
+ scale_factor = random_scale / original_mesh_scale
+
+ # Determine scale factor for inertial properties based on random scale and current scale
+ current_scale = link.visuals[0].geometry.mesh.scale.value[0]
+ inertial_scale_factor = scale_factor / current_scale
+
+ # Write scale factor into SDF for visual and collision geometry
+ link.visuals[0].geometry.mesh.scale = [scale_factor] * 3
+ link.collisions[0].geometry.mesh.scale = [scale_factor] * 3
+
+ # Recompute inertial properties according to the scale
+ link.inertial.pose.x *= inertial_scale_factor
+ link.inertial.pose.y *= inertial_scale_factor
+ link.inertial.pose.z *= inertial_scale_factor
+
+ # Mass is scaled n^3
+ link.mass = link.mass.value * inertial_scale_factor**3
+
+ # Inertia is scaled n^5
+ inertial_scale_factor_n5 = inertial_scale_factor**5
+ link.inertia.ixx = link.inertia.ixx.value * inertial_scale_factor_n5
+ link.inertia.iyy = link.inertia.iyy.value * inertial_scale_factor_n5
+ link.inertia.izz = link.inertia.izz.value * inertial_scale_factor_n5
+ link.inertia.ixy = link.inertia.ixy.value * inertial_scale_factor_n5
+ link.inertia.ixz = link.inertia.ixz.value * inertial_scale_factor_n5
+ link.inertia.iyz = link.inertia.iyz.value * inertial_scale_factor_n5
+
+ def randomize_inertial(
+ self, link, min_mass=0.1, max_mass=3.0
+ ) -> tuple[float, float]:
+ """
+ Randomizes the mass and updates the inertial properties of a link.
+
+ This method assigns a random mass to the link within the specified
+ range and recalculates its inertial properties based on the new mass.
+
+ Args:
+ link (Link): The link object whose mass and inertial properties
+ will be randomized.
+ min_mass (float): The minimum mass for randomization. Default is 0.1.
+ max_mass (float): The maximum mass for randomization. Default is 3.0.
+
+ Returns:
+ tuple[float, float]: A tuple containing the new mass and the
+ mass scale factor used to update the inertial properties.
+
+ Notes:
+ - The method modifies the link's mass and updates its
+ inertial properties (ixx, iyy, izz, ixy, ixz, iyz) based on the
+ ratio of the new mass to the original mass.
+ """
+
+ random_mass = self.np_random.uniform(min_mass, max_mass)
+ mass_scale_factor = random_mass / link.mass.value
+
+ link.mass = random_mass
+ link.inertia.ixx = link.inertia.ixx.value * mass_scale_factor
+ link.inertia.iyy = link.inertia.iyy.value * mass_scale_factor
+ link.inertia.izz = link.inertia.izz.value * mass_scale_factor
+ link.inertia.ixy = link.inertia.ixy.value * mass_scale_factor
+ link.inertia.ixz = link.inertia.ixz.value * mass_scale_factor
+ link.inertia.iyz = link.inertia.iyz.value * mass_scale_factor
+
+ def randomize_friction(self, link, min_friction=0.75, max_friction=1.5):
+ """
+ Randomizes the friction coefficients of a link's collision surfaces.
+
+ This method assigns random friction values to each collision surface
+ of the link within the specified range.
+
+ Args:
+ link (Link): The link object whose collision surfaces' friction
+ coefficients will be randomized.
+ min_friction (float): The minimum friction coefficient for
+ randomization. Default is 0.75.
+ max_friction (float): The maximum friction coefficient for
+ randomization. Default is 1.5.
+
+ Notes:
+ - The friction coefficients are applied to both the
+ 'mu' and 'mu2' attributes of the collision surface's friction
+ properties.
+ """
+
+ for collision in link.collisions:
+ random_friction = self.np_random.uniform(min_friction, max_friction)
+
+ collision.surface.friction.ode.mu = random_friction
+ collision.surface.friction.ode.mu2 = random_friction
+
+ def write_inertial_properties(self, link, mass, inertia, centre_of_mass):
+ """
+ Writes the specified mass and inertial properties to a link.
+
+ This method updates the link's mass, inertia tensor, and
+ centre of mass based on the provided values.
+
+ Args:
+ link (Link): The link object to be updated with new inertial properties.
+ mass (float): The mass to set for the link.
+ inertia (list[list[float]]): A 3x3 list representing the
+ inertia tensor of the link.
+ centre_of_mass (list[float]): A list containing the x, y, z
+ coordinates of the link's centre of mass.
+
+ Notes:
+ - The method directly modifies the link's mass and inertia properties,
+ as well as its inertial pose, which is set to the specified centre of mass.
+ """
+
+ link.mass = mass
+
+ link.inertia.ixx = inertia[0][0]
+ link.inertia.iyy = inertia[1][1]
+ link.inertia.izz = inertia[2][2]
+ link.inertia.ixy = inertia[0][1]
+ link.inertia.ixz = inertia[0][2]
+ link.inertia.iyz = inertia[1][2]
+
+ link.inertial.pose = [
+ centre_of_mass[0],
+ centre_of_mass[1],
+ centre_of_mass[2],
+ 0.0,
+ 0.0,
+ 0.0,
+ ]
+
+ def write_original_scale(self, mesh, model_path):
+ """
+ Writes the original scale of the mesh to a file.
+
+ This method records the scale of the provided mesh into a text file,
+ which can be used later for reference during randomization or scaling
+ operations.
+
+ Args:
+ mesh (Mesh): The mesh object whose original scale is to be recorded.
+ model_path (str): The path to the model directory where the
+ original scale file will be saved.
+
+ Notes:
+ - The scale is written as a string representation of the mesh's
+ scale property.
+ - The file is created or overwritten at the location returned
+ by the `get_original_scale_path` method.
+ """
+
+ file = open(self.get_original_scale_path(model_path), "w")
+ file.write(str(mesh.scale))
+ file.close()
+
+ def read_original_scale(self, model_path) -> float:
+ """
+ Reads the original scale of a model from a file.
+
+ This method retrieves the original scale value that was previously
+ saved for a specific model. The scale is expected to be stored as
+ a string in a file, and this method converts it back to a float.
+
+ Args:
+ model_path (str): The path to the model directory from which
+ to read the original scale.
+
+ Returns:
+ float: The original scale of the model.
+
+ Raises:
+ ValueError: If the contents of the scale file cannot be converted to a float.
+ """
+
+ file = open(self.get_original_scale_path(model_path), "r")
+ original_scale = file.read()
+ file.close()
+
+ return float(original_scale)
+
+ def check_excessive_geometry(
+ self, mesh, model_path, max_faces=40000, max_vertices=None
+ ) -> bool:
+ """
+ Checks if the mesh exceeds the specified geometry limits.
+
+ This method evaluates the number of faces and vertices in the
+ given mesh. If the mesh exceeds the defined limits for faces
+ or vertices, it blacklists the model and returns False.
+
+ Args:
+ mesh (Mesh): The mesh object to be checked for excessive geometry.
+ model_path (str): The path to the model directory for logging purposes.
+ max_faces (int, optional): The maximum allowed number of faces.
+ Defaults to 40000.
+ max_vertices (int, optional): The maximum allowed number of vertices.
+ Defaults to None.
+
+ Returns:
+ bool: True if the mesh does not exceed the limits; otherwise, False.
+ """
+
+ if max_faces is not None:
+ num_faces = len(mesh.faces)
+ if num_faces > max_faces:
+ self.blacklist_model(
+ model_path, reason="Excessive geometry (%d faces)" % num_faces
+ )
+ return False
+
+ if max_vertices is not None:
+ num_vertices = len(mesh.vertices)
+ if num_vertices > max_vertices:
+ self.blacklist_model(
+ model_path, reason="Excessive geometry (%d vertices)" % num_vertices
+ )
+ return False
+
+ return True
+
+ def check_disconnected_components(
+ self,
+ mesh,
+ model_path,
+ component_min_faces_fraction=0.05,
+ component_max_volume_fraction=0.1,
+ ) -> bool:
+ """
+ Checks for disconnected components within the mesh.
+
+ This method identifies connected components within the mesh. If
+ there are multiple components, it evaluates their relative volumes.
+ If any component exceeds the specified volume fraction threshold,
+ the model is blacklisted.
+
+ Args:
+ mesh (Mesh): The mesh object to be evaluated for disconnected components.
+ model_path (str): The path to the model directory for logging purposes.
+ component_min_faces_fraction (float, optional): The minimum fraction
+ of faces required for a component to be considered. Defaults to 0.05.
+ component_max_volume_fraction (float, optional): The maximum allowed
+ volume fraction for a component. Defaults to 0.1.
+
+ Returns:
+ bool: True if the mesh has no significant disconnected components;
+ otherwise, False.
+ """
+
+ # Get a list of all connected components inside the mesh
+ # Consider components only with `component_min_faces_fraction` percent faces
+ min_faces = round(component_min_faces_fraction * len(mesh.faces))
+ connected_components = trimesh.graph.connected_components(
+ mesh.face_adjacency, min_len=min_faces
+ )
+
+ # If more than 1 objects were detected, consider also relative volume of the meshes
+ if len(connected_components) > 1:
+ total_volume = mesh.volume
+
+ large_component_counter = 0
+ for component in connected_components:
+ submesh = mesh.copy()
+ mask = np.zeros(len(mesh.faces), dtype=np.bool)
+ mask[component] = True
+ submesh.update_faces(mask)
+
+ volume_fraction = submesh.volume / total_volume
+ if volume_fraction > component_max_volume_fraction:
+ large_component_counter += 1
+
+ if large_component_counter > 1:
+ self.blacklist_model(
+ model_path,
+ reason="Disconnected components (%d instances)"
+ % len(connected_components),
+ )
+ return False
+
+ return True
+
+ def check_inertial_properties(self, model_path, mass, inertia) -> bool:
+ """
+ Checks the validity of inertial properties for a model.
+
+ This method evaluates whether the mass and the principal moments of
+ inertia are valid. If any of these values are below a specified threshold,
+ the model is blacklisted.
+
+ Args:
+ model_path (str): The path to the model directory for logging purposes.
+ mass (float): The mass of the model to be checked.
+ inertia (list of list of float): A 2D list representing the inertia matrix.
+
+ Returns:
+ bool: True if the inertial properties are valid; otherwise, False.
+ """
+ if (
+ mass < 1e-10
+ or inertia[0][0] < 1e-10
+ or inertia[1][1] < 1e-10
+ or inertia[2][2] < 1e-10
+ ):
+ self.blacklist_model(model_path, reason="Invalid inertial properties")
+ return False
+
+ return True
+
+ def get_random_model_path(self) -> str:
+ """
+ Retrieves a random model path from the available models.
+
+ This method selects and returns a random model path. The selection
+ depends on whether a unique cache is being used or not.
+
+ Returns:
+ str: A randomly selected model path.
+ """
+ if self._unique_cache:
+ return self.np_random.choice(self._model_paths)
+ else:
+ return self.np_random.choice(self._class_model_paths)
+
+ def get_collision_mesh_path(self, model_path, collision_name) -> str:
+ """
+ Constructs the path for the collision mesh file.
+
+ This method generates and returns the file path for a collision
+ mesh based on the model path and the name of the collision.
+
+ Args:
+ model_path (str): The path to the model directory.
+ collision_name (str): The name of the collision geometry.
+
+ Returns:
+ str: The full path to the collision mesh file.
+ """
+ return os.path.join(
+ model_path,
+ self.__collision_mesh_dir,
+ collision_name + "." + self.__collision_mesh_file_type,
+ )
+
+ def get_sdf_path(self, model_path) -> str:
+ """
+ Constructs the path for the SDF (Simulation Description Format) file.
+
+ This method generates and returns the file path for the SDF file
+ corresponding to the given model path.
+
+ Args:
+ model_path (str): The path to the model directory.
+
+ Returns:
+ str: The full path to the SDF file.
+ """
+ return os.path.join(model_path, self.__sdf_base_name)
+
+ def get_configured_sdf_path(self, model_path) -> str:
+ """
+ Constructs the path for the configured SDF file.
+
+ This method generates and returns the file path for the configured
+ SDF file that includes any modifications made to the original model.
+
+ Args:
+ model_path (str): The path to the model directory.
+
+ Returns:
+ str: The full path to the configured SDF file.
+ """
+ return os.path.join(model_path, self.__configured_sdf_base_name)
+
+ def get_blacklisted_path(self, model_path) -> str:
+ """
+ Constructs the path for the blacklisted model file.
+
+ This method generates and returns the file path where information
+ about blacklisted models is stored.
+
+ Args:
+ model_path (str): The path to the model directory.
+
+ Returns:
+ str: The full path to the blacklisted model file.
+ """
+ return os.path.join(model_path, self.__blacklisted_base_name)
+
+ def get_mesh_path(self, model_path, visual_or_collision) -> str:
+ """
+ Constructs the path for the mesh associated with a visual or collision.
+
+ This method retrieves the URI from the specified visual or collision
+ object and constructs the full path to the corresponding mesh file.
+
+ Args:
+ model_path (str): The path to the model directory.
+ visual_or_collision (object): An object containing the geometry
+ information for either visual or collision.
+
+ Returns:
+ str: The full path to the mesh file.
+
+ Notes:
+ This method may require adjustments for specific collections or models.
+ """
+ # TODO: This might need fixing for certain collections/models
+ mesh_uri = visual_or_collision.geometry.mesh.uri.value
+ return os.path.join(model_path, mesh_uri)
+
+ def get_original_scale_path(self, model_path) -> str:
+ """
+ Constructs the path for the original scale file.
+
+ This method generates and returns the file path for the file
+ that stores the original scale of the model.
+
+ Args:
+ model_path (str): The path to the model directory.
+
+ Returns:
+ str: The full path to the original scale file.
+ """
+ return os.path.join(model_path, self.__original_scale_base_name)
+
+ def blacklist_model(self, model_path, reason="Unknown"):
+ """
+ Blacklists a model by writing its path and the reason to a file.
+
+ If blacklisting is enabled, this method writes the specified reason
+ for blacklisting the model to a designated file. It also logs the
+ blacklisting action.
+
+ Args:
+ model_path (str): The path to the model directory.
+ reason (str): The reason for blacklisting the model. Default is "Unknown".
+ """
+ if self._enable_blacklisting:
+ bl_file = open(self.get_blacklisted_path(model_path), "w")
+ bl_file.write(reason)
+ bl_file.close()
+ logger.warn(
+ '%s model "%s". Reason: %s.'
+ % (
+ "Blacklisting" if self._enable_blacklisting else "Skipping",
+ model_path,
+ reason,
+ )
+ )
+
+ def is_blacklisted(self, model_path) -> bool:
+ """
+ Checks if a model is blacklisted.
+
+ This method checks if the blacklisted file for a given model exists,
+ indicating that the model has been blacklisted.
+
+ Args:
+ model_path (str): The path to the model directory.
+
+ Returns:
+ bool: True if the model is blacklisted; otherwise, False.
+ """
+ return os.path.isfile(self.get_blacklisted_path(model_path))
+
+ def is_configured(self, model_path) -> bool:
+ """
+ Checks if a model is configured.
+
+ This method checks if the configured SDF file for a given model
+ exists, indicating that the model has been processed and configured.
+
+ Args:
+ model_path (str): The path to the model directory.
+
+ Returns:
+ bool: True if the model is configured; otherwise, False.
+ """
+ return os.path.isfile(self.get_configured_sdf_path(model_path))
+
+ def fix_mtl_texture_paths(self, model_path, mesh_path, model_name):
+ """
+ Fixes the texture paths in the associated MTL file of an OBJ model.
+
+ This method modifies the texture paths in the MTL file associated
+ with a given OBJ file to ensure they are correctly linked to
+ the model's textures. It also ensures that the texture files are
+ uniquely named by prepending the model's name to avoid conflicts.
+
+ Args:
+ model_path (str): The path to the model directory where textures are located.
+ mesh_path (str): The path to the OBJ file for which the MTL file needs fixing.
+ model_name (str): The name of the model, used to make texture file names unique.
+
+ Notes:
+ - This method scans the model directory for texture files, identifies
+ the relevant MTL file, and updates any texture paths in that MTL file.
+ - If a texture file's name does not include the model's name, the method
+ renames it to include the model's name.
+ - The method only processes MTL files linked to OBJ files.
+ """
+ # The `.obj` files use mtl
+ if mesh_path.endswith(".obj"):
+ # Find all textures located in the model path, used later to relative linking
+ texture_files = glob.glob(os.path.join(model_path, "**", "textures", "*.*"))
+
+ # Find location of mtl file, if any
+ mtllib_file = None
+ with open(mesh_path, "r") as file:
+ for line in file:
+ if "mtllib" in line:
+ mtllib_file = line.split(" ")[-1].strip()
+ break
+
+ if mtllib_file is not None:
+ mtllib_file = os.path.join(os.path.dirname(mesh_path), mtllib_file)
+
+ fin = open(mtllib_file, "r")
+ data = fin.read()
+ for line in data.splitlines():
+ if "map_" in line:
+ # Find the name of the texture/map in the mtl
+ map_file = line.split(" ")[-1].strip()
+
+ # Find the first match of the texture/map file
+ for texture_file in texture_files:
+ if os.path.basename(
+ texture_file
+ ) == map_file or os.path.basename(
+ texture_file
+ ) == os.path.basename(map_file):
+ # Make the file unique to the model (unless it already is)
+ if model_name in texture_file:
+ new_texture_file_name = texture_file
+ else:
+ new_texture_file_name = texture_file.replace(
+ map_file, model_name + "_" + map_file
+ )
+
+ os.rename(texture_file, new_texture_file_name)
+
+ # Apply the correct relative path
+ data = data.replace(
+ map_file,
+ os.path.relpath(
+ new_texture_file_name,
+ start=os.path.dirname(mesh_path),
+ ),
+ )
+ break
+ fin.close()
+
+ # Write in the correct data
+ fout = open(mtllib_file, "w")
+ fout.write(data)
+ fout.close()
diff --git a/env_manager/env_manager/env_manager/models/utils/xacro2sdf.py b/env_manager/env_manager/env_manager/models/utils/xacro2sdf.py
new file mode 100644
index 0000000..2a682ee
--- /dev/null
+++ b/env_manager/env_manager/env_manager/models/utils/xacro2sdf.py
@@ -0,0 +1,37 @@
+
+import subprocess
+import tempfile
+from typing import Dict, Optional, Tuple
+
+import xacro
+
+
+def xacro2sdf(
+ input_file_path: str, mappings: Dict, model_path_remap: Optional[Tuple[str, str]]
+) -> str:
+ """Convert xacro (URDF variant) with given arguments to SDF and return as a string."""
+
+ # Convert all values in mappings to strings
+ for keys, values in mappings.items():
+ mappings[keys] = str(values)
+
+ # Convert xacro to URDF
+ urdf_xml = xacro.process(input_file_name=input_file_path, mappings=mappings)
+
+ # Create temporary file for URDF (`ign sdf -p` accepts only files)
+ with tempfile.NamedTemporaryFile() as tmp_urdf:
+ with open(tmp_urdf.name, "w") as urdf_file:
+ urdf_file.write(urdf_xml)
+
+ # Convert to SDF
+ result = subprocess.run(
+ ["ign", "sdf", "-p", tmp_urdf.name], stdout=subprocess.PIPE
+ )
+ sdf_xml = result.stdout.decode("utf-8")
+
+ # Remap package name to model name, such that meshes can be located by Ignition
+ if model_path_remap is not None:
+ sdf_xml = sdf_xml.replace(model_path_remap[0], model_path_remap[1])
+
+ # Return as string
+ return sdf_xml
diff --git a/env_manager/env_manager/env_manager/scene/__init__.py b/env_manager/env_manager/env_manager/scene/__init__.py
new file mode 100644
index 0000000..421848f
--- /dev/null
+++ b/env_manager/env_manager/env_manager/scene/__init__.py
@@ -0,0 +1 @@
+from .scene import Scene
diff --git a/env_manager/env_manager/env_manager/scene/scene.py b/env_manager/env_manager/env_manager/scene/scene.py
new file mode 100644
index 0000000..a42e4f8
--- /dev/null
+++ b/env_manager/env_manager/env_manager/scene/scene.py
@@ -0,0 +1,1480 @@
+from dataclasses import asdict
+
+import numpy as np
+import rclpy
+from geometry_msgs.msg import Point as RosPoint
+from geometry_msgs.msg import Quaternion as RosQuat
+from gym_gz.scenario.model_wrapper import ModelWrapper
+from rbs_assets_library import get_model_meshes_info
+from rcl_interfaces.srv import GetParameters
+from rclpy.node import Node, Parameter
+from scenario.bindings.gazebo import GazeboSimulator, PhysicsEngine_dart, World
+from scipy.spatial import distance
+from scipy.spatial.transform import Rotation
+from visualization_msgs.msg import Marker, MarkerArray
+
+from env_manager.models.terrains.random_ground import RandomGround
+
+from ..models import Box, Camera, Cylinder, Ground, Mesh, Model, Plane, Sphere, Sun
+from ..models.configs import CameraData, ObjectData, SceneData
+from ..models.robots import get_robot_model_class
+from ..utils import Tf2Broadcaster, Tf2DynamicBroadcaster
+from ..utils.conversions import quat_to_wxyz
+from ..utils.gazebo import (
+ transform_move_to_model_pose,
+ transform_move_to_model_position,
+)
+from ..utils.types import Point, Pose
+
+
+# TODO: Split scene randomizer and scene
+class Scene:
+ """
+ Manages the simulation scene during runtime, including object and environment randomization.
+
+ The Scene class initializes and manages various components in a Gazebo simulation
+ environment, such as robots, cameras, lights, and terrain. It also provides methods
+ for scene setup, parameter retrieval, and scene initialization.
+
+ Attributes:
+ gazebo (GazeboSimulator): An instance of the Gazebo simulator.
+ robot (RobotData): The robot data configuration.
+ cameras (list[CameraData]): List of camera configurations.
+ objects (list[ObjectData]): List of object configurations in the scene.
+ node (Node): The ROS 2 node for communication.
+ light (LightData): The light configuration for the scene.
+ terrain (TerrainData): The terrain configuration for the scene.
+ world (World): The Gazebo world instance.
+ """
+
+ def __init__(
+ self,
+ node: Node,
+ gazebo: GazeboSimulator,
+ scene: SceneData,
+ robot_urdf_string: str,
+ ) -> None:
+ """
+ Initializes the Scene object with the necessary components and parameters.
+
+ Args:
+ - node (Node): The ROS 2 node for managing communication and parameters.
+ - gazebo (GazeboSimulator): An instance of the Gazebo simulator to manage the simulation.
+ - scene (SceneData): A data object containing configuration for the scene.
+ - robot_urdf_string (str): The URDF string of the robot model.
+ """
+ self.gazebo = gazebo
+ self.robot = scene.robot
+ self.cameras = scene.camera
+ self.objects = scene.objects
+ self.node = node
+ self.light = scene.light
+ self.terrain = scene.terrain
+ self.__scene_initialized = False
+ self.__plugin_scene_broadcaster = True
+ self.__plugin_user_commands = True
+ self.__fts_enable = True
+ self.__plugint_sensor_render_engine = "ogre2"
+ self.world: World = gazebo.get_world()
+
+ self.__objects_unique_names = {}
+ self.object_models: list[ModelWrapper] = []
+ self.__object_positions: dict[str, Point] = {}
+ self.__objects_workspace_centre: Point = (0.0, 0.0, 0.0)
+
+ if not robot_urdf_string:
+ self.param_client = node.create_client(
+ GetParameters, "/robot_state_publisher/get_parameters"
+ )
+ self.get_parameter_sync()
+ # scene.robot.urdf_string = (
+ # node.get_parameter("robot_description")
+ # .get_parameter_value()
+ # .string_value
+ # )
+ else:
+ scene.robot.urdf_string = robot_urdf_string
+
+ self.tf2_broadcaster = Tf2Broadcaster(node=node)
+ self.tf2_broadcaster_dyn = Tf2DynamicBroadcaster(node=node)
+
+ self.__workspace_center = (0.0, 0.0, 0.0)
+ self.__workspace_dimensions = (1.5, 1.5, 1.5)
+
+ half_width = self.__workspace_dimensions[0] / 2
+ half_height = self.__workspace_dimensions[1] / 2
+ half_depth = self.__workspace_dimensions[2] / 2
+
+ self.__workspace_min_bound = (
+ self.__workspace_center[0] - half_width,
+ self.__workspace_center[1] - half_height,
+ self.__workspace_center[2] - half_depth,
+ )
+ self.__workspace_max_bound = (
+ self.__workspace_center[0] + half_width,
+ self.__workspace_center[1] + half_height,
+ self.__workspace_center[2] + half_depth,
+ )
+
+ # self.marker_array = MarkerArray()
+
+ def get_parameter_sync(self):
+ """
+ Retrieves the `robot_description` parameter synchronously, waiting for the asynchronous call to complete.
+ """
+ while not self.param_client.wait_for_service(timeout_sec=5.0):
+ self.node.get_logger().info(
+ "Service /robot_state_publisher/get_parameters is unavailable, waiting ..."
+ )
+
+ request = self.create_get_parameters_request()
+ future = self.param_client.call_async(request)
+
+ rclpy.spin_until_future_complete(self.node, future)
+
+ if future.result() is None:
+ raise RuntimeError("Failed to retrieve the robot_description parameter")
+
+ response = future.result()
+ if not response.values or not response.values[0].string_value:
+ raise RuntimeError("The robot_description parameter is missing or empty")
+
+ value = response.values[0]
+ self.node.get_logger().info(
+ "Succesfully got parameter response from robot_state_publisher"
+ )
+ param = Parameter(
+ "robot_description", Parameter.Type.STRING, value.string_value
+ )
+ self.robot.urdf_string = value.string_value
+ self.node.set_parameters([param])
+
+ def create_get_parameters_request(self) -> GetParameters.Request:
+ """
+ Creates a request to get the parameters from the robot state publisher.
+
+ Returns:
+ GetParameters.Request: A request object containing the parameter name to retrieve.
+ """
+ return GetParameters.Request(names=["robot_description"])
+
+ def init_scene(self):
+ """
+ Initializes the simulation scene by adding the terrain, lights, robots, and objects.
+
+ This method pauses the Gazebo simulation, sets up the physics engine, and
+ configures the environment and components as defined in the SceneData.
+
+ Note:
+ This method must be called after the Scene object has been initialized
+ and before running the simulation.
+ """
+
+ self.gazebo_paused_run()
+
+ self.init_world_plugins()
+
+ self.world.set_physics_engine(PhysicsEngine_dart)
+
+ self.add_terrain()
+ self.add_light()
+ self.add_robot()
+ for obj in self.objects:
+ self.add_object(obj)
+ for camera in self.cameras:
+ self.add_camera(camera)
+
+ self.__scene_initialized = True
+
+ def reset_scene(self):
+ """
+ Resets the scene to its initial state by resetting the robot joint positions
+ and randomizing the positions of objects and cameras.
+
+ This method performs the following actions:
+ - Resets the joint positions of the robot.
+ - Clears previously stored object positions.
+ - For each object in the scene, if randomization is enabled,
+ randomizes its position; otherwise, resets it to its original pose.
+ - Randomizes the pose of enabled cameras if their pose has expired.
+
+ Note:
+ This method assumes that the objects and cameras have been properly initialized
+ and that their properties (like randomization and pose expiration) are set correctly.
+ """
+ self.reset_robot_joint_positions()
+
+ object_randomized = set()
+ self.__object_positions.clear()
+ for object in self.objects:
+ if object.name in object_randomized:
+ continue
+ # TODO: Add color randomization (something that affects the runtime during the second spawn of an object)
+ # if object.randomize.random_color:
+ # self.randomize_object_color(object)
+ if object.randomize.random_pose:
+ self.randomize_object_position(object)
+ else:
+ self.reset_objects_pose(object)
+ object_randomized.add(object.name)
+
+ for camera in self.cameras:
+ if camera.enable and self.is_camera_pose_expired(camera):
+ self.randomize_camera_pose(camera=camera)
+
+ def init_world_plugins(self):
+ """
+ Initializes and inserts world plugins into the Gazebo simulator for various functionalities.
+
+ This method configures the following plugins based on the internal settings:
+ - **SceneBroadcaster**: Inserts the scene broadcaster plugin to enable GUI clients to receive
+ updates about the scene. This is only done if the scene broadcaster is enabled.
+ - **UserCommands**: Inserts the user commands plugin to allow user interactions with the scene.
+ - **Sensors**: Inserts the sensors plugin if any cameras are enabled. The rendering engine for
+ sensors can be specified.
+ - **ForceTorque**: Inserts the Force Torque Sensor plugin if enabled.
+
+ The method pauses the Gazebo simulation after inserting plugins to ensure that changes take effect.
+
+ Note:
+ If plugins are already active, this method does not reinsert them.
+ """
+ # SceneBroadcaster
+ if self.__plugin_scene_broadcaster:
+ if not self.gazebo.scene_broadcaster_active(self.world.to_gazebo().name()):
+ self.node.get_logger().info(
+ "Inserting world plugins for broadcasting scene to GUI clients..."
+ )
+ self.world.to_gazebo().insert_world_plugin(
+ "gz-sim-scene-broadcaster-system",
+ "gz::sim::systems::SceneBroadcaster",
+ )
+
+ self.gazebo_paused_run()
+
+ # UserCommands
+ if self.__plugin_user_commands:
+ self.node.get_logger().info(
+ "Inserting world plugins to enable user commands..."
+ )
+ self.world.to_gazebo().insert_world_plugin(
+ "gz-sim-user-commands-system",
+ "gz::sim::systems::UserCommands",
+ )
+
+ self.gazebo_paused_run()
+
+ # Sensors
+ self._camera_enable = False
+ for camera in self.cameras:
+ if camera.enable:
+ self._camera_enable = True
+
+ if self._camera_enable:
+ self.node.get_logger().info(
+ f"Inserting world plugins for sensors with {self.__plugint_sensor_render_engine} rendering engine..."
+ )
+ self.world.to_gazebo().insert_world_plugin(
+ "libgz-sim-sensors-system.so",
+ "gz::sim::systems::Sensors",
+ ""
+ f"{self.__plugint_sensor_render_engine}"
+ "",
+ )
+
+ if self.__fts_enable:
+ self.world.to_gazebo().insert_world_plugin(
+ "gz-sim-forcetorque-system",
+ "gz::sim::systems::ForceTorque",
+ )
+
+ self.gazebo_paused_run()
+
+ def add_robot(self):
+ """
+ Configure and insert the robot into the simulation.
+
+ This method performs the following actions:
+ - Retrieves the robot model class based on the robot's name.
+ - Instantiates the robot model with specified parameters including position, orientation,
+ URDF string, and initial joint positions for both the arm and gripper.
+ - Ensures that the instantiated robot's name is unique and returns the actual name used.
+ - Enables contact detection for all gripper links (fingers) of the robot.
+ - Pauses the Gazebo simulation to apply the changes.
+ - Resets the robot's joints to their default positions.
+
+ Note:
+ The method assumes that the robot name is valid and that the associated model class
+ can be retrieved. If the robot's URDF string or joint positions are not correctly specified,
+ this may lead to errors during robot instantiation.
+
+ Raises:
+ RuntimeError: If the robot model class cannot be retrieved for the specified robot name.
+ """
+
+ robot_model_class = get_robot_model_class(self.robot.name)
+
+ # Instantiate robot class based on the selected model
+ self.robot_wrapper = robot_model_class(
+ world=self.world,
+ name=self.robot.name,
+ node=self.node,
+ position=self.robot.spawn_position,
+ orientation=quat_to_wxyz(self.robot.spawn_quat_xyzw),
+ urdf_string=self.robot.urdf_string,
+ initial_arm_joint_positions=self.robot.joint_positions,
+ initial_gripper_joint_positions=self.robot.gripper_joint_positions,
+ )
+
+ # The desired name is passed as arg on creation, however, a different name might be selected to be unique
+ # Therefore, return the name back to the task
+ self.robot_name = self.robot_wrapper.name()
+
+ # Enable contact detection for all gripper links (fingers)
+ robot_gazebo = self.robot_wrapper.to_gazebo()
+ robot_gazebo.enable_contacts()
+
+ self.gazebo_paused_run()
+
+ # Reset robot joints to the defaults
+ self.reset_robot_joint_positions()
+
+ def add_camera(self, camera: CameraData):
+ """
+ Configure and insert a camera into the simulation, placing it with respect to the robot.
+
+ The method performs the following steps:
+ - Determines the appropriate position and orientation of the camera based on whether it is
+ relative to the world or to the robot.
+ - Creates an instance of the `Camera` class with the specified parameters, including
+ position, orientation, type, dimensions, image formatting, update rate, field of view,
+ clipping properties, noise characteristics, and ROS 2 bridge publishing options.
+ - Pauses the Gazebo simulation to apply changes.
+ - Attaches the camera to the robot if the camera's reference is not the world.
+ - Broadcasts the transformation (TF) of the camera relative to the robot.
+
+ Args:
+ - camera (CameraData): The camera data containing configuration options including:
+
+ Raises:
+ - Exception: If the camera cannot be attached to the robot.
+
+ Notes:
+ The method expects the camera data to be correctly configured. If the camera's
+ `relative_to` field does not match any existing model in the world, it will raise an
+ exception when attempting to attach the camera to the robot.
+ """
+
+ if self.world.to_gazebo().name() == camera.relative_to:
+ camera_position = camera.spawn_position
+ camera_quat_wxyz = quat_to_wxyz(camera.spawn_quat_xyzw)
+ else:
+ # Transform the pose of camera to be with respect to robot - but still represented in world reference frame for insertion into the world
+ camera_position, camera_quat_wxyz = transform_move_to_model_pose(
+ world=self.world,
+ position=camera.spawn_position,
+ quat=quat_to_wxyz(camera.spawn_quat_xyzw),
+ target_model=self.robot_wrapper,
+ target_link=camera.relative_to,
+ xyzw=False,
+ )
+
+ # Create camera
+ self.camera_wrapper = Camera(
+ name=camera.name,
+ world=self.world,
+ position=camera_position,
+ orientation=camera_quat_wxyz,
+ camera_type=camera.type,
+ width=camera.width,
+ height=camera.height,
+ image_format=camera.image_format,
+ update_rate=camera.update_rate,
+ horizontal_fov=camera.horizontal_fov,
+ vertical_fov=camera.vertical_fov,
+ clip_color=camera.clip_color,
+ clip_depth=camera.clip_depth,
+ noise_mean=camera.noise_mean,
+ noise_stddev=camera.noise_stddev,
+ ros2_bridge_color=camera.publish_color,
+ ros2_bridge_depth=camera.publish_depth,
+ ros2_bridge_points=camera.publish_points,
+ )
+
+ self.gazebo_paused_run()
+
+ # Attach to robot if needed
+ if self.world.to_gazebo().name() != camera.relative_to:
+ if not self.robot_wrapper.to_gazebo().attach_link(
+ camera.relative_to,
+ self.camera_wrapper.name(),
+ self.camera_wrapper.link_name,
+ ):
+ raise Exception("Cannot attach camera link to robot")
+
+ self.__is_camera_attached = True
+
+ self.gazebo_paused_run()
+
+ # Broadcast tf
+ self.tf2_broadcaster.broadcast_tf(
+ parent_frame_id=camera.relative_to,
+ child_frame_id=self.camera_wrapper.frame_id,
+ translation=camera.spawn_position,
+ rotation=camera.spawn_quat_xyzw,
+ xyzw=True,
+ )
+
+ def add_object(self, obj: ObjectData):
+ """
+ Configure and insert an object into the simulation.
+
+ Args:
+ obj (ObjectData): The object data containing configuration options.
+
+ Raises:
+ NotImplementedError: If the specified type is not supported.
+ """
+ obj_dict = asdict(obj)
+ obj_dict["world"] = self.world
+
+ try:
+ object_wrapper = self._create_object_wrapper(obj, obj_dict)
+ self._register_object(obj, object_wrapper)
+ self._enable_contact_detection(object_wrapper)
+
+ self.node.get_logger().info(
+ f"Added object: {obj.name}, relative_to: {obj.relative_to}"
+ )
+ self.node.get_logger().info(
+ f"Position: {obj.position}, Orientation: {obj.orientation}"
+ )
+
+ self._initialize_tf_and_markers(obj, object_wrapper)
+
+ except Exception as ex:
+ self.node.get_logger().warn(f"Model could not be inserted. Reason: {ex}")
+
+ def _create_object_wrapper(self, obj: ObjectData, obj_dict: dict) -> ModelWrapper:
+ """Create the appropriate object wrapper based on the type."""
+ if obj.type != "":
+ return self._create_by_type(obj.type, obj_dict)
+ else:
+ return self._create_by_data_class(obj, obj_dict)
+
+ def _create_by_type(self, obj_type: str, obj_dict: dict) -> ModelWrapper:
+ match obj_type:
+ case "box":
+ return Box(**obj_dict)
+ case "plane":
+ return Plane(**obj_dict)
+ case "sphere":
+ return Sphere(**obj_dict)
+ case "cylinder":
+ return Cylinder(**obj_dict)
+ case "model":
+ return Model(**obj_dict)
+ case "mesh":
+ return Mesh(**obj_dict)
+ case _:
+ raise NotImplementedError(f"Unsupported object type: {obj_type}")
+
+ def _create_by_data_class(self, obj: ObjectData, obj_dict: dict) -> ModelWrapper:
+ from ..models.configs import (
+ BoxObjectData,
+ CylinderObjectData,
+ MeshData,
+ ModelData,
+ PlaneObjectData,
+ SphereObjectData,
+ )
+
+ match obj:
+ case BoxObjectData():
+ return Box(**obj_dict)
+ case PlaneObjectData():
+ return Plane(**obj_dict)
+ case SphereObjectData():
+ return Sphere(**obj_dict)
+ case CylinderObjectData():
+ return Cylinder(**obj_dict)
+ case ModelData():
+ return Model(**obj_dict)
+ case MeshData():
+ return Mesh(**obj_dict)
+ case _:
+ raise NotImplementedError("Unsupported object data class")
+
+ def _register_object(self, obj: ObjectData, wrapper: ModelWrapper):
+ model_name = wrapper.name()
+ self.__objects_unique_names.setdefault(obj.name, []).append(model_name)
+ self.__object_positions[model_name] = obj.position
+ self.object_models.append(wrapper.model)
+
+ def _enable_contact_detection(self, wrapper: ModelWrapper):
+ for link_name in wrapper.link_names():
+ link = wrapper.to_gazebo().get_link(link_name=link_name)
+ link.enable_contact_detection(True)
+
+ def _initialize_tf_and_markers(self, obj: ObjectData, wrapper: ModelWrapper):
+ self._setup_marker_array()
+ marker = self._create_marker(obj, wrapper, len(self.object_models))
+ if marker:
+ self.marker_array.markers.append(marker)
+ # Start broadcasting once when len of markers == 1
+ if len(self.marker_array.markers) == 1:
+ self._setup_tf_broadcaster()
+
+ def _setup_tf_broadcaster(self):
+ if not hasattr(self, "tf_broadcaster_and_markers"):
+ self.tf_broadcaster_and_markers = self.node.create_timer(
+ 0.1, self.publish_marker
+ )
+
+ def _setup_marker_array(self):
+ if not hasattr(self, "marker_array"):
+ self.marker_array = MarkerArray()
+ self.marker_publisher = self.node.create_publisher(
+ MarkerArray, "/markers", 10
+ )
+
+ def _create_marker(
+ self, obj: ObjectData, wrapper: ModelWrapper, id: int
+ ) -> Marker | None:
+ """Create a visualization marker based on the object type."""
+ marker = Marker()
+ marker.header.frame_id = obj.relative_to
+ marker.header.stamp = self.node.get_clock().now().to_msg()
+ marker.ns = "assembly"
+ marker.id = id
+
+ type_map = {
+ "mesh": Marker.MESH_RESOURCE,
+ "model": Marker.MESH_RESOURCE,
+ "sphere": Marker.SPHERE,
+ "box": Marker.CUBE,
+ "cylinder": Marker.CYLINDER,
+ }
+
+ if obj.type not in type_map:
+ self.node.get_logger().fatal(f"Unsupported object type: {obj.type}")
+ return None
+
+ marker.type = type_map[obj.type]
+ self._configure_marker_geometry_and_color(marker, obj)
+
+ model_position, model_orientation = self.rotate_around_parent(
+ wrapper.base_position(),
+ wrapper.base_orientation(),
+ (0.0, 0.0, 0.0),
+ (1.0, 0.0, 0.0, 0.0),
+ )
+
+ position = RosPoint()
+ orientation = RosQuat()
+ position.x = model_position[0]
+ position.y = model_position[1]
+ position.z = model_position[2]
+
+ orientation.w = model_orientation[0]
+ orientation.x = model_orientation[1]
+ orientation.y = model_orientation[2]
+ orientation.z = model_orientation[3]
+
+ marker.pose.position = position
+ marker.pose.orientation = orientation
+ marker.action = Marker.ADD
+
+ return marker
+
+ def _configure_marker_geometry_and_color(self, marker: Marker, obj: ObjectData):
+ """Set marker geometry and color."""
+ if obj.type in ["mesh", "model"]:
+ marker.mesh_resource = f"file://{get_model_meshes_info(obj.name)['visual']}"
+ marker.mesh_use_embedded_materials = True
+ marker.scale.x = marker.scale.y = marker.scale.z = 1.0
+ elif obj.type == "sphere":
+ marker.scale.x = marker.scale.y = marker.scale.z = obj.radius
+ elif obj.type == "box":
+ marker.scale.x, marker.scale.y, marker.scale.z = obj.size
+ elif obj.type == "cylinder":
+ marker.scale.x = marker.scale.y = obj.radius * 2
+ marker.scale.z = obj.length
+
+ marker.color.r, marker.color.g, marker.color.b, marker.color.a = (
+ 0.0,
+ 0.5,
+ 1.0,
+ 1.0,
+ )
+
+ def publish_marker(self):
+ if not self.marker_array:
+ return
+
+ for idx, model in enumerate(self.object_models):
+ if self.marker_array.markers[idx]:
+ model_position, model_orientation = (
+ model.base_position(),
+ model.base_orientation(),
+ )
+
+ position = RosPoint()
+ orientation = RosQuat()
+ position.x = model_position[0]
+ position.y = model_position[1]
+ position.z = model_position[2]
+
+ orientation.w = model_orientation[0]
+ orientation.x = model_orientation[1]
+ orientation.y = model_orientation[2]
+ orientation.z = model_orientation[3]
+ self.marker_array.markers[idx].pose.position = position
+ self.marker_array.markers[idx].pose.orientation = orientation
+
+ self.tf2_broadcaster_dyn.broadcast_tf(
+ parent_frame_id="world",
+ child_frame_id=model.base_frame(),
+ translation=model_position,
+ rotation=model_orientation,
+ xyzw=False,
+ )
+ self.marker_publisher.publish(self.marker_array)
+
+ def rotate_around_parent(
+ self, position, orientation, parent_position, parent_orientation
+ ):
+ """
+ Apply a 180-degree rotation around the parent's Z-axis to the object's pose.
+
+ Args:
+ position: Tuple (x, y, z) - the object's position relative to the parent.
+ orientation: Tuple (w, x, y, z) - the object's orientation relative to the parent.
+ parent_position: Tuple (x, y, z) - the parent's position.
+ parent_orientation: Tuple (w, x, y, z) - the parent's orientation.
+
+ Returns:
+ A tuple containing the updated position and orientation:
+ (new_position, new_orientation)
+ """
+ # Define a quaternion for 180-degree rotation around Z-axis (w, x, y, z)
+ rotation_180_z = (0.0, 0.0, 0.0, 1.0)
+ w1, x1, y1, z1 = orientation
+ # Кватернион для 180° поворота вокруг Z
+ w2, x2, y2, z2 = (0, 0, 0, 1)
+
+ # Умножение кватернионов
+ w3 = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
+ x3 = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
+ y3 = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
+ z3 = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
+ new_orientation = (w3, x3, y3, z3)
+
+ # Compute the new orientation: parent_orientation * rotation_180_z * orientation
+ # new_orientation = (
+ # parent_orientation[0] * rotation_180_z[0] - parent_orientation[1] * rotation_180_z[1] -
+ # parent_orientation[2] * rotation_180_z[2] - parent_orientation[3] * rotation_180_z[3],
+ # parent_orientation[0] * rotation_180_z[1] + parent_orientation[1] * rotation_180_z[0] +
+ # parent_orientation[2] * rotation_180_z[3] - parent_orientation[3] * rotation_180_z[2],
+ # parent_orientation[0] * rotation_180_z[2] - parent_orientation[1] * rotation_180_z[3] +
+ # parent_orientation[2] * rotation_180_z[0] + parent_orientation[3] * rotation_180_z[1],
+ # parent_orientation[0] * rotation_180_z[3] + parent_orientation[1] * rotation_180_z[2] -
+ # parent_orientation[2] * rotation_180_z[1] + parent_orientation[3] * rotation_180_z[0],
+ # )
+
+ # Apply rotation to the position
+ rotated_position = (
+ -position[0], # Rotate 180 degrees around Z-axis: x -> -x
+ -position[1], # y -> -y
+ position[2], # z remains unchanged
+ )
+
+ # Add parent position
+ new_position = (
+ parent_position[0] + rotated_position[0],
+ parent_position[1] + rotated_position[1],
+ parent_position[2] + rotated_position[2],
+ )
+
+ return new_position, new_orientation
+
+ def add_light(self):
+ """
+ Configure and insert a default light source into the simulation.
+
+ This method initializes and adds a light source to the Gazebo simulation based on the
+ specified light type. Currently, it supports a "sun" type light. If an unsupported
+ light type is specified, it raises an error.
+
+ Raises:
+ RuntimeError: If the specified light type is unsupported, including:
+ - "random_sun": Currently not implemented.
+ - Any other unrecognized light type.
+
+ Notes:
+ - The method constructs the light using the properties defined in the `self.light` object,
+ which should include direction, color, distance, visual properties, and radius,
+ as applicable for the specified light type.
+ - After adding the light to the simulation, the method pauses the Gazebo simulation
+ to ensure the scene is updated.
+ """
+
+ # Create light
+ match self.light.type:
+ case "sun":
+ self.light = Sun(
+ self.world,
+ direction=self.light.direction,
+ color=self.light.color,
+ distance=self.light.distance,
+ visual=self.light.visual,
+ radius=self.light.radius,
+ )
+ case "random_sun":
+ raise RuntimeError("random_sun is not supported yet")
+ case _:
+ raise RuntimeError(
+ f"Type of light [{self.light.type}] currently not supported"
+ )
+
+ self.gazebo_paused_run()
+
+ def add_terrain(self):
+ """
+ Configure and insert default terrain into the simulation.
+
+ This method initializes and adds a terrain element to the Gazebo simulation using
+ the properties specified in the `self.terrain` object. Currently, it creates a
+ ground plane, but there are plans to support additional terrain types in the future.
+
+ Raises:
+ NotImplementedError: If there are attempts to add terrain types other than
+ the default ground.
+
+ Notes:
+ - The method utilizes properties such as name, spawn position, spawn quaternion
+ (for orientation), and size defined in `self.terrain` to create the terrain.
+ - After the terrain is added, contact detection is enabled for all links associated
+ with the terrain object.
+ - The method pauses the Gazebo simulation after adding the terrain to ensure
+ changes are reflected in the environment.
+ """
+
+ if self.terrain.type == "flat":
+ self.terrain_wrapper = Ground(
+ self.world,
+ name=self.terrain.name,
+ position=self.terrain.spawn_position,
+ orientation=quat_to_wxyz(self.terrain.spawn_quat_xyzw),
+ size=self.terrain.size,
+ ambient=self.terrain.ambient,
+ diffuse=self.terrain.diffuse,
+ specular=self.terrain.specular,
+ )
+ elif self.terrain.type == "random_flat":
+ self.terrain_wrapper = RandomGround(
+ self.world,
+ name=self.terrain.name,
+ position=self.terrain.spawn_position,
+ orientation=quat_to_wxyz(self.terrain.spawn_quat_xyzw),
+ size=self.terrain.size,
+ )
+ else:
+ raise ValueError(
+ "Type of ground is not supported, supported type is: [ground, random_ground]"
+ )
+
+ # Enable contact detection
+ for link_name in self.terrain_wrapper.link_names():
+ link = self.terrain_wrapper.to_gazebo().get_link(link_name=link_name)
+ link.enable_contact_detection(True)
+
+ self.gazebo_paused_run()
+
+ # def randomize_object_color(self, object: ObjectData):
+ # # debugpy.listen(5678)
+ # # debugpy.wait_for_client()
+ # task_object_names = self.__objects_unique_names[object.name]
+ # for object_name in task_object_names:
+ # if not self.world.to_gazebo().remove_model(object_name):
+ # self.node.get_logger().error(f"Failed to remove {object_name}")
+ # raise RuntimeError(f"Failed to remove {object_name}")
+ #
+ # self.gazebo_paused_run()
+ # del self.__objects_unique_names[object.name]
+ # self.add_object(object)
+
+ def randomize_object_position(self, object: ObjectData):
+ """
+ Randomize the position and/or orientation of a specified object in the simulation.
+
+ This method retrieves the names of the object instances from the scene and updates
+ their positions and orientations based on the randomization settings defined in
+ the `object` parameter. The method applies random transformations to the object's
+ pose according to the specified flags in the `ObjectData` instance.
+
+ Args:
+ object (ObjectData): An instance of `ObjectData` containing the properties
+ of the object to be randomized.
+
+ Raises:
+ KeyError: If the object name is not found in the unique object names dictionary.
+
+ Notes:
+ - If `random_pose` is `True`, both position and orientation are randomized.
+ - If `random_orientation` is `True`, only the orientation is randomized, while
+ the position remains unchanged.
+ - If `random_position` is `True`, only the position is randomized, while the
+ orientation remains unchanged.
+ """
+ task_object_names = self.__objects_unique_names[object.name]
+ for object_name in task_object_names:
+ position, quat_random = self.get_random_object_pose(object)
+ obj = self.world.to_gazebo().get_model(object_name).to_gazebo()
+ if object.randomize.random_pose:
+ obj.reset_base_pose(position, quat_random)
+ elif object.randomize.random_orientation:
+ obj.reset_base_pose(object.position, quat_random)
+ elif object.randomize.random_position:
+ obj.reset_base_pose(position, object.orientation)
+
+ self.__object_positions[object_name] = position
+
+ self.gazebo_paused_run()
+
+ def randomize_camera_pose(self, camera: CameraData):
+ """
+ Randomize the pose of a specified camera in the simulation.
+
+ This method updates the camera's position and orientation based on the selected
+ randomization mode. The new pose is computed relative to either the world frame or
+ a specified target model (e.g., the robot). The camera can be detached from the robot
+ during the pose update and reattached afterward if needed.
+
+ Args:
+ camera (CameraData): An instance of `CameraData` containing properties of the
+ camera to be randomized.
+
+ Raises:
+ TypeError: If the randomization mode specified in `camera.random_pose_mode` is
+ invalid.
+ Exception: If the camera cannot be detached from or attached to the robot.
+
+ Notes:
+ - Supported randomization modes for camera pose:
+ - `"orbit"`: The camera is positioned in an orbital path around the object.
+ - `"select_random"`: A random camera pose is selected from predefined options.
+ - `"select_nearest"`: The camera is positioned based on the nearest predefined
+ camera pose.
+ """
+ # Get random camera pose, centered at object position (or center of object spawn box)
+ if "orbit" == camera.random_pose_mode:
+ camera_position, camera_quat_xyzw = self.get_random_camera_pose_orbit(
+ camera
+ )
+ elif "select_random" == camera.random_pose_mode:
+ (
+ camera_position,
+ camera_quat_xyzw,
+ ) = self.get_random_camera_pose_sample_random(camera)
+ elif "select_nearest" == camera.random_pose_mode:
+ (
+ camera_position,
+ camera_quat_xyzw,
+ ) = self.get_random_camera_pose_sample_nearest(camera)
+ else:
+ raise TypeError("Invalid mode for camera pose randomization.")
+
+ if self.world.to_gazebo().name() == camera.relative_to:
+ transformed_camera_position = camera_position
+ transformed_camera_quat_wxyz = quat_to_wxyz(camera_quat_xyzw)
+ else:
+ # Transform the pose of camera to be with respect to robot - but represented in world reference frame for insertion into the world
+ (
+ transformed_camera_position,
+ transformed_camera_quat_wxyz,
+ ) = transform_move_to_model_pose(
+ world=self.world,
+ position=camera_position,
+ quat=quat_to_wxyz(camera_quat_xyzw),
+ target_model=self.robot_wrapper,
+ target_link=camera.relative_to,
+ xyzw=False,
+ )
+
+ # Detach camera if needed
+ if self.__is_camera_attached:
+ if not self.robot_wrapper.to_gazebo().detach_link(
+ camera.relative_to,
+ self.camera_wrapper.name(),
+ self.camera_wrapper.link_name,
+ ):
+ raise Exception("Cannot detach camera link from robot")
+
+ self.gazebo_paused_run()
+
+ # Move pose of the camera
+ camera_gazebo = self.camera_wrapper.to_gazebo()
+ camera_gazebo.reset_base_pose(
+ transformed_camera_position, transformed_camera_quat_wxyz
+ )
+
+ self.gazebo_paused_run()
+
+ # Attach to robot again if needed
+ if self.__is_camera_attached:
+ if not self.robot_wrapper.to_gazebo().attach_link(
+ camera.relative_to,
+ self.camera_wrapper.name(),
+ self.camera_wrapper.link_name,
+ ):
+ raise Exception("Cannot attach camera link to robot")
+
+ self.gazebo_paused_run()
+
+ # Broadcast tf
+ self.tf2_broadcaster.broadcast_tf(
+ parent_frame_id=camera.relative_to,
+ child_frame_id=self.camera_wrapper.frame_id,
+ translation=camera_position,
+ rotation=camera_quat_xyzw,
+ xyzw=True,
+ )
+
+ def get_random_camera_pose_orbit(self, camera: CameraData) -> Pose:
+ """
+ Generate a random camera pose in an orbital path around a focal point.
+
+ This method computes a random position for the camera based on an orbital distance
+ and height range. The camera's orientation is calculated to focus on a specified
+ focal point. The generated position ensures that it does not fall within a specified
+ arc behind the robot.
+
+ Args:
+ camera (CameraData): An instance of `CameraData` containing properties of the
+ camera, including the orbital distance, height range,
+ focal point offset, and arc behind the robot to ignore.
+
+ Returns:
+ Pose: A tuple containing the randomly generated position and orientation
+ (as a quaternion) of the camera.
+
+ Notes:
+ - The method utilizes a random number generator to create a uniformly
+ distributed position within specified bounds.
+ - The computed camera orientation is represented in the quaternion format.
+ """
+ rng = np.random.default_rng()
+
+ while True:
+ position = rng.uniform(
+ low=(-1.0, -1.0, camera.random_pose_orbit_height_range[0]),
+ high=(1.0, 1.0, camera.random_pose_orbit_height_range[1]),
+ )
+
+ norm = np.linalg.norm(position)
+ position /= norm
+
+ if (
+ abs(np.arctan2(position[0], position[1]) + np.pi / 2)
+ > camera.random_pose_orbit_ignore_arc_behind_robot
+ ):
+ break
+
+ rpy = np.array(
+ [
+ 0.0,
+ np.arctan2(
+ position[2] - camera.random_pose_focal_point_z_offset,
+ np.linalg.norm(position[:2]),
+ ),
+ np.arctan2(position[1], position[0]) + np.pi,
+ ]
+ )
+ quat_xyzw = Rotation.from_euler("xyz", rpy).as_quat()
+
+ position *= camera.random_pose_orbit_distance
+ position[:2] += self.__objects_workspace_centre[:2]
+
+ return tuple(position), tuple(quat_xyzw)
+
+ def get_random_camera_pose_sample_random(self, camera: CameraData) -> Pose:
+ """
+ Select a random position from the predefined camera position options.
+
+ This method randomly selects one of the available camera position options
+ specified in the `camera` object and processes it to generate the camera's pose.
+
+ Args:
+ camera (CameraData): An instance of `CameraData` containing the available
+ position options for random selection.
+
+ Returns:
+ Pose: A tuple containing the selected camera position and its orientation
+ (as a quaternion).
+ """
+ # Select a random entry from the options
+ camera.random_pose_select_position_options
+
+ selection = camera.random_pose_select_position_options[
+ np.random.randint(len(camera.random_pose_select_position_options))
+ ]
+
+ return self.get_random_camera_pose_sample_process(camera, selection)
+
+ def get_random_camera_pose_sample_nearest(self, camera: CameraData) -> Pose:
+ """
+ Select the nearest entry from the predefined camera position options.
+
+ This method calculates the squared distances of the camera position options
+ from a central point and selects the nearest one to generate the camera's pose.
+
+ Args:
+ camera (CameraData): An instance of `CameraData` containing the position options
+ to choose from.
+
+ Returns:
+ Pose: A tuple containing the nearest camera position and its orientation
+ (as a quaternion).
+ """
+ # Select the nearest entry
+ dist_sqr = np.sum(
+ (
+ np.array(camera.random_pose_select_position_options)
+ - np.array(camera.random_pose_select_position_options)
+ )
+ ** 2,
+ axis=1,
+ )
+ nearest = camera.random_pose_select_position_options[np.argmin(dist_sqr)]
+
+ return self.get_random_camera_pose_sample_process(camera, nearest)
+
+ def get_random_camera_pose_sample_process(
+ self, camera: CameraData, position: Point
+ ) -> Pose:
+ """
+ Process the selected camera position to determine its orientation.
+
+ This method calculates the camera's orientation based on the selected position,
+ ensuring that the camera faces a focal point adjusted by a specified offset.
+
+ Args:
+ camera (CameraData): An instance of `CameraData` that contains properties
+ used for orientation calculations.
+ position (Point): A 3D point representing the camera's position.
+
+ Returns:
+ Pose: A tuple containing the camera's position and its orientation
+ (as a quaternion).
+
+ Notes:
+ - The orientation is computed using roll, pitch, and yaw angles based on
+ the position relative to the workspace center.
+ """
+ # Determine orientation such that camera faces the center
+ rpy = [
+ 0.0,
+ np.arctan2(
+ position[2] - camera.random_pose_focal_point_z_offset,
+ np.linalg.norm(
+ (
+ position[0] - self.__objects_workspace_centre[0],
+ position[1] - self.__objects_workspace_centre[1],
+ ),
+ 2,
+ ),
+ ),
+ np.arctan2(
+ position[1] - self.__objects_workspace_centre[1],
+ position[0] - self.__objects_workspace_centre[0],
+ )
+ + np.pi,
+ ]
+ quat_xyzw = Rotation.from_euler("xyz", rpy).as_quat()
+
+ return camera.spawn_position, tuple(quat_xyzw)
+
+ def reset_robot_pose(self):
+ """
+ Reset the robot's pose to a randomized position or its default spawn position.
+
+ If the robot's randomizer is enabled for pose randomization, this method will
+ generate a new random position within the specified spawn volume and assign
+ a random orientation. Otherwise, it will reset the robot's pose to its default
+ spawn position and orientation.
+
+ The new position is calculated by adding a random offset within the spawn
+ volume to the robot's initial spawn position. The orientation is generated
+ as a random rotation around the z-axis.
+
+ Raises:
+ Exception: If there are issues resetting the robot pose in the Gazebo simulation.
+
+ Notes:
+ - The position is a 3D vector represented as [x, y, z].
+ - The orientation is represented as a quaternion in the form (x, y, z, w).
+ """
+ if self.robot.randomizer.pose:
+ position = [
+ self.robot.spawn_position[0]
+ + np.random.RandomState().uniform(
+ -self.robot.randomizer.spawn_volume[0] / 2,
+ self.robot.randomizer.spawn_volume[0] / 2,
+ ),
+ self.robot.spawn_position[1]
+ + np.random.RandomState().uniform(
+ -self.robot.randomizer.spawn_volume[1] / 2,
+ self.robot.randomizer.spawn_volume[1] / 2,
+ ),
+ self.robot.spawn_position[2]
+ + np.random.RandomState().uniform(
+ -self.robot.randomizer.spawn_volume[2] / 2,
+ self.robot.randomizer.spawn_volume[2] / 2,
+ ),
+ ]
+ quat_xyzw = Rotation.from_euler(
+ "xyz", (0, 0, np.random.RandomState().uniform(-np.pi, np.pi))
+ ).as_quat()
+ quat_xyzw = tuple(quat_xyzw)
+ else:
+ position = self.robot.spawn_position
+ quat_xyzw = self.robot.spawn_quat_xyzw
+
+ gazebo_robot = self.robot_wrapper.to_gazebo()
+ gazebo_robot.reset_base_pose(position, quat_to_wxyz(quat_xyzw))
+
+ self.gazebo_paused_run()
+
+ def reset_robot_joint_positions(self):
+ """
+ Reset the robot's joint positions and velocities to their initial values.
+
+ This method retrieves the robot's initial arm joint positions and applies
+ them to the robot in the Gazebo simulation. If joint position randomization
+ is enabled, it adds Gaussian noise to the joint positions based on the
+ specified standard deviation.
+
+ Additionally, the method resets the velocities of both the arm and gripper
+ joints to zero. After updating the joint states, it executes an unpaused step
+ in Gazebo to process the modifications and update the joint states.
+
+ Raises:
+ RuntimeError: If there are issues resetting joint positions or velocities
+ in the Gazebo simulation, or if the Gazebo step fails.
+
+ Notes:
+ - This method assumes the robot has a gripper only if `self.robot.with_gripper` is true.
+ - Joint positions for the gripper are reset only if gripper actuated joint names are available.
+ """
+ gazebo_robot = self.robot_wrapper.to_gazebo()
+
+ arm_joint_positions = self.robot_wrapper.initial_arm_joint_positions
+
+ # Add normal noise if desired
+ if self.robot.randomizer.joint_positions:
+ for joint_position in arm_joint_positions:
+ joint_position += np.random.RandomState().normal(
+ loc=0.0, scale=self.robot.randomizer.joint_positions_std
+ )
+
+ # Arm joints - apply joint positions zero out velocities
+ if not gazebo_robot.reset_joint_positions(
+ arm_joint_positions, self.robot_wrapper.robot.actuated_joint_names
+ ):
+ raise RuntimeError("Failed to reset robot joint positions")
+ if not gazebo_robot.reset_joint_velocities(
+ [0.0] * len(self.robot_wrapper.robot.actuated_joint_names),
+ self.robot_wrapper.robot.actuated_joint_names,
+ ):
+ raise RuntimeError("Failed to reset robot joint velocities")
+
+ # Gripper joints - apply joint positions zero out velocities
+ if (
+ self.robot.with_gripper
+ and self.robot_wrapper.robot.gripper_actuated_joint_names
+ ):
+ if not gazebo_robot.reset_joint_positions(
+ self.robot_wrapper.initial_gripper_joint_positions,
+ self.robot_wrapper.robot.gripper_actuated_joint_names,
+ ):
+ raise RuntimeError("Failed to reset gripper joint positions")
+ if not gazebo_robot.reset_joint_velocities(
+ [0.0] * len(self.robot_wrapper.robot.gripper_actuated_joint_names),
+ self.robot_wrapper.robot.gripper_actuated_joint_names,
+ ):
+ raise RuntimeError("Failed to reset gripper joint velocities")
+
+ # Execute an unpaused run to process model modification and get new JointStates
+ if not self.gazebo.step():
+ raise RuntimeError("Failed to execute an unpaused Gazebo step")
+
+ # if self.robot.with_gripper:
+ # if (
+ # self.robot_wrapper.CLOSED_GRIPPER_JOINT_POSITIONS
+ # == self.robot.gripper_joint_positions
+ # ):
+ # self.gripper_controller.close()
+ # else:
+ # self.gripper.open()
+
+ def reset_objects_pose(self, object: ObjectData):
+ """
+ Reset the pose of specified objects in the simulation.
+
+ This method updates the base position and orientation of all instances of
+ the specified object type in the Gazebo simulation to their defined
+ positions and orientations.
+
+ Args:
+ object (ObjectData): The object data containing the name and the desired
+ position and orientation for the objects to be reset.
+
+ Raises:
+ KeyError: If the specified object's name is not found in the unique names
+ registry.
+
+ Notes:
+ - The method retrieves the list of unique names for the specified object
+ type and resets the pose of each instance to the provided position and
+ orientation from the `object` data.
+ """
+ task_object_names = self.__objects_unique_names[object.name]
+ for object_name in task_object_names:
+ obj = self.world.to_gazebo().get_model(object_name).to_gazebo()
+ obj.reset_base_pose(object.position, object.orientation)
+
+ self.gazebo_paused_run()
+
+ def get_random_object_pose(
+ self,
+ obj: ObjectData,
+ name: str = "",
+ min_distance_to_other_objects: float = 0.2,
+ min_distance_decay_factor: float = 0.95,
+ ) -> Pose:
+ """
+ Generate a random pose for an object, ensuring it is sufficiently
+ distanced from other objects.
+
+ This method computes a random position within a specified spawn volume for
+ the given object. It checks that the new position does not violate minimum
+ distance constraints from existing objects in the simulation.
+
+ Args:
+ obj (ObjectData): The object data containing the base position,
+ randomization parameters, and relative positioning info.
+ name (str, optional): The name of the object being positioned. Defaults to an empty string.
+ Mic_distance_to_other_objects (float, optional): The minimum required distance
+ from other objects. Defaults to 0.2.
+ min_distance_decay_factor (float, optional): Factor by which the minimum distance
+ will decay if the new position is
+ too close to existing objects. Defaults to 0.95.
+
+ Returns:
+ Pose: A tuple containing the randomly generated position (x, y, z) and a
+ randomly generated orientation (quaternion).
+
+ Notes:
+ - The method uses a uniform distribution to select random coordinates
+ within the object's defined spawn volume.
+ - The quaternion is generated to provide a random orientation for the object.
+ - The method iteratively checks for proximity to other objects and adjusts the
+ minimum distance if necessary.
+ """
+ is_too_close = True
+ # The default value for randomization
+ object_position = obj.position
+ while is_too_close:
+ object_position = (
+ obj.position[0]
+ + np.random.uniform(
+ -obj.randomize.random_spawn_volume[0] / 2,
+ obj.randomize.random_spawn_volume[0] / 2,
+ ),
+ obj.position[1]
+ + np.random.uniform(
+ -obj.randomize.random_spawn_volume[1] / 2,
+ obj.randomize.random_spawn_volume[1] / 2,
+ ),
+ obj.position[2]
+ + np.random.uniform(
+ -obj.randomize.random_spawn_volume[2] / 2,
+ obj.randomize.random_spawn_volume[2] / 2,
+ ),
+ )
+
+ if self.world.to_gazebo().name() != obj.relative_to:
+ # Transform the pose of camera to be with respect to robot - but represented in world reference frame for insertion into the world
+ object_position = transform_move_to_model_position(
+ world=self.world,
+ position=object_position,
+ target_model=self.robot_wrapper,
+ target_link=obj.relative_to,
+ )
+
+ # Check if position is far enough from other
+ is_too_close = False
+ for (
+ existing_object_name,
+ existing_object_position,
+ ) in self.__object_positions.items():
+ if existing_object_name == name:
+ # Do not compare to itself
+ continue
+ if (
+ distance.euclidean(object_position, existing_object_position)
+ < min_distance_to_other_objects
+ ):
+ min_distance_to_other_objects *= min_distance_decay_factor
+ is_too_close = True
+ break
+
+ quat = np.random.uniform(-1, 1, 4)
+ quat /= np.linalg.norm(quat)
+
+ return object_position, tuple(quat)
+
+ def is_camera_pose_expired(self, camera: CameraData) -> bool:
+ """
+ Check if the camera pose has expired and needs to be randomized.
+
+ This method evaluates whether the camera's current pose has reached its
+ limit for randomization based on the specified rollout count. If the
+ camera has exceeded the allowed number of rollouts, it will reset the
+ counter and indicate that the pose should be randomized.
+
+ Args:
+ camera (CameraData): The camera data containing the current rollout
+ counter and the maximum number of rollouts
+ allowed before a new pose is required.
+
+ Returns:
+ bool: True if the camera pose needs to be randomized (i.e., it has
+ expired), False otherwise.
+ """
+
+ if camera.random_pose_rollouts_num == 0:
+ return False
+
+ camera.random_pose_rollout_counter += 1
+
+ if camera.random_pose_rollout_counter >= camera.random_pose_rollouts_num:
+ camera.random_pose_rollout_counter = 0
+ return True
+
+ return False
+
+ def gazebo_paused_run(self):
+ """
+ Execute a run in Gazebo while paused.
+
+ This method attempts to run the Gazebo simulation in a paused state.
+ If the operation fails, it raises a RuntimeError to indicate the failure.
+
+ Raises:
+ RuntimeError: If the Gazebo run operation fails.
+ """
+ if not self.gazebo.run(paused=True):
+ raise RuntimeError("Failed to execute a paused Gazebo run")
+
+ def check_object_outside_workspace(
+ self,
+ object_position: Point,
+ ) -> bool:
+ """
+ Check if an object is outside the defined workspace.
+
+ This method evaluates whether the given object position lies outside the
+ specified workspace boundaries, which are defined by a parallelepiped
+ determined by minimum and maximum bounds.
+
+ Args:
+ object_position (Point): A tuple or list representing the 3D position
+ of the object in the format (x, y, z).
+
+ Returns:
+ bool: True if the object is outside the workspace, False otherwise.
+ """
+
+ return (
+ object_position[0] < self.__workspace_min_bound[0]
+ or object_position[1] < self.__workspace_min_bound[1]
+ or object_position[2] < self.__workspace_min_bound[2]
+ or object_position[0] > self.__workspace_max_bound[0]
+ or object_position[1] > self.__workspace_max_bound[1]
+ or object_position[2] > self.__workspace_max_bound[2]
+ )
+
+ def check_object_overlapping(
+ self,
+ object: ObjectData,
+ allowed_penetration_depth: float = 0.001,
+ terrain_allowed_penetration_depth: float = 0.002,
+ ) -> bool:
+ """
+ Check for overlapping objects in the simulation and reset their positions if necessary.
+
+ This method iterates through all objects and ensures that none of them are overlapping.
+ If an object is found to be overlapping with another object (or the robot), it resets
+ its position to a random valid pose. Collisions or overlaps with terrain are ignored.
+
+ Args:
+ object (ObjectData): The object data to check for overlaps.
+ allowed_penetration_depth (float, optional): The maximum allowed penetration depth
+ between objects. Defaults to 0.001.
+ terrain_allowed_penetration_depth (float, optional): The maximum allowed penetration
+ depth when in contact with terrain. Defaults to 0.002.
+
+ Returns:
+ bool: True if all objects are in valid positions, False if any had to be reset.
+ """
+
+ # Update object positions
+ for object_name in self.__objects_unique_names[object.name]:
+ model = self.world.get_model(object_name).to_gazebo()
+ self.__object_positions[object_name] = model.get_link(
+ link_name=model.link_names()[0]
+ ).position()
+
+ for object_name in self.__objects_unique_names[object.name]:
+ obj = self.world.get_model(object_name).to_gazebo()
+
+ # Make sure the object is inside workspace
+ if self.check_object_outside_workspace(
+ self.__object_positions[object_name]
+ ):
+ position, quat_random = self.get_random_object_pose(
+ object,
+ name=object_name,
+ )
+ obj.reset_base_pose(position, quat_random)
+ obj.reset_base_world_velocity([0.0, 0.0, 0.0], [0.0, 0.0, 0.0])
+ return False
+
+ # Make sure the object is not intersecting other objects
+ try:
+ for contact in obj.contacts():
+ depth = np.mean([point.depth for point in contact.points])
+ if (
+ self.terrain_wrapper.name() in contact.body_b
+ and depth < terrain_allowed_penetration_depth
+ ):
+ continue
+ if (
+ self.robot_wrapper.name() in contact.body_b
+ or depth > allowed_penetration_depth
+ ):
+ position, quat_random = self.get_random_object_pose(
+ object,
+ name=object_name,
+ )
+ obj.reset_base_pose(position, quat_random)
+ obj.reset_base_world_velocity([0.0, 0.0, 0.0], [0.0, 0.0, 0.0])
+ return False
+ except Exception as e:
+ self.node.get_logger().error(
+ f"Runtime error encountered while checking objects intersections: {e}"
+ )
+
+ return True
diff --git a/env_manager/env_manager/env_manager/utils/__init__.py b/env_manager/env_manager/env_manager/utils/__init__.py
new file mode 100644
index 0000000..88b3cc1
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/__init__.py
@@ -0,0 +1,4 @@
+from . import conversions, gazebo, logging, math, types
+from .tf2_broadcaster import Tf2Broadcaster, Tf2BroadcasterStandalone
+from .tf2_dynamic_broadcaster import Tf2DynamicBroadcaster
+from .tf2_listener import Tf2Listener, Tf2ListenerStandalone
diff --git a/env_manager/env_manager/env_manager/utils/conversions.py b/env_manager/env_manager/env_manager/utils/conversions.py
new file mode 100644
index 0000000..064c33d
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/conversions.py
@@ -0,0 +1,182 @@
+from typing import Tuple, Union
+
+import numpy
+# import open3d
+# import pyoctree
+import sensor_msgs
+from scipy.spatial.transform import Rotation
+
+# from sensor_msgs.msg import PointCloud2
+# from open3d.geometry import PointCloud
+from geometry_msgs.msg import Transform
+
+
+# def pointcloud2_to_open3d(
+# ros_point_cloud2: PointCloud2,
+# include_color: bool = False,
+# include_intensity: bool = False,
+# # Note: Order does not matter for DL, that's why channel swapping is disabled by default
+# fix_rgb_channel_order: bool = False,
+# ) -> PointCloud:
+#
+# # Create output Open3D PointCloud
+# open3d_pc = PointCloud()
+#
+# size = ros_point_cloud2.width * ros_point_cloud2.height
+# xyz_dtype = ">f4" if ros_point_cloud2.is_bigendian else " 3:
+# bgr = numpy.ndarray(
+# shape=(size, 3),
+# dtype=numpy.uint8,
+# buffer=ros_point_cloud2.data,
+# offset=ros_point_cloud2.fields[3].offset,
+# strides=(ros_point_cloud2.point_step, 1),
+# )
+# if fix_rgb_channel_order:
+# # Swap channels to gain rgb (faster than `bgr[:, [2, 1, 0]]`)
+# bgr[:, 0], bgr[:, 2] = bgr[:, 2], bgr[:, 0].copy()
+# open3d_pc.colors = open3d.utility.Vector3dVector(
+# (bgr[valid_points] / 255).astype(numpy.float64)
+# )
+# else:
+# open3d_pc.colors = open3d.utility.Vector3dVector(
+# numpy.zeros((len(valid_points), 3), dtype=numpy.float64)
+# )
+# # TODO: Update octree creator once L8 image format is supported in Ignition Gazebo
+# # elif include_intensity:
+# # # Faster approach, but only the first channel gets the intensity value (rest is 0)
+# # intensities = numpy.zeros((len(valid_points), 3), dtype=numpy.float64)
+# # intensities[:, [0]] = (
+# # numpy.ndarray(
+# # shape=(size, 1),
+# # dtype=numpy.uint8,
+# # buffer=ros_point_cloud2.data,
+# # offset=ros_point_cloud2.fields[3].offset,
+# # strides=(ros_point_cloud2.point_step, 1),
+# # )[valid_points]
+# # / 255
+# # ).astype(numpy.float64)
+# # open3d_pc.colors = open3d.utility.Vector3dVector(intensities)
+# # # # Slower approach, but all channels get the intensity value
+# # # intensities = numpy.ndarray(
+# # # shape=(size, 1),
+# # # dtype=numpy.uint8,
+# # # buffer=ros_point_cloud2.data,
+# # # offset=ros_point_cloud2.fields[3].offset,
+# # # strides=(ros_point_cloud2.point_step, 1),
+# # # )
+# # # open3d_pc.colors = open3d.utility.Vector3dVector(
+# # # numpy.tile(intensities[valid_points] / 255, (1, 3)).astype(numpy.float64)
+# # # )
+#
+# # Return the converted Open3D PointCloud
+# return open3d_pc
+
+
+def transform_to_matrix(transform: Transform) -> numpy.ndarray:
+
+ transform_matrix = numpy.zeros((4, 4))
+ transform_matrix[3, 3] = 1.0
+
+ transform_matrix[0:3, 0:3] = open3d.geometry.get_rotation_matrix_from_quaternion(
+ [
+ transform.rotation.w,
+ transform.rotation.x,
+ transform.rotation.y,
+ transform.rotation.z,
+ ]
+ )
+ transform_matrix[0, 3] = transform.translation.x
+ transform_matrix[1, 3] = transform.translation.y
+ transform_matrix[2, 3] = transform.translation.z
+
+ return transform_matrix
+
+
+# def open3d_point_cloud_to_octree_points(
+# open3d_point_cloud: PointCloud,
+# include_color: bool = False,
+# include_intensity: bool = False,
+# ) -> pyoctree.Points:
+#
+# octree_points = pyoctree.Points()
+#
+# if include_color:
+# features = numpy.reshape(numpy.asarray(open3d_point_cloud.colors), -1)
+# elif include_intensity:
+# features = numpy.asarray(open3d_point_cloud.colors)[:, 0]
+# else:
+# features = []
+#
+# octree_points.set_points(
+# # XYZ points
+# numpy.reshape(numpy.asarray(open3d_point_cloud.points), -1),
+# # Normals
+# numpy.reshape(numpy.asarray(open3d_point_cloud.normals), -1),
+# # Other features, e.g. color
+# features,
+# # Labels - not used
+# [],
+# )
+#
+# return octree_points
+
+
+def orientation_6d_to_quat(
+ v1: Tuple[float, float, float], v2: Tuple[float, float, float]
+) -> Tuple[float, float, float, float]:
+
+ # Normalize vectors
+ col1 = v1 / numpy.linalg.norm(v1)
+ col2 = v2 / numpy.linalg.norm(v2)
+
+ # Find their orthogonal vector via cross product
+ col3 = numpy.cross(col1, col2)
+
+ # Stack into rotation matrix as columns, convert to quaternion and return
+ quat_xyzw = Rotation.from_matrix(numpy.array([col1, col2, col3]).T).as_quat()
+ return quat_xyzw
+
+
+def orientation_quat_to_6d(
+ quat_xyzw: Tuple[float, float, float, float]
+) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]:
+
+ # Convert quaternion into rotation matrix
+ rot_mat = Rotation.from_quat(quat_xyzw).as_matrix()
+
+ # Return first two columns (already normalised)
+ return (tuple(rot_mat[:, 0]), tuple(rot_mat[:, 1]))
+
+
+def quat_to_wxyz(
+ xyzw: tuple[float, float, float, float]
+) -> tuple[float, float, float, float]:
+
+ return (xyzw[3], xyzw[0], xyzw[1], xyzw[2])
+
+ # return xyzw[[3, 0, 1, 2]]
+
+
+def quat_to_xyzw(
+ wxyz: Union[numpy.ndarray, Tuple[float, float, float, float]]
+) -> numpy.ndarray:
+
+ if isinstance(wxyz, tuple):
+ return (wxyz[1], wxyz[2], wxyz[3], wxyz[0])
+
+ return wxyz[[1, 2, 3, 0]]
diff --git a/env_manager/env_manager/env_manager/utils/gazebo.py b/env_manager/env_manager/env_manager/utils/gazebo.py
new file mode 100644
index 0000000..2786de8
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/gazebo.py
@@ -0,0 +1,262 @@
+from typing import Tuple, Union
+
+from gym_gz.scenario.model_wrapper import ModelWrapper
+from numpy import exp
+from scenario.bindings.core import Link, World
+from scipy.spatial.transform import Rotation
+
+from .conversions import quat_to_wxyz, quat_to_xyzw
+from .math import quat_mul
+from .types import Pose, Point, Quat
+
+
+#NOTE: Errors in pyright will be fixed only with pybind11 in the scenario module
+def get_model_pose(
+ world: World,
+ model: ModelWrapper | str,
+ link: Link | str | None = None,
+ xyzw: bool = False,
+) -> Pose:
+ """
+ Return pose of model's link. Orientation is represented as wxyz quaternion or xyzw based on the passed argument `xyzw`.
+ """
+
+ if isinstance(model, str):
+ # Get reference to the model from its name if string is passed
+ model = world.to_gazebo().get_model(model).to_gazebo()
+
+ if link is None:
+ # Use the first link if not specified
+ link = model.get_link(link_name=model.link_names()[0])
+ elif isinstance(link, str):
+ # Get reference to the link from its name if string
+ link = model.get_link(link_name=link)
+
+ # Get position and orientation
+ position = link.position()
+ quat = link.orientation()
+
+ # Convert to xyzw order if desired
+ if xyzw:
+ quat = quat_to_xyzw(quat)
+
+ # Return pose of the model's link
+ return (
+ position,
+ quat,
+ )
+
+
+def get_model_position(
+ world: World,
+ model: ModelWrapper | str,
+ link: Link | str | None = None,
+) -> Point:
+ """
+ Return position of model's link.
+ """
+
+ if isinstance(model, str):
+ # Get reference to the model from its name if string is passed
+ model = world.to_gazebo().get_model(model).to_gazebo()
+
+ if link is None:
+ # Use the first link if not specified
+ link = model.get_link(link_name=model.link_names()[0])
+ elif isinstance(link, str):
+ # Get reference to the link from its name if string
+ link = model.get_link(link_name=link)
+
+ # Return position of the model's link
+ return link.position()
+
+
+def get_model_orientation(
+ world: World,
+ model: ModelWrapper | str,
+ link: Link | str | None = None,
+ xyzw: bool = False,
+) -> Quat:
+ """
+ Return orientation of model's link that is represented as wxyz quaternion or xyzw based on the passed argument `xyzw`.
+ """
+
+ if isinstance(model, str):
+ # Get reference to the model from its name if string is passed
+ model = world.to_gazebo().get_model(model).to_gazebo()
+
+ if link is None:
+ # Use the first link if not specified
+ link = model.get_link(link_name=model.link_names()[0])
+ elif isinstance(link, str):
+ # Get reference to the link from its name if string
+ link = model.get_link(link_name=link)
+
+ # Get orientation
+ quat = link.orientation()
+
+ # Convert to xyzw order if desired
+ if xyzw:
+ quat = quat_to_xyzw(quat)
+
+ # Return orientation of the model's link
+ return quat
+
+
+def transform_move_to_model_pose(
+ world: World,
+ position: Point,
+ quat: Quat,
+ target_model: ModelWrapper | str,
+ target_link: Link | str | None = None,
+ xyzw: bool = False,
+) -> Pose:
+ """
+ Transform such that original `position` and `quat` are represented with respect to `target_model::target_link`.
+ The resulting pose is still represented in world coordinate system.
+ """
+
+ target_frame_position, target_frame_quat = get_model_pose(
+ world,
+ model=target_model,
+ link=target_link,
+ xyzw=True,
+ )
+
+ transformed_position = Rotation.from_quat(target_frame_quat).apply(position)
+ transformed_position = (
+ transformed_position[0] + target_frame_position[0],
+ transformed_position[1] + target_frame_position[1],
+ transformed_position[2] + target_frame_position[2],
+ )
+
+ if not xyzw:
+ target_frame_quat = quat_to_wxyz(target_frame_quat)
+ transformed_quat = quat_mul(quat, target_frame_quat, xyzw=xyzw)
+
+ return (transformed_position, transformed_quat)
+
+
+def transform_move_to_model_position(
+ world: World,
+ position: Point,
+ target_model: ModelWrapper | str,
+ target_link: Link | str | None = None,
+) -> Point:
+
+ target_frame_position, target_frame_quat_xyzw = get_model_pose(
+ world,
+ model=target_model,
+ link=target_link,
+ xyzw=True,
+ )
+
+ transformed_position = Rotation.from_quat(target_frame_quat_xyzw).apply(position)
+ transformed_position = (
+ target_frame_position[0] + transformed_position[0],
+ target_frame_position[1] + transformed_position[1],
+ target_frame_position[2] + transformed_position[2],
+ )
+
+ return transformed_position
+
+
+def transform_move_to_model_orientation(
+ world: World,
+ quat: Quat,
+ target_model: ModelWrapper | str,
+ target_link: Link | str | None = None,
+ xyzw: bool = False,
+) -> Quat:
+
+ target_frame_quat = get_model_orientation(
+ world,
+ model=target_model,
+ link=target_link,
+ xyzw=xyzw,
+ )
+
+ transformed_quat = quat_mul(quat, target_frame_quat, xyzw=xyzw)
+
+ return transformed_quat
+
+
+def transform_change_reference_frame_pose(
+ world: World,
+ position: Point,
+ quat: Quat,
+ target_model: ModelWrapper | str,
+ target_link: Link | str | None,
+ xyzw: bool = False,
+) -> Pose:
+ """
+ Change reference frame of original `position` and `quat` from world coordinate system to `target_model::target_link` coordinate system.
+ """
+
+ target_frame_position, target_frame_quat = get_model_pose(
+ world,
+ model=target_model,
+ link=target_link,
+ xyzw=True,
+ )
+
+ transformed_position = (
+ position[0] - target_frame_position[0],
+ position[1] - target_frame_position[1],
+ position[2] - target_frame_position[2],
+ )
+ transformed_position = Rotation.from_quat(target_frame_quat).apply(
+ transformed_position, inverse=True
+ )
+
+ if not xyzw:
+ target_frame_quat = quat_to_wxyz(target_frame_quat)
+ transformed_quat = quat_mul(target_frame_quat, quat, xyzw=xyzw)
+
+ return (tuple(transformed_position), transformed_quat)
+
+
+def transform_change_reference_frame_position(
+ world: World,
+ position: Point,
+ target_model: ModelWrapper | str,
+ target_link: Link | str | None = None,
+) -> Point:
+
+ target_frame_position, target_frame_quat_xyzw = get_model_pose(
+ world,
+ model=target_model,
+ link=target_link,
+ xyzw=True,
+ )
+
+ transformed_position = (
+ position[0] - target_frame_position[0],
+ position[1] - target_frame_position[1],
+ position[2] - target_frame_position[2],
+ )
+ transformed_position = Rotation.from_quat(target_frame_quat_xyzw).apply(
+ transformed_position, inverse=True
+ )
+
+ return tuple(transformed_position)
+
+
+def transform_change_reference_frame_orientation(
+ world: World,
+ quat: Quat,
+ target_model: ModelWrapper | str,
+ target_link: Link | str | None = None,
+ xyzw: bool = False,
+) -> Quat:
+
+ target_frame_quat = get_model_orientation(
+ world,
+ model=target_model,
+ link=target_link,
+ xyzw=xyzw,
+ )
+
+ transformed_quat = quat_mul(target_frame_quat, quat, xyzw=xyzw)
+
+ return transformed_quat
diff --git a/env_manager/env_manager/env_manager/utils/helper.py b/env_manager/env_manager/env_manager/utils/helper.py
new file mode 100644
index 0000000..a5bfd08
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/helper.py
@@ -0,0 +1,488 @@
+import numpy as np
+from gym_gz.scenario.model_wrapper import ModelWrapper
+from rclpy.node import Node
+from scenario.bindings.core import World
+from scipy.spatial.transform import Rotation
+
+from .conversions import orientation_6d_to_quat
+from .gazebo import (
+ get_model_orientation,
+ get_model_pose,
+ get_model_position,
+ transform_change_reference_frame_orientation,
+ transform_change_reference_frame_pose,
+ transform_change_reference_frame_position,
+)
+from .math import quat_mul
+from .tf2_listener import Tf2Listener
+from .types import Point, Pose, PoseRpy, Quat, Rpy
+
+
+# Helper functions #
+def get_relative_ee_position(self, translation: Point) -> Point:
+ # Scale relative action to metric units
+ translation = self.scale_relative_translation(translation)
+ # Get current position
+ current_position = self.get_ee_position()
+ # Compute target position
+ target_position = (
+ current_position[0] + translation[0],
+ current_position[1] + translation[1],
+ current_position[2] + translation[2],
+ )
+
+ # Restrict target position to a limited workspace, if desired
+ if self._restrict_position_goal_to_workspace:
+ target_position = self.restrict_position_goal_to_workspace(target_position)
+
+ return target_position
+
+
+def get_relative_ee_orientation(
+ self,
+ rotation: float | Quat | PoseRpy,
+ representation: str = "quat",
+) -> Quat:
+ current_quat_xyzw = self.get_ee_orientation()
+
+ if representation == "z":
+ current_yaw = Rotation.from_quat(current_quat_xyzw).as_euler("xyz")[2]
+ current_quat_xyzw = Rotation.from_euler(
+ "xyz", [np.pi, 0, current_yaw]
+ ).as_quat()
+
+ if isinstance(rotation, tuple):
+ if len(rotation) == 4 and representation == "quat":
+ relative_quat_xyzw = rotation
+ elif len(rotation) == 6 and representation == "6d":
+ vectors = (rotation[0:3], rotation[3:6])
+ relative_quat_xyzw = orientation_6d_to_quat(vectors[0], vectors[1])
+ else:
+ raise ValueError("Invalid rotation tuple length for representation.")
+ elif isinstance(rotation, float) and representation == "z":
+ rotation = self.scale_relative_rotation(rotation)
+ relative_quat_xyzw = Rotation.from_euler("xyz", [0, 0, rotation]).as_quat()
+ else:
+ raise TypeError("Invalid type for rotation or representation.")
+
+ target_quat_xyzw = quat_mul(tuple(current_quat_xyzw), tuple(relative_quat_xyzw))
+
+ target_quat_xyzw = normalize_quaternion(tuple(relative_quat_xyzw))
+
+ return target_quat_xyzw
+
+
+def normalize_quaternion(
+ target_quat_xyzw: tuple[float, float, float, float],
+) -> tuple[float, float, float, float]:
+ quat_array = np.array(target_quat_xyzw)
+ normalized_quat = quat_array / np.linalg.norm(quat_array)
+ return tuple(normalized_quat)
+
+
+def scale_relative_translation(self, translation: Point) -> Point:
+ return (
+ self.__scaling_factor_translation * translation[0],
+ self.__scaling_factor_translation * translation[1],
+ self.__scaling_factor_translation * translation[2],
+ )
+
+
+def scale_relative_rotation(
+ self, rotation: float | Rpy | np.floating | np.ndarray
+) -> float | Rpy:
+ scaling_factor = self.__scaling_factor_rotation
+
+ if isinstance(rotation, (int, float, np.floating)):
+ return scaling_factor * rotation
+
+ return tuple(scaling_factor * r for r in rotation)
+
+
+def restrict_position_goal_to_workspace(self, position: Point) -> Point:
+ return (
+ min(
+ self.workspace_max_bound[0],
+ max(
+ self.workspace_min_bound[0],
+ position[0],
+ ),
+ ),
+ min(
+ self.workspace_max_bound[1],
+ max(
+ self.workspace_min_bound[1],
+ position[1],
+ ),
+ ),
+ min(
+ self.workspace_max_bound[2],
+ max(
+ self.workspace_min_bound[2],
+ position[2],
+ ),
+ ),
+ )
+
+
+# def restrict_servo_translation_to_workspace(
+# self, translation: tuple[float, float, float]
+# ) -> tuple[float, float, float]:
+# current_ee_position = self.get_ee_position()
+#
+# translation = tuple(
+# 0.0
+# if (
+# current_ee_position[i] > self.workspace_max_bound[i]
+# and translation[i] > 0.0
+# )
+# or (
+# current_ee_position[i] < self.workspace_min_bound[i]
+# and translation[i] < 0.0
+# )
+# else translation[i]
+# for i in range(3)
+# )
+#
+# return translation
+
+
+def get_ee_pose(
+ node: Node,
+ world: World,
+ robot_ee_link_name: str,
+ robot_name: str,
+ robot_arm_base_link_name: str,
+ tf2_listener: Tf2Listener,
+) -> Pose | None:
+ """
+ Return the current pose of the end effector with respect to arm base link.
+ """
+
+ try:
+ robot_model = world.to_gazebo().get_model(robot_name).to_gazebo()
+ ee_position, ee_quat_xyzw = get_model_pose(
+ world=world,
+ model=robot_model,
+ link=robot_ee_link_name,
+ xyzw=True,
+ )
+ return transform_change_reference_frame_pose(
+ world=world,
+ position=ee_position,
+ quat=ee_quat_xyzw,
+ target_model=robot_model,
+ target_link=robot_arm_base_link_name,
+ xyzw=True,
+ )
+ except Exception as e:
+ node.get_logger().warn(
+ f"Cannot get end effector pose from Gazebo ({e}), using tf2..."
+ )
+ transform = tf2_listener.lookup_transform_sync(
+ source_frame=robot_ee_link_name,
+ target_frame=robot_arm_base_link_name,
+ retry=False,
+ )
+ if transform is not None:
+ return (
+ (
+ transform.translation.x,
+ transform.translation.y,
+ transform.translation.z,
+ ),
+ (
+ transform.rotation.x,
+ transform.rotation.y,
+ transform.rotation.z,
+ transform.rotation.w,
+ ),
+ )
+ else:
+ node.get_logger().error(
+ "Cannot get pose of the end effector (default values are returned)"
+ )
+ return (
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ )
+
+
+def get_ee_position(
+ world: World,
+ robot_name: str,
+ robot_ee_link_name: str,
+ robot_arm_base_link_name: str,
+ node: Node,
+ tf2_listener: Tf2Listener,
+) -> Point:
+ """
+ Return the current position of the end effector with respect to arm base link.
+ """
+
+ try:
+ robot_model = world.to_gazebo().get_model(robot_name).to_gazebo()
+ ee_position = get_model_position(
+ world=world,
+ model=robot_model,
+ link=robot_ee_link_name,
+ )
+ return transform_change_reference_frame_position(
+ world=world,
+ position=ee_position,
+ target_model=robot_model,
+ target_link=robot_arm_base_link_name,
+ )
+ except Exception as e:
+ node.get_logger().debug(
+ f"Cannot get end effector position from Gazebo ({e}), using tf2..."
+ )
+ transform = tf2_listener.lookup_transform_sync(
+ source_frame=robot_ee_link_name,
+ target_frame=robot_arm_base_link_name,
+ retry=False,
+ )
+ if transform is not None:
+ return (
+ transform.translation.x,
+ transform.translation.y,
+ transform.translation.z,
+ )
+ else:
+ node.get_logger().error(
+ "Cannot get position of the end effector (default values are returned)"
+ )
+ return (0.0, 0.0, 0.0)
+
+
+def get_ee_orientation(
+ world: World,
+ robot_name: str,
+ robot_ee_link_name: str,
+ robot_arm_base_link_name: str,
+ node: Node,
+ tf2_listener: Tf2Listener,
+) -> Quat:
+ """
+ Return the current xyzw quaternion of the end effector with respect to arm base link.
+ """
+
+ try:
+ robot_model = world.to_gazebo().get_model(robot_name).to_gazebo()
+ ee_quat_xyzw = get_model_orientation(
+ world=world,
+ model=robot_model,
+ link=robot_ee_link_name,
+ xyzw=True,
+ )
+ return transform_change_reference_frame_orientation(
+ world=world,
+ quat=ee_quat_xyzw,
+ target_model=robot_model,
+ target_link=robot_arm_base_link_name,
+ xyzw=True,
+ )
+ except Exception as e:
+ node.get_logger().warn(
+ f"Cannot get end effector orientation from Gazebo ({e}), using tf2..."
+ )
+ transform = tf2_listener.lookup_transform_sync(
+ source_frame=robot_ee_link_name,
+ target_frame=robot_arm_base_link_name,
+ retry=False,
+ )
+ if transform is not None:
+ return (
+ transform.rotation.x,
+ transform.rotation.y,
+ transform.rotation.z,
+ transform.rotation.w,
+ )
+ else:
+ node.get_logger().error(
+ "Cannot get orientation of the end effector (default values are returned)"
+ )
+ return (0.0, 0.0, 0.0, 1.0)
+
+
+def get_object_position(
+ object_model: ModelWrapper | str,
+ node: Node,
+ world: World,
+ robot_name: str,
+ robot_arm_base_link_name: str,
+) -> Point:
+ """
+ Return the current position of an object with respect to arm base link.
+ Note: Only simulated objects are currently supported.
+ """
+
+ try:
+ object_position = get_model_position(
+ world=world,
+ model=object_model,
+ )
+ return transform_change_reference_frame_position(
+ world=world,
+ position=object_position,
+ target_model=robot_name,
+ target_link=robot_arm_base_link_name,
+ )
+ except Exception as e:
+ node.get_logger().error(
+ f"Cannot get position of {object_model} object (default values are returned): {e}"
+ )
+ return (0.0, 0.0, 0.0)
+
+
+def get_object_positions(
+ node: Node,
+ world: World,
+ object_names: list[str],
+ robot_name: str,
+ robot_arm_base_link_name: str,
+) -> dict[str, tuple[float, float, float]]:
+ """
+ Return the current position of all objects with respect to arm base link.
+ Note: Only simulated objects are currently supported.
+ """
+
+ object_positions = {}
+
+ try:
+ robot_model = world.to_gazebo().get_model(robot_name).to_gazebo()
+ robot_arm_base_link = robot_model.get_link(link_name=robot_arm_base_link_name)
+ for object_name in object_names:
+ object_position = get_model_position(
+ world=world,
+ model=object_name,
+ )
+ object_positions[object_name] = transform_change_reference_frame_position(
+ world=world,
+ position=object_position,
+ target_model=robot_model,
+ target_link=robot_arm_base_link,
+ )
+ except Exception as e:
+ node.get_logger().error(
+ f"Cannot get positions of all objects (empty Dict is returned): {e}"
+ )
+
+ return object_positions
+
+
+# def substitute_special_frame(self, frame_id: str) -> str:
+# if "arm_base_link" == frame_id:
+# return self.robot_arm_base_link_name
+# elif "base_link" == frame_id:
+# return self.robot_base_link_name
+# elif "end_effector" == frame_id:
+# return self.robot_ee_link_name
+# elif "world" == frame_id:
+# try:
+# # In Gazebo, where multiple worlds are allowed
+# return self.world.to_gazebo().name()
+# except Exception as e:
+# self.get_logger().warn(f"")
+# # Otherwise (e.g. real world)
+# return "rbs_gym_world"
+# else:
+# return frame_id
+
+
+def move_to_initial_joint_configuration(self):
+ pass
+
+ # self.moveit2.move_to_configuration(self.initial_arm_joint_positions)
+
+ # if (
+ # self.robot_model_class.CLOSED_GRIPPER_JOINT_POSITIONS
+ # == self.initial_gripper_joint_positions
+ # ):
+ # self.gripper.reset_close()
+ # else:
+ # self.gripper.reset_open()
+
+
+def check_terrain_collision(
+ world: World,
+ robot_name: str,
+ terrain_name: str,
+ robot_base_link_name: str,
+ robot_arm_link_names: list[str],
+ robot_gripper_link_names: list[str],
+) -> bool:
+ """
+ Returns true if robot links are in collision with the ground.
+ """
+ robot_name_len = len(robot_name)
+ terrain_model = world.get_model(terrain_name)
+
+ for contact in terrain_model.contacts():
+ body_b = contact.body_b
+
+ if body_b.startswith(robot_name) and len(body_b) > robot_name_len:
+ link = body_b[robot_name_len + 2 :]
+
+ if link != robot_base_link_name and (
+ link in robot_arm_link_names or link in robot_gripper_link_names
+ ):
+ return True
+
+ return False
+
+
+def check_all_objects_outside_workspace(
+ workspace: Pose,
+ object_positions: dict[str, Point],
+) -> bool:
+ """
+ Returns True if all objects are outside the workspace.
+ """
+
+ return all(
+ [
+ check_object_outside_workspace(workspace, object_position)
+ for object_position in object_positions.values()
+ ]
+ )
+
+
+def check_object_outside_workspace(workspace: Pose, object_position: Point) -> bool:
+ """
+ Returns True if the object is outside the workspace.
+ """
+ workspace_min_bound, workspace_max_bound = workspace
+ return any(
+ coord < min_bound or coord > max_bound
+ for coord, min_bound, max_bound in zip(
+ object_position, workspace_min_bound, workspace_max_bound
+ )
+ )
+
+
+# def add_parameter_overrides(self, parameter_overrides: Dict[str, any]):
+# self.add_task_parameter_overrides(parameter_overrides)
+# self.add_randomizer_parameter_overrides(parameter_overrides)
+#
+#
+# def add_task_parameter_overrides(self, parameter_overrides: Dict[str, any]):
+# self.__task_parameter_overrides.update(parameter_overrides)
+#
+#
+# def add_randomizer_parameter_overrides(self, parameter_overrides: Dict[str, any]):
+# self._randomizer_parameter_overrides.update(parameter_overrides)
+#
+#
+# def __consume_parameter_overrides(self):
+# for key, value in self.__task_parameter_overrides.items():
+# if hasattr(self, key):
+# setattr(self, key, value)
+# elif hasattr(self, f"_{key}"):
+# setattr(self, f"_{key}", value)
+# elif hasattr(self, f"__{key}"):
+# setattr(self, f"__{key}", value)
+# else:
+# self.get_logger().error(f"Override '{key}' is not supperted by the task.")
+#
+# self.__task_parameter_overrides.clear()
diff --git a/env_manager/env_manager/env_manager/utils/logging.py b/env_manager/env_manager/env_manager/utils/logging.py
new file mode 100644
index 0000000..e9b8c09
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/logging.py
@@ -0,0 +1,25 @@
+from typing import Union
+
+from gymnasium import logger as gym_logger
+from gym_gz.utils import logger as gym_ign_logger
+
+
+def set_log_level(log_level: Union[int, str]):
+ """
+ Set log level for (Gym) Ignition.
+ """
+
+ if not isinstance(log_level, int):
+ log_level = str(log_level).upper()
+
+ if "WARNING" == log_level:
+ log_level = "WARN"
+ elif not log_level in ["DEBUG", "INFO", "WARN", "ERROR", "DISABLED"]:
+ log_level = "DISABLED"
+
+ log_level = getattr(gym_logger, log_level)
+
+ gym_ign_logger.set_level(
+ level=log_level,
+ scenario_level=log_level,
+ )
diff --git a/env_manager/env_manager/env_manager/utils/math.py b/env_manager/env_manager/env_manager/utils/math.py
new file mode 100644
index 0000000..291cdf5
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/math.py
@@ -0,0 +1,44 @@
+import numpy as np
+
+
+def quat_mul(
+ quat_0: tuple[float, float, float, float],
+ quat_1: tuple[float, float, float, float],
+ xyzw: bool = True,
+) -> tuple[float, float, float, float]:
+ """
+ Multiply two quaternions
+ """
+ if xyzw:
+ x0, y0, z0, w0 = quat_0
+ x1, y1, z1, w1 = quat_1
+ return (
+ x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0,
+ -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0,
+ x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0,
+ -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0,
+ )
+ else:
+ w0, x0, y0, z0 = quat_0
+ w1, x1, y1, z1 = quat_1
+ return (
+ -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0,
+ x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0,
+ -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0,
+ x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0,
+ )
+
+
+def distance_to_nearest_point(
+ origin: tuple[float, float, float], points: list[tuple[float, float, float]]
+) -> float:
+
+ return np.linalg.norm(np.array(points) - np.array(origin), axis=1).min()
+
+
+def get_nearest_point(
+ origin: tuple[float, float, float], points: list[tuple[float, float, float]]
+) -> tuple[float, float, float]:
+
+ target_distances = np.linalg.norm(np.array(points) - np.array(origin), axis=1)
+ return points[target_distances.argmin()]
diff --git a/env_manager/env_manager/env_manager/utils/tf2_broadcaster.py b/env_manager/env_manager/env_manager/utils/tf2_broadcaster.py
new file mode 100644
index 0000000..27fb52a
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/tf2_broadcaster.py
@@ -0,0 +1,74 @@
+import sys
+from typing import Tuple
+
+import rclpy
+from geometry_msgs.msg import TransformStamped
+from rclpy.node import Node
+from rclpy.parameter import Parameter
+from tf2_ros import StaticTransformBroadcaster
+
+
+class Tf2Broadcaster:
+ def __init__(
+ self,
+ node: Node,
+ ):
+
+ self._node = node
+ self.__tf2_broadcaster = StaticTransformBroadcaster(node=self._node)
+ self._transform_stamped = TransformStamped()
+
+ def broadcast_tf(
+ self,
+ parent_frame_id: str,
+ child_frame_id: str,
+ translation: Tuple[float, float, float],
+ rotation: Tuple[float, float, float, float],
+ xyzw: bool = True,
+ ):
+ """
+ Broadcast transformation of the camera
+ """
+
+ self._transform_stamped.header.frame_id = parent_frame_id
+ self._transform_stamped.child_frame_id = child_frame_id
+
+ self._transform_stamped.header.stamp = self._node.get_clock().now().to_msg()
+
+ self._transform_stamped.transform.translation.x = float(translation[0])
+ self._transform_stamped.transform.translation.y = float(translation[1])
+ self._transform_stamped.transform.translation.z = float(translation[2])
+
+ if xyzw:
+ self._transform_stamped.transform.rotation.x = float(rotation[0])
+ self._transform_stamped.transform.rotation.y = float(rotation[1])
+ self._transform_stamped.transform.rotation.z = float(rotation[2])
+ self._transform_stamped.transform.rotation.w = float(rotation[3])
+ else:
+ self._transform_stamped.transform.rotation.w = float(rotation[0])
+ self._transform_stamped.transform.rotation.x = float(rotation[1])
+ self._transform_stamped.transform.rotation.y = float(rotation[2])
+ self._transform_stamped.transform.rotation.z = float(rotation[3])
+
+ self.__tf2_broadcaster.sendTransform(self._transform_stamped)
+
+
+class Tf2BroadcasterStandalone(Node, Tf2Broadcaster):
+ def __init__(
+ self,
+ node_name: str = "env_manager_tf_broadcaster",
+ use_sim_time: bool = True,
+ ):
+
+ try:
+ rclpy.init()
+ except Exception as e:
+ if not rclpy.ok():
+ sys.exit(f"ROS 2 context could not be initialised: {e}")
+
+ Node.__init__(self, node_name)
+ self.set_parameters(
+ [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)]
+ )
+
+ Tf2Broadcaster.__init__(self, node=self)
diff --git a/env_manager/env_manager/env_manager/utils/tf2_dynamic_broadcaster.py b/env_manager/env_manager/env_manager/utils/tf2_dynamic_broadcaster.py
new file mode 100644
index 0000000..4737754
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/tf2_dynamic_broadcaster.py
@@ -0,0 +1,74 @@
+import sys
+from typing import Tuple
+
+import rclpy
+from geometry_msgs.msg import TransformStamped
+from rclpy.node import Node
+from rclpy.parameter import Parameter
+from tf2_ros import TransformBroadcaster
+
+
+class Tf2DynamicBroadcaster:
+ def __init__(
+ self,
+ node: Node,
+ ):
+
+ self._node = node
+ self.__tf2_broadcaster = TransformBroadcaster(node=self._node)
+ self._transform_stamped = TransformStamped()
+
+ def broadcast_tf(
+ self,
+ parent_frame_id: str,
+ child_frame_id: str,
+ translation: Tuple[float, float, float],
+ rotation: Tuple[float, float, float, float],
+ xyzw: bool = True,
+ ):
+ """
+ Broadcast transformation of the camera
+ """
+
+ self._transform_stamped.header.frame_id = parent_frame_id
+ self._transform_stamped.child_frame_id = child_frame_id
+
+ self._transform_stamped.header.stamp = self._node.get_clock().now().to_msg()
+
+ self._transform_stamped.transform.translation.x = float(translation[0])
+ self._transform_stamped.transform.translation.y = float(translation[1])
+ self._transform_stamped.transform.translation.z = float(translation[2])
+
+ if xyzw:
+ self._transform_stamped.transform.rotation.x = float(rotation[0])
+ self._transform_stamped.transform.rotation.y = float(rotation[1])
+ self._transform_stamped.transform.rotation.z = float(rotation[2])
+ self._transform_stamped.transform.rotation.w = float(rotation[3])
+ else:
+ self._transform_stamped.transform.rotation.w = float(rotation[0])
+ self._transform_stamped.transform.rotation.x = float(rotation[1])
+ self._transform_stamped.transform.rotation.y = float(rotation[2])
+ self._transform_stamped.transform.rotation.z = float(rotation[3])
+
+ self.__tf2_broadcaster.sendTransform(self._transform_stamped)
+
+
+class Tf2BroadcasterStandalone(Node, Tf2DynamicBroadcaster):
+ def __init__(
+ self,
+ node_name: str = "env_manager_tf_broadcaster",
+ use_sim_time: bool = True,
+ ):
+
+ try:
+ rclpy.init()
+ except Exception as e:
+ if not rclpy.ok():
+ sys.exit(f"ROS 2 context could not be initialised: {e}")
+
+ Node.__init__(self, node_name)
+ self.set_parameters(
+ [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)]
+ )
+
+ Tf2DynamicBroadcaster.__init__(self, node=self)
diff --git a/env_manager/env_manager/env_manager/utils/tf2_listener.py b/env_manager/env_manager/env_manager/utils/tf2_listener.py
new file mode 100644
index 0000000..ef22bb0
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/tf2_listener.py
@@ -0,0 +1,74 @@
+import sys
+from typing import Optional
+
+import rclpy
+from geometry_msgs.msg import Transform
+from rclpy.node import Node
+from rclpy.parameter import Parameter
+from tf2_ros import Buffer, TransformListener
+
+
+class Tf2Listener:
+ def __init__(
+ self,
+ node: Node,
+ ):
+
+ self._node = node
+
+ # Create tf2 buffer and listener for transform lookup
+ self.__tf2_buffer = Buffer()
+ TransformListener(buffer=self.__tf2_buffer, node=node)
+
+ def lookup_transform_sync(
+ self, target_frame: str, source_frame: str, retry: bool = True
+ ) -> Optional[Transform]:
+
+ try:
+ return self.__tf2_buffer.lookup_transform(
+ target_frame=target_frame,
+ source_frame=source_frame,
+ time=rclpy.time.Time(),
+ ).transform
+ except:
+ if retry:
+ while rclpy.ok():
+ if self.__tf2_buffer.can_transform(
+ target_frame=target_frame,
+ source_frame=source_frame,
+ time=rclpy.time.Time(),
+ timeout=rclpy.time.Duration(seconds=1, nanoseconds=0),
+ ):
+ return self.__tf2_buffer.lookup_transform(
+ target_frame=target_frame,
+ source_frame=source_frame,
+ time=rclpy.time.Time(),
+ ).transform
+
+ self._node.get_logger().warn(
+ f'Lookup of transform from "{source_frame}"'
+ f' to "{target_frame}" failed, retrying...'
+ )
+ else:
+ return None
+
+
+class Tf2ListenerStandalone(Node, Tf2Listener):
+ def __init__(
+ self,
+ node_name: str = "rbs_gym_tf_listener",
+ use_sim_time: bool = True,
+ ):
+
+ try:
+ rclpy.init()
+ except Exception as e:
+ if not rclpy.ok():
+ sys.exit(f"ROS 2 context could not be initialised: {e}")
+
+ Node.__init__(self, node_name)
+ self.set_parameters(
+ [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)]
+ )
+
+ Tf2Listener.__init__(self, node=self)
diff --git a/env_manager/env_manager/env_manager/utils/types.py b/env_manager/env_manager/env_manager/utils/types.py
new file mode 100644
index 0000000..a502bd6
--- /dev/null
+++ b/env_manager/env_manager/env_manager/utils/types.py
@@ -0,0 +1,5 @@
+Point = tuple[float, float, float]
+Quat = tuple[float, float, float, float]
+Rpy = tuple[float, float, float]
+Pose = tuple[Point, Quat]
+PoseRpy = tuple[Point, Rpy]
diff --git a/env_manager/env_manager/package.nix b/env_manager/env_manager/package.nix
new file mode 100644
index 0000000..49bca32
--- /dev/null
+++ b/env_manager/env_manager/package.nix
@@ -0,0 +1,31 @@
+# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0
+# Copyright 2025 None
+# Distributed under the terms of the Apache-2.0 license
+{
+ lib,
+ buildRosPackage,
+ ament-copyright,
+ ament-flake8,
+ ament-pep257,
+ gym-gz-ros-python,
+ pythonPackages,
+ scenario,
+}:
+let
+ rbs-assets-library = pythonPackages.callPackage ../../../../rbs_assets_library/default.nix {};
+in
+buildRosPackage rec {
+ pname = "ros-jazzy-env-manager";
+ version = "0.0.0";
+
+ src = ./.;
+
+ buildType = "ament_python";
+ checkInputs = [ament-copyright ament-flake8 ament-pep257 gym-gz-ros-python pythonPackages.pytest scenario];
+ propagatedBuildInputs = [gym-gz-ros-python rbs-assets-library];
+
+ meta = {
+ description = "TODO: Package description";
+ license = with lib.licenses; [asl20];
+ };
+}
diff --git a/env_manager/env_manager/package.xml b/env_manager/env_manager/package.xml
new file mode 100644
index 0000000..5018612
--- /dev/null
+++ b/env_manager/env_manager/package.xml
@@ -0,0 +1,21 @@
+
+
+
+ env_manager
+ 0.0.0
+ TODO: Package description
+ narmak
+ Apache-2.0
+
+ scenario
+ gym-gz-ros-python
+
+ ament_copyright
+ ament_flake8
+ ament_pep257
+ python3-pytest
+
+
+ ament_python
+
+
diff --git a/env_manager/env_manager/resource/env_manager b/env_manager/env_manager/resource/env_manager
new file mode 100644
index 0000000..e69de29
diff --git a/env_manager/env_manager/setup.cfg b/env_manager/env_manager/setup.cfg
new file mode 100644
index 0000000..36219d7
--- /dev/null
+++ b/env_manager/env_manager/setup.cfg
@@ -0,0 +1,4 @@
+[develop]
+script_dir=$base/lib/env_manager
+[install]
+install_scripts=$base/lib/env_manager
diff --git a/env_manager/env_manager/setup.py b/env_manager/env_manager/setup.py
new file mode 100644
index 0000000..3e62b74
--- /dev/null
+++ b/env_manager/env_manager/setup.py
@@ -0,0 +1,25 @@
+from setuptools import find_packages, setup
+
+package_name = 'env_manager'
+
+setup(
+ name=package_name,
+ version='0.0.0',
+ packages=find_packages(exclude=['test']),
+ data_files=[
+ ('share/ament_index/resource_index/packages',
+ ['resource/' + package_name]),
+ ('share/' + package_name, ['package.xml']),
+ ],
+ install_requires=['setuptools', 'trimesh'],
+ zip_safe=True,
+ maintainer='narmak',
+ maintainer_email='ur.narmak@gmail.com',
+ description='TODO: Package description',
+ license='Apache-2.0',
+ tests_require=['pytest'],
+ entry_points={
+ 'console_scripts': [
+ ],
+ },
+)
diff --git a/env_manager/env_manager/test/test_copyright.py b/env_manager/env_manager/test/test_copyright.py
new file mode 100644
index 0000000..97a3919
--- /dev/null
+++ b/env_manager/env_manager/test/test_copyright.py
@@ -0,0 +1,25 @@
+# Copyright 2015 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_copyright.main import main
+import pytest
+
+
+# Remove the `skip` decorator once the source file(s) have a copyright header
+@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
+@pytest.mark.copyright
+@pytest.mark.linter
+def test_copyright():
+ rc = main(argv=['.', 'test'])
+ assert rc == 0, 'Found errors'
diff --git a/env_manager/env_manager/test/test_flake8.py b/env_manager/env_manager/test/test_flake8.py
new file mode 100644
index 0000000..27ee107
--- /dev/null
+++ b/env_manager/env_manager/test/test_flake8.py
@@ -0,0 +1,25 @@
+# Copyright 2017 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_flake8.main import main_with_errors
+import pytest
+
+
+@pytest.mark.flake8
+@pytest.mark.linter
+def test_flake8():
+ rc, errors = main_with_errors(argv=[])
+ assert rc == 0, \
+ 'Found %d code style errors / warnings:\n' % len(errors) + \
+ '\n'.join(errors)
diff --git a/env_manager/env_manager/test/test_pep257.py b/env_manager/env_manager/test/test_pep257.py
new file mode 100644
index 0000000..b234a38
--- /dev/null
+++ b/env_manager/env_manager/test/test_pep257.py
@@ -0,0 +1,23 @@
+# Copyright 2015 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_pep257.main import main
+import pytest
+
+
+@pytest.mark.linter
+@pytest.mark.pep257
+def test_pep257():
+ rc = main(argv=['.', 'test'])
+ assert rc == 0, 'Found code style errors / warnings'
diff --git a/env_manager/env_manager_interfaces/CMakeLists.txt b/env_manager/env_manager_interfaces/CMakeLists.txt
new file mode 100644
index 0000000..a0a5c75
--- /dev/null
+++ b/env_manager/env_manager_interfaces/CMakeLists.txt
@@ -0,0 +1,38 @@
+cmake_minimum_required(VERSION 3.8)
+project(env_manager_interfaces)
+
+if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ add_compile_options(-Wall -Wextra -Wpedantic)
+endif()
+
+# find dependencies
+find_package(ament_cmake REQUIRED)
+find_package(lifecycle_msgs REQUIRED)
+find_package(rosidl_default_generators REQUIRED)
+# uncomment the following section in order to fill in
+# further dependencies manually.
+# find_package( REQUIRED)
+
+rosidl_generate_interfaces(${PROJECT_NAME}
+ srv/StartEnv.srv
+ srv/ConfigureEnv.srv
+ srv/LoadEnv.srv
+ srv/UnloadEnv.srv
+ srv/ResetEnv.srv
+ msg/EnvState.msg
+ DEPENDENCIES lifecycle_msgs
+)
+
+if(BUILD_TESTING)
+ find_package(ament_lint_auto REQUIRED)
+ # the following line skips the linter which checks for copyrights
+ # comment the line when a copyright and license is added to all source files
+ set(ament_cmake_copyright_FOUND TRUE)
+ # the following line skips cpplint (only works in a git repo)
+ # comment the line when this package is in a git repo and when
+ # a copyright and license is added to all source files
+ set(ament_cmake_cpplint_FOUND TRUE)
+ ament_lint_auto_find_test_dependencies()
+endif()
+
+ament_package()
diff --git a/env_manager/env_manager_interfaces/LICENSE b/env_manager/env_manager_interfaces/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/env_manager/env_manager_interfaces/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/env_manager/env_manager_interfaces/msg/EnvState.msg b/env_manager/env_manager_interfaces/msg/EnvState.msg
new file mode 100644
index 0000000..fa4c82e
--- /dev/null
+++ b/env_manager/env_manager_interfaces/msg/EnvState.msg
@@ -0,0 +1,4 @@
+string name
+string type
+string plugin_name
+lifecycle_msgs/State state
diff --git a/env_manager/env_manager_interfaces/package.nix b/env_manager/env_manager_interfaces/package.nix
new file mode 100644
index 0000000..3402f72
--- /dev/null
+++ b/env_manager/env_manager_interfaces/package.nix
@@ -0,0 +1,23 @@
+# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0
+
+# Copyright 2025 None
+# Distributed under the terms of the Apache-2.0 license
+
+{ lib, buildRosPackage, ament-cmake, ament-lint-auto, ament-lint-common, builtin-interfaces, lifecycle-msgs, rosidl-default-generators, rosidl-default-runtime }:
+buildRosPackage rec {
+ pname = "ros-jazzy-env-manager-interfaces";
+ version = "0.0.0";
+
+ src = ./.;
+
+ buildType = "ament_cmake";
+ buildInputs = [ ament-cmake rosidl-default-generators ];
+ checkInputs = [ ament-lint-auto ament-lint-common ];
+ propagatedBuildInputs = [ builtin-interfaces lifecycle-msgs rosidl-default-runtime ];
+ nativeBuildInputs = [ ament-cmake rosidl-default-generators ];
+
+ meta = {
+ description = "TODO: Package description";
+ license = with lib.licenses; [ asl20 ];
+ };
+}
diff --git a/env_manager/env_manager_interfaces/package.xml b/env_manager/env_manager_interfaces/package.xml
new file mode 100644
index 0000000..4fb3920
--- /dev/null
+++ b/env_manager/env_manager_interfaces/package.xml
@@ -0,0 +1,27 @@
+
+
+
+ env_manager_interfaces
+ 0.0.0
+ TODO: Package description
+ solid-sinusoid
+ Apache-2.0
+
+ ament_cmake
+ rosidl_default_generators
+
+ builtin_interfaces
+ lifecycle_msgs
+
+ builtin_interfaces
+ lifecycle_msgs
+ rosidl_default_runtime
+
+ ament_lint_auto
+ ament_lint_common
+
+ rosidl_interface_packages
+
+ ament_cmake
+
+
diff --git a/env_manager/env_manager_interfaces/srv/ConfigureEnv.srv b/env_manager/env_manager_interfaces/srv/ConfigureEnv.srv
new file mode 100644
index 0000000..4292382
--- /dev/null
+++ b/env_manager/env_manager_interfaces/srv/ConfigureEnv.srv
@@ -0,0 +1,3 @@
+string name
+---
+bool ok
\ No newline at end of file
diff --git a/env_manager/env_manager_interfaces/srv/LoadEnv.srv b/env_manager/env_manager_interfaces/srv/LoadEnv.srv
new file mode 100644
index 0000000..d6d46cc
--- /dev/null
+++ b/env_manager/env_manager_interfaces/srv/LoadEnv.srv
@@ -0,0 +1,4 @@
+string name
+string type
+---
+bool ok
\ No newline at end of file
diff --git a/env_manager/env_manager_interfaces/srv/ResetEnv.srv b/env_manager/env_manager_interfaces/srv/ResetEnv.srv
new file mode 100644
index 0000000..5d0cbc4
--- /dev/null
+++ b/env_manager/env_manager_interfaces/srv/ResetEnv.srv
@@ -0,0 +1,3 @@
+
+---
+bool ok
diff --git a/env_manager/env_manager_interfaces/srv/StartEnv.srv b/env_manager/env_manager_interfaces/srv/StartEnv.srv
new file mode 100644
index 0000000..d6d46cc
--- /dev/null
+++ b/env_manager/env_manager_interfaces/srv/StartEnv.srv
@@ -0,0 +1,4 @@
+string name
+string type
+---
+bool ok
\ No newline at end of file
diff --git a/env_manager/env_manager_interfaces/srv/UnloadEnv.srv b/env_manager/env_manager_interfaces/srv/UnloadEnv.srv
new file mode 100644
index 0000000..4292382
--- /dev/null
+++ b/env_manager/env_manager_interfaces/srv/UnloadEnv.srv
@@ -0,0 +1,3 @@
+string name
+---
+bool ok
\ No newline at end of file
diff --git a/env_manager/rbs_gym/LICENSE b/env_manager/rbs_gym/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/env_manager/rbs_gym/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/env_manager/rbs_gym/hyperparams/sac.yml b/env_manager/rbs_gym/hyperparams/sac.yml
new file mode 100644
index 0000000..d372fc7
--- /dev/null
+++ b/env_manager/rbs_gym/hyperparams/sac.yml
@@ -0,0 +1,42 @@
+# Reach
+Reach-Gazebo-v0:
+ policy: "MlpPolicy"
+ policy_kwargs:
+ n_critics: 2
+ net_arch: [128, 64]
+ n_timesteps: 200000
+ buffer_size: 25000
+ learning_starts: 5000
+ batch_size: 512
+ learning_rate: lin_0.0002
+ gamma: 0.95
+ tau: 0.001
+ ent_coef: "auto_0.1"
+ target_entropy: "auto"
+ train_freq: [1, "episode"]
+ gradient_steps: 100
+ noise_type: "normal"
+ noise_std: 0.025
+ use_sde: False
+ optimize_memory_usage: False
+
+Reach-ColorImage-Gazebo-v0:
+ policy: "CnnPolicy"
+ policy_kwargs:
+ n_critics: 2
+ net_arch: [128, 128]
+ n_timesteps: 50000
+ buffer_size: 25000
+ learning_starts: 5000
+ batch_size: 32
+ learning_rate: lin_0.0002
+ gamma: 0.95
+ tau: 0.0005
+ ent_coef: "auto_0.1"
+ target_entropy: "auto"
+ train_freq: [1, "episode"]
+ gradient_steps: 100
+ noise_type: "normal"
+ noise_std: 0.025
+ use_sde: False
+ optimize_memory_usage: False
diff --git a/env_manager/rbs_gym/hyperparams/td3.yml b/env_manager/rbs_gym/hyperparams/td3.yml
new file mode 100644
index 0000000..6f57e0a
--- /dev/null
+++ b/env_manager/rbs_gym/hyperparams/td3.yml
@@ -0,0 +1,39 @@
+Reach-Gazebo-v0:
+ policy: "MlpPolicy"
+ policy_kwargs:
+ n_critics: 2
+ net_arch: [128, 64]
+ n_timesteps: 200000
+ buffer_size: 25000
+ learning_starts: 5000
+ batch_size: 512
+ learning_rate: lin_0.0002
+ gamma: 0.95
+ tau: 0.001
+ train_freq: [1, "episode"]
+ gradient_steps: 100
+ target_policy_noise: 0.1
+ target_noise_clip: 0.2
+ noise_type: "normal"
+ noise_std: 0.025
+ optimize_memory_usage: False
+
+Reach-ColorImage-Gazebo-v0:
+ policy: "CnnPolicy"
+ policy_kwargs:
+ n_critics: 2
+ net_arch: [128, 128]
+ n_timesteps: 50000
+ buffer_size: 25000
+ learning_starts: 5000
+ batch_size: 32
+ learning_rate: lin_0.0002
+ gamma: 0.95
+ tau: 0.0005
+ train_freq: [1, "episode"]
+ gradient_steps: 100
+ target_policy_noise: 0.1
+ target_noise_clip: 0.2
+ noise_type: "normal"
+ noise_std: 0.025
+ optimize_memory_usage: True
diff --git a/env_manager/rbs_gym/hyperparams/tqc.yml b/env_manager/rbs_gym/hyperparams/tqc.yml
new file mode 100644
index 0000000..7e483af
--- /dev/null
+++ b/env_manager/rbs_gym/hyperparams/tqc.yml
@@ -0,0 +1,46 @@
+# Reach
+Reach-Gazebo-v0:
+ policy: "MlpPolicy"
+ policy_kwargs:
+ n_quantiles: 25
+ n_critics: 2
+ net_arch: [128, 64]
+ n_timesteps: 200000
+ buffer_size: 25000
+ learning_starts: 5000
+ batch_size: 512
+ learning_rate: lin_0.0002
+ gamma: 0.95
+ tau: 0.001
+ ent_coef: "auto_0.1"
+ target_entropy: "auto"
+ top_quantiles_to_drop_per_net: 2
+ train_freq: [1, "episode"]
+ gradient_steps: 100
+ noise_type: "normal"
+ noise_std: 0.025
+ use_sde: False
+ optimize_memory_usage: False
+
+Reach-ColorImage-Gazebo-v0:
+ policy: "CnnPolicy"
+ policy_kwargs:
+ n_quantiles: 25
+ n_critics: 2
+ net_arch: [128, 128]
+ n_timesteps: 50000
+ buffer_size: 25000
+ learning_starts: 5000
+ batch_size: 32
+ learning_rate: lin_0.0002
+ gamma: 0.95
+ tau: 0.0005
+ ent_coef: "auto_0.1"
+ target_entropy: "auto"
+ top_quantiles_to_drop_per_net: 2
+ train_freq: [1, "episode"]
+ gradient_steps: 100
+ noise_type: "normal"
+ noise_std: 0.025
+ use_sde: False
+ optimize_memory_usage: True
diff --git a/env_manager/rbs_gym/launch/evaluate.launch.py b/env_manager/rbs_gym/launch/evaluate.launch.py
new file mode 100644
index 0000000..4e69f23
--- /dev/null
+++ b/env_manager/rbs_gym/launch/evaluate.launch.py
@@ -0,0 +1,426 @@
+from launch import LaunchContext, LaunchDescription
+from launch.actions import (
+ DeclareLaunchArgument,
+ IncludeLaunchDescription,
+ OpaqueFunction,
+ SetEnvironmentVariable,
+ TimerAction
+)
+from launch.launch_description_sources import PythonLaunchDescriptionSource
+from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
+from launch_ros.substitutions import FindPackageShare
+from launch_ros.actions import Node
+import os
+from os import cpu_count
+from ament_index_python.packages import get_package_share_directory
+
+def launch_setup(context, *args, **kwargs):
+ # Initialize Arguments
+ robot_type = LaunchConfiguration("robot_type")
+ # General arguments
+ with_gripper_condition = LaunchConfiguration("with_gripper")
+ controllers_file = LaunchConfiguration("controllers_file")
+ cartesian_controllers = LaunchConfiguration("cartesian_controllers")
+ description_package = LaunchConfiguration("description_package")
+ description_file = LaunchConfiguration("description_file")
+ robot_name = LaunchConfiguration("robot_name")
+ start_joint_controller = LaunchConfiguration("start_joint_controller")
+ initial_joint_controller = LaunchConfiguration("initial_joint_controller")
+ launch_simulation = LaunchConfiguration("launch_sim")
+ launch_moveit = LaunchConfiguration("launch_moveit")
+ launch_task_planner = LaunchConfiguration("launch_task_planner")
+ launch_perception = LaunchConfiguration("launch_perception")
+ moveit_config_package = LaunchConfiguration("moveit_config_package")
+ moveit_config_file = LaunchConfiguration("moveit_config_file")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ sim_gazebo = LaunchConfiguration("sim_gazebo")
+ hardware = LaunchConfiguration("hardware")
+ env_manager = LaunchConfiguration("env_manager")
+ launch_controllers = LaunchConfiguration("launch_controllers")
+ gazebo_gui = LaunchConfiguration("gazebo_gui")
+ gripper_name = LaunchConfiguration("gripper_name")
+
+ # training arguments
+ env = LaunchConfiguration("env")
+ algo = LaunchConfiguration("algo")
+ num_threads = LaunchConfiguration("num_threads")
+ seed = LaunchConfiguration("seed")
+ log_folder = LaunchConfiguration("log_folder")
+ verbose = LaunchConfiguration("verbose")
+ # use_sim_time = LaunchConfiguration("use_sim_time")
+ log_level = LaunchConfiguration("log_level")
+ env_kwargs = LaunchConfiguration("env_kwargs")
+
+ n_episodes = LaunchConfiguration("n_episodes")
+ exp_id = LaunchConfiguration("exp_id")
+ load_best = LaunchConfiguration("load_best")
+ load_checkpoint = LaunchConfiguration("load_checkpoint")
+ stochastic = LaunchConfiguration("stochastic")
+ reward_log = LaunchConfiguration("reward_log")
+ norm_reward = LaunchConfiguration("norm_reward")
+ no_render = LaunchConfiguration("no_render")
+
+
+ sim_gazebo = LaunchConfiguration("sim_gazebo")
+ launch_simulation = LaunchConfiguration("launch_sim")
+
+ initial_joint_controllers_file_path = os.path.join(
+ get_package_share_directory('rbs_arm'), 'config', 'controllers.yaml'
+ )
+
+ single_robot_setup = IncludeLaunchDescription(
+ PythonLaunchDescriptionSource([
+ PathJoinSubstitution([
+ FindPackageShare('rbs_bringup'),
+ "launch",
+ "rbs_robot.launch.py"
+ ])
+ ]),
+ launch_arguments={
+ "env_manager": env_manager,
+ "with_gripper": with_gripper_condition,
+ "gripper_name": gripper_name,
+ # "controllers_file": controllers_file,
+ "robot_type": robot_type,
+ "controllers_file": initial_joint_controllers_file_path,
+ "cartesian_controllers": cartesian_controllers,
+ "description_package": description_package,
+ "description_file": description_file,
+ "robot_name": robot_name,
+ "start_joint_controller": start_joint_controller,
+ "initial_joint_controller": initial_joint_controller,
+ "launch_simulation": launch_simulation,
+ "launch_moveit": launch_moveit,
+ "launch_task_planner": launch_task_planner,
+ "launch_perception": launch_perception,
+ "moveit_config_package": moveit_config_package,
+ "moveit_config_file": moveit_config_file,
+ "use_sim_time": use_sim_time,
+ "sim_gazebo": sim_gazebo,
+ "hardware": hardware,
+ "launch_controllers": launch_controllers,
+ # "gazebo_gui": gazebo_gui
+ }.items()
+ )
+
+ args = [
+ "--env",
+ env,
+ "--env-kwargs",
+ env_kwargs,
+ "--algo",
+ algo,
+ "--seed",
+ seed,
+ "--num-threads",
+ num_threads,
+ "--n-episodes",
+ n_episodes,
+ "--log-folder",
+ log_folder,
+ "--exp-id",
+ exp_id,
+ "--load-best",
+ load_best,
+ "--load-checkpoint",
+ load_checkpoint,
+ "--stochastic",
+ stochastic,
+ "--reward-log",
+ reward_log,
+ "--norm-reward",
+ norm_reward,
+ "--no-render",
+ no_render,
+ "--verbose",
+ verbose,
+ "--ros-args",
+ "--log-level",
+ log_level,
+ ]
+
+ rl_task = Node(
+ package="rbs_gym",
+ executable="evaluate",
+ output="log",
+ arguments=args,
+ parameters=[{"use_sim_time": use_sim_time}],
+ )
+
+
+
+ delay_robot_control_stack = TimerAction(
+ period=10.0,
+ actions=[single_robot_setup]
+ )
+
+ nodes_to_start = [
+ # env,
+ rl_task,
+ delay_robot_control_stack
+ ]
+ return nodes_to_start
+
+
+def generate_launch_description():
+ declared_arguments = []
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_type",
+ description="Type of robot by name",
+ choices=["rbs_arm","ur3", "ur3e", "ur5", "ur5e", "ur10", "ur10e", "ur16e"],
+ default_value="rbs_arm",
+ )
+ )
+ # General arguments
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "controllers_file",
+ default_value="controllers_gazebosim.yaml",
+ description="YAML file with the controllers configuration.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_package",
+ default_value="rbs_arm",
+ description="Description package with robot URDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom description.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_file",
+ default_value="rbs_arm_modular.xacro",
+ description="URDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_name",
+ default_value="arm0",
+ description="Name for robot, used to apply namespace for specific robot in multirobot setup",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "start_joint_controller",
+ default_value="false",
+ description="Enable headless mode for robot control",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "initial_joint_controller",
+ default_value="joint_trajectory_controller",
+ description="Robot controller to start.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_package",
+ default_value="rbs_arm",
+ description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom moveit config.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_file",
+ default_value="rbs_arm.srdf.xacro",
+ description="MoveIt SRDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_sim_time",
+ default_value="true",
+ description="Make MoveIt to use simulation time.\
+ This is needed for the trajectory planing in simulation.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "gripper_name",
+ default_value="rbs_gripper",
+ choices=["rbs_gripper", ""],
+ description="choose gripper by name (leave empty if hasn't)",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("with_gripper",
+ default_value="true",
+ description="With gripper or not?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("sim_gazebo",
+ default_value="true",
+ description="Gazebo Simulation")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("env_manager",
+ default_value="false",
+ description="Launch env_manager?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_sim",
+ default_value="true",
+ description="Launch simulator (Gazebo)?\
+ Most general arg")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_moveit",
+ default_value="false",
+ description="Launch moveit?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_perception",
+ default_value="false",
+ description="Launch perception?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_task_planner",
+ default_value="false",
+ description="Launch task_planner?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("cartesian_controllers",
+ default_value="true",
+ description="Load cartesian\
+ controllers?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("hardware",
+ choices=["gazebo", "mock"],
+ default_value="gazebo",
+ description="Choose your harware_interface")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_controllers",
+ default_value="true",
+ description="Launch controllers?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("gazebo_gui",
+ default_value="true",
+ description="Launch gazebo with gui?")
+ )
+ # training arguments
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env",
+ default_value="Reach-Gazebo-v0",
+ description="Environment ID",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env_kwargs",
+ default_value="",
+ description="Optional keyword argument to pass to the env constructor.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "vec_env",
+ default_value="dummy",
+ description="Type of VecEnv to use (dummy or subproc).",
+ ))
+ # Algorithm and training
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "algo",
+ default_value="sac",
+ description="RL algorithm to use during the training.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "num_threads",
+ default_value="-1",
+ description="Number of threads for PyTorch (-1 to use default).",
+ ))
+ # Random seed
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "seed",
+ default_value="84",
+ description="Random generator seed.",
+ ))
+ # Logging
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_folder",
+ default_value="logs",
+ description="Path to the log directory.",
+ ))
+ # Verbosity
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "verbose",
+ default_value="1",
+ description="Verbose mode (0: no output, 1: INFO).",
+ ))
+ # HER specifics
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_level",
+ default_value="error",
+ description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
+ ))
+
+
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_episodes",
+ default_value="200",
+ description="Number of evaluation episodes.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "exp_id",
+ default_value="0",
+ description="Experiment ID (default: 0: latest, -1: no exp folder).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "load_best",
+ default_value="false",
+ description="Load best model instead of last model if available.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "load_checkpoint",
+ default_value="0",
+ description="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "stochastic",
+ default_value="false",
+ description="Use stochastic actions instead of deterministic.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "reward_log",
+ default_value="reward_logs",
+ description="Where to log reward.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "norm_reward",
+ default_value="false",
+ description="Normalize reward if applicable (trained with VecNormalize)",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "no_render",
+ default_value="true",
+ description="Do not render the environment (useful for tests).",
+ ))
+
+
+
+ env_variables = [
+ SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"),
+ SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2))
+ ]
+
+ return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables)
diff --git a/env_manager/rbs_gym/launch/optimize.launch.py b/env_manager/rbs_gym/launch/optimize.launch.py
new file mode 100644
index 0000000..7e06a9c
--- /dev/null
+++ b/env_manager/rbs_gym/launch/optimize.launch.py
@@ -0,0 +1,519 @@
+from launch import LaunchDescription
+from launch.actions import (
+ DeclareLaunchArgument,
+ IncludeLaunchDescription,
+ OpaqueFunction,
+ TimerAction
+)
+from launch.launch_description_sources import PythonLaunchDescriptionSource
+from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
+from launch_ros.substitutions import FindPackageShare
+from launch_ros.actions import Node
+import os
+from os import cpu_count
+from ament_index_python.packages import get_package_share_directory
+
+def launch_setup(context, *args, **kwargs):
+ # Initialize Arguments
+ robot_type = LaunchConfiguration("robot_type")
+ # General arguments
+ with_gripper_condition = LaunchConfiguration("with_gripper")
+ controllers_file = LaunchConfiguration("controllers_file")
+ cartesian_controllers = LaunchConfiguration("cartesian_controllers")
+ description_package = LaunchConfiguration("description_package")
+ description_file = LaunchConfiguration("description_file")
+ robot_name = LaunchConfiguration("robot_name")
+ start_joint_controller = LaunchConfiguration("start_joint_controller")
+ initial_joint_controller = LaunchConfiguration("initial_joint_controller")
+ launch_simulation = LaunchConfiguration("launch_sim")
+ launch_moveit = LaunchConfiguration("launch_moveit")
+ launch_task_planner = LaunchConfiguration("launch_task_planner")
+ launch_perception = LaunchConfiguration("launch_perception")
+ moveit_config_package = LaunchConfiguration("moveit_config_package")
+ moveit_config_file = LaunchConfiguration("moveit_config_file")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ sim_gazebo = LaunchConfiguration("sim_gazebo")
+ hardware = LaunchConfiguration("hardware")
+ env_manager = LaunchConfiguration("env_manager")
+ launch_controllers = LaunchConfiguration("launch_controllers")
+ gripper_name = LaunchConfiguration("gripper_name")
+
+ # training arguments
+ env = LaunchConfiguration("env")
+ env_kwargs = LaunchConfiguration("env_kwargs")
+ algo = LaunchConfiguration("algo")
+ hyperparams = LaunchConfiguration("hyperparams")
+ n_timesteps = LaunchConfiguration("n_timesteps")
+ num_threads = LaunchConfiguration("num_threads")
+ seed = LaunchConfiguration("seed")
+ preload_replay_buffer = LaunchConfiguration("preload_replay_buffer")
+ log_folder = LaunchConfiguration("log_folder")
+ tensorboard_log = LaunchConfiguration("tensorboard_log")
+ log_interval = LaunchConfiguration("log_interval")
+ uuid = LaunchConfiguration("uuid")
+ eval_episodes = LaunchConfiguration("eval_episodes")
+ verbose = LaunchConfiguration("verbose")
+ truncate_last_trajectory = LaunchConfiguration("truncate_last_trajectory")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ log_level = LaunchConfiguration("log_level")
+ sampler = LaunchConfiguration("sampler")
+ pruner = LaunchConfiguration("pruner")
+ n_trials = LaunchConfiguration("n_trials")
+ n_startup_trials = LaunchConfiguration("n_startup_trials")
+ n_evaluations = LaunchConfiguration("n_evaluations")
+ n_jobs = LaunchConfiguration("n_jobs")
+ storage = LaunchConfiguration("storage")
+ study_name = LaunchConfiguration("study_name")
+
+
+ sim_gazebo = LaunchConfiguration("sim_gazebo")
+ launch_simulation = LaunchConfiguration("launch_sim")
+
+ initial_joint_controllers_file_path = os.path.join(
+ get_package_share_directory('rbs_arm'), 'config', 'controllers.yaml'
+ )
+
+ single_robot_setup = IncludeLaunchDescription(
+ PythonLaunchDescriptionSource([
+ PathJoinSubstitution([
+ FindPackageShare('rbs_bringup'),
+ "launch",
+ "rbs_robot.launch.py"
+ ])
+ ]),
+ launch_arguments={
+ "env_manager": env_manager,
+ "with_gripper": with_gripper_condition,
+ "gripper_name": gripper_name,
+ "controllers_file": controllers_file,
+ "robot_type": robot_type,
+ "controllers_file": initial_joint_controllers_file_path,
+ "cartesian_controllers": cartesian_controllers,
+ "description_package": description_package,
+ "description_file": description_file,
+ "robot_name": robot_name,
+ "start_joint_controller": start_joint_controller,
+ "initial_joint_controller": initial_joint_controller,
+ "launch_simulation": launch_simulation,
+ "launch_moveit": launch_moveit,
+ "launch_task_planner": launch_task_planner,
+ "launch_perception": launch_perception,
+ "moveit_config_package": moveit_config_package,
+ "moveit_config_file": moveit_config_file,
+ "use_sim_time": use_sim_time,
+ "sim_gazebo": sim_gazebo,
+ "hardware": hardware,
+ "launch_controllers": launch_controllers,
+ # "gazebo_gui": gazebo_gui
+ }.items()
+ )
+
+ args = [
+ "--env",
+ env,
+ "--env-kwargs",
+ env_kwargs,
+ "--algo",
+ algo,
+ "--seed",
+ seed,
+ "--num-threads",
+ num_threads,
+ "--n-timesteps",
+ n_timesteps,
+ "--preload-replay-buffer",
+ preload_replay_buffer,
+ "--log-folder",
+ log_folder,
+ "--tensorboard-log",
+ tensorboard_log,
+ "--log-interval",
+ log_interval,
+ "--uuid",
+ uuid,
+ "--optimize-hyperparameters",
+ "True",
+ "--sampler",
+ sampler,
+ "--pruner",
+ pruner,
+ "--n-trials",
+ n_trials,
+ "--n-startup-trials",
+ n_startup_trials,
+ "--n-evaluations",
+ n_evaluations,
+ "--n-jobs",
+ n_jobs,
+ "--storage",
+ storage,
+ "--study-name",
+ study_name,
+ "--eval-episodes",
+ eval_episodes,
+ "--verbose",
+ verbose,
+ "--truncate-last-trajectory",
+ truncate_last_trajectory,
+ "--ros-args",
+ "--log-level",
+ log_level,
+ ]
+
+ rl_task = Node(
+ package="rbs_gym",
+ executable="train",
+ output="log",
+ arguments = args,
+ parameters=[{"use_sim_time": True}]
+ )
+
+
+ delay_robot_control_stack = TimerAction(
+ period=10.0,
+ actions=[single_robot_setup]
+ )
+
+ nodes_to_start = [
+ rl_task,
+ delay_robot_control_stack
+ ]
+ return nodes_to_start
+
+
+def generate_launch_description():
+ declared_arguments = []
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_type",
+ description="Type of robot by name",
+ choices=["rbs_arm","ur3", "ur3e", "ur5", "ur5e", "ur10", "ur10e", "ur16e"],
+ default_value="rbs_arm",
+ )
+ )
+ # General arguments
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "controllers_file",
+ default_value="controllers_gazebosim.yaml",
+ description="YAML file with the controllers configuration.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_package",
+ default_value="rbs_arm",
+ description="Description package with robot URDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom description.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_file",
+ default_value="rbs_arm_modular.xacro",
+ description="URDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_name",
+ default_value="arm0",
+ description="Name for robot, used to apply namespace for specific robot in multirobot setup",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "start_joint_controller",
+ default_value="false",
+ description="Enable headless mode for robot control",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "initial_joint_controller",
+ default_value="joint_trajectory_controller",
+ description="Robot controller to start.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_package",
+ default_value="rbs_arm",
+ description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom moveit config.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_file",
+ default_value="rbs_arm.srdf.xacro",
+ description="MoveIt SRDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_sim_time",
+ default_value="true",
+ description="Make MoveIt to use simulation time.\
+ This is needed for the trajectory planing in simulation.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "gripper_name",
+ default_value="rbs_gripper",
+ choices=["rbs_gripper", ""],
+ description="choose gripper by name (leave empty if hasn't)",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("with_gripper",
+ default_value="true",
+ description="With gripper or not?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("sim_gazebo",
+ default_value="true",
+ description="Gazebo Simulation")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("env_manager",
+ default_value="false",
+ description="Launch env_manager?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_sim",
+ default_value="true",
+ description="Launch simulator (Gazebo)?\
+ Most general arg")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_moveit",
+ default_value="false",
+ description="Launch moveit?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_perception",
+ default_value="false",
+ description="Launch perception?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_task_planner",
+ default_value="false",
+ description="Launch task_planner?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("cartesian_controllers",
+ default_value="true",
+ description="Load cartesian\
+ controllers?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("hardware",
+ choices=["gazebo", "mock"],
+ default_value="gazebo",
+ description="Choose your harware_interface")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("launch_controllers",
+ default_value="true",
+ description="Launch controllers?")
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument("gazebo_gui",
+ default_value="true",
+ description="Launch gazebo with gui?")
+ )
+ # training arguments
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env",
+ default_value="Reach-Gazebo-v0",
+ description="Environment ID",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env_kwargs",
+ default_value="",
+ description="Optional keyword argument to pass to the env constructor.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "vec_env",
+ default_value="dummy",
+ description="Type of VecEnv to use (dummy or subproc).",
+ ))
+ # Algorithm and training
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "algo",
+ default_value="sac",
+ description="RL algorithm to use during the training.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_timesteps",
+ default_value="-1",
+ description="Overwrite the number of timesteps.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "hyperparams",
+ default_value="",
+ description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "num_threads",
+ default_value="-1",
+ description="Number of threads for PyTorch (-1 to use default).",
+ ))
+ # Continue training an already trained agent
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "trained_agent",
+ default_value="",
+ description="Path to a pretrained agent to continue training.",
+ ))
+ # Random seed
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "seed",
+ default_value="84",
+ description="Random generator seed.",
+ ))
+ # Saving of model
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "save_freq",
+ default_value="10000",
+ description="Save the model every n steps (if negative, no checkpoint).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "save_replay_buffer",
+ default_value="False",
+ description="Save the replay buffer too (when applicable).",
+ ))
+ # Pre-load a replay buffer and start training on it
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "preload_replay_buffer",
+ default_value="",
+ description="Path to a replay buffer that should be preloaded before starting the training process.",
+ ))
+ # Logging
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_folder",
+ default_value="logs",
+ description="Path to the log directory.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "tensorboard_log",
+ default_value="tensorboard_logs",
+ description="Tensorboard log dir.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_interval",
+ default_value="-1",
+ description="Override log interval (default: -1, no change).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "uuid",
+ default_value="False",
+ description="Ensure that the run has a unique ID.",
+ ))
+
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "sampler",
+ default_value="tpe",
+ description="Sampler to use when optimizing hyperparameters (random, tpe or skopt).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "pruner",
+ default_value="median",
+ description="Pruner to use when optimizing hyperparameters (halving, median or none).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_trials",
+ default_value="10",
+ description="Number of trials for optimizing hyperparameters.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_startup_trials",
+ default_value="5",
+ description="Number of trials before using optuna sampler.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_evaluations",
+ default_value="2",
+ description="Number of evaluations for hyperparameter optimization.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_jobs",
+ default_value="1",
+ description="Number of parallel jobs when optimizing hyperparameters.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "storage",
+ default_value="",
+ description="Database storage path if distributed optimization should be used.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "study_name",
+ default_value="",
+ description="Study name for distributed optimization.",
+ ))
+
+ # Evaluation
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "eval_freq",
+ default_value="-1",
+ description="Evaluate the agent every n steps (if negative, no evaluation).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "eval_episodes",
+ default_value="5",
+ description="Number of episodes to use for evaluation.",
+ ))
+ # Verbosity
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "verbose",
+ default_value="1",
+ description="Verbose mode (0: no output, 1: INFO).",
+ ))
+ # HER specifics
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "truncate_last_trajectory",
+ default_value="True",
+ description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer."
+ )),
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_level",
+ default_value="error",
+ description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
+ ))
+
+# env_variables = [
+# SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"),
+# SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2))
+# ]
+
+ return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)])
diff --git a/env_manager/rbs_gym/launch/test_env.launch.py b/env_manager/rbs_gym/launch/test_env.launch.py
new file mode 100644
index 0000000..9cfbdde
--- /dev/null
+++ b/env_manager/rbs_gym/launch/test_env.launch.py
@@ -0,0 +1,507 @@
+import os
+from os import cpu_count
+
+import xacro
+import yaml
+from ament_index_python.packages import get_package_share_directory
+from launch import LaunchDescription
+from launch.actions import (
+ DeclareLaunchArgument,
+ IncludeLaunchDescription,
+ OpaqueFunction,
+ SetEnvironmentVariable,
+ TimerAction,
+)
+from launch.launch_description_sources import PythonLaunchDescriptionSource
+from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
+from launch_ros.actions import Node
+from launch_ros.substitutions import FindPackageShare
+from robot_builder.external.ros2_control import ControllerManager
+from robot_builder.parser.urdf import URDF_parser
+
+
+def launch_setup(context, *args, **kwargs):
+ # Initialize Arguments
+ robot_type = LaunchConfiguration("robot_type")
+ # General arguments
+ with_gripper_condition = LaunchConfiguration("with_gripper")
+ description_package = LaunchConfiguration("description_package")
+ description_file = LaunchConfiguration("description_file")
+ use_moveit = LaunchConfiguration("use_moveit")
+ moveit_config_package = LaunchConfiguration("moveit_config_package")
+ moveit_config_file = LaunchConfiguration("moveit_config_file")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ scene_config_file = LaunchConfiguration("scene_config_file").perform(context)
+
+ ee_link_name = LaunchConfiguration("ee_link_name").perform(context)
+ base_link_name = LaunchConfiguration("base_link_name").perform(context)
+
+ control_space = LaunchConfiguration("control_space").perform(context)
+ control_strategy = LaunchConfiguration("control_strategy").perform(context)
+ use_rbs_utils = LaunchConfiguration("use_rbs_utils")
+
+ interactive = LaunchConfiguration("interactive").perform(context)
+
+ # training arguments
+ env = LaunchConfiguration("env")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ log_level = LaunchConfiguration("log_level")
+ env_kwargs = LaunchConfiguration("env_kwargs")
+
+ description_package_abs_path = get_package_share_directory(
+ description_package.perform(context)
+ )
+
+ xacro_file = os.path.join(
+ description_package_abs_path,
+ "urdf",
+ description_file.perform(context),
+ )
+
+ if not scene_config_file == "":
+ config_file = {"config_file": scene_config_file}
+ else:
+ config_file = {}
+
+ description_package_abs_path = get_package_share_directory(
+ description_package.perform(context)
+ )
+
+ controllers_file = os.path.join(
+ description_package_abs_path, "config", "controllers.yaml"
+ )
+
+ xacro_file = os.path.join(
+ description_package_abs_path,
+ "urdf",
+ description_file.perform(context),
+ )
+
+ # xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml"
+ xacro_config_file = os.path.join(
+ description_package_abs_path, "urdf", "xacro_args.yaml"
+ )
+
+ # TODO: hide this to another place
+ # Load xacro_args
+ def param_constructor(loader, node, local_vars):
+ value = loader.construct_scalar(node)
+ return LaunchConfiguration(value).perform(
+ local_vars.get("context", "Launch context if not defined")
+ )
+
+ def variable_constructor(loader, node, local_vars):
+ value = loader.construct_scalar(node)
+ return local_vars.get(value, f"Variable '{value}' not found")
+
+ def load_xacro_args(yaml_file, local_vars):
+ # Get valut from ros2 argument
+ yaml.add_constructor(
+ "!param", lambda loader, node: param_constructor(loader, node, local_vars)
+ )
+
+ # Get value from local variable in this code
+ # The local variable should be initialized before the loader was called
+ yaml.add_constructor(
+ "!variable",
+ lambda loader, node: variable_constructor(loader, node, local_vars),
+ )
+
+ with open(yaml_file, "r") as file:
+ return yaml.load(file, Loader=yaml.FullLoader)
+
+ mappings_data = load_xacro_args(xacro_config_file, locals())
+
+ robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data)
+
+ robot_description_semantic_content = ""
+
+ robot_description_content = robot_description_doc.toprettyxml(indent=" ")
+ robot_description = {"robot_description": robot_description_content}
+
+ # Parse robot and configure controller's file for ControllerManager
+ robot = URDF_parser.load_string(
+ robot_description_content, ee_link_name=ee_link_name
+ )
+ ControllerManager.save_to_yaml(
+ robot, description_package_abs_path, "controllers.yaml"
+ )
+
+ single_robot_setup = IncludeLaunchDescription(
+ PythonLaunchDescriptionSource(
+ [
+ PathJoinSubstitution(
+ [FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"]
+ )
+ ]
+ ),
+ launch_arguments={
+ "with_gripper": with_gripper_condition,
+ "controllers_file": controllers_file,
+ "robot_type": robot_type,
+ "description_package": description_package,
+ "description_file": description_file,
+ "robot_name": robot_type,
+ "use_moveit": "false",
+ "moveit_config_package": moveit_config_package,
+ "moveit_config_file": moveit_config_file,
+ "use_sim_time": "true",
+ "use_skills": "false",
+ "use_controllers": "true",
+ "robot_description": robot_description_content,
+ "robot_description_semantic": robot_description_semantic_content,
+ "base_link_name": base_link_name,
+ "ee_link_name": ee_link_name,
+ "control_space": control_space,
+ "control_strategy": control_strategy,
+ "use_rbs_utils": use_rbs_utils,
+ "interactive_control": "false",
+ }.items(),
+ )
+
+ args = [
+ "--env",
+ env,
+ "--env-kwargs",
+ env_kwargs,
+ "--ros-args",
+ "--log-level",
+ log_level,
+ ]
+
+ rl_task = Node(
+ package="rbs_gym",
+ executable="test_agent",
+ output="log",
+ arguments=args,
+ parameters=[{"use_sim_time": True}, robot_description],
+ )
+
+ clock_bridge = Node(
+ package="ros_gz_bridge",
+ executable="parameter_bridge",
+ arguments=["/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock"],
+ output="screen",
+ )
+
+ delay_robot_control_stack = TimerAction(period=10.0, actions=[single_robot_setup])
+
+ nodes_to_start = [
+ # env,
+ rl_task,
+ clock_bridge,
+ delay_robot_control_stack,
+ ]
+ return nodes_to_start
+
+
+def generate_launch_description():
+ declared_arguments = []
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_type",
+ description="Type of robot by name",
+ choices=[
+ "rbs_arm",
+ "ar4",
+ "ur3",
+ "ur3e",
+ "ur5",
+ "ur5e",
+ "ur10",
+ "ur10e",
+ "ur16e",
+ ],
+ default_value="rbs_arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_package",
+ default_value="rbs_arm",
+ description="Description package with robot URDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom description.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_file",
+ default_value="rbs_arm_modular.xacro",
+ description="URDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_name",
+ default_value="arm0",
+ description="Name for robot, used to apply namespace for specific robot in multirobot setup",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_package",
+ default_value="rbs_arm",
+ description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom moveit config.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_file",
+ default_value="rbs_arm.srdf.xacro",
+ description="MoveIt SRDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_sim_time",
+ default_value="true",
+ description="Make MoveIt to use simulation time.\
+ This is needed for the trajectory planing in simulation.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "with_gripper", default_value="true", description="With gripper or not?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_moveit", default_value="false", description="Launch moveit?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "launch_perception", default_value="false", description="Launch perception?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_controllers",
+ default_value="true",
+ description="Launch controllers?",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "scene_config_file",
+ default_value="",
+ description="Path to a scene configuration file",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "ee_link_name",
+ default_value="",
+ description="End effector name of robot arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "base_link_name",
+ default_value="",
+ description="Base link name if robot arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "control_space",
+ default_value="task",
+ choices=["task", "joint"],
+ description="Specify the control space for the robot (e.g., task space).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "control_strategy",
+ default_value="position",
+ choices=["position", "velocity", "effort"],
+ description="Specify the control strategy (e.g., position control).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "interactive",
+ default_value="true",
+ description="Wheter to run the motion_control_handle controller",
+ ),
+ )
+
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_rbs_utils",
+ default_value="false",
+ description="Wheter to use rbs_utils",
+ ),
+ )
+ # training arguments
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env",
+ default_value="Reach-Gazebo-v0",
+ description="Environment ID",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env_kwargs",
+ default_value="",
+ description="Optional keyword argument to pass to the env constructor.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "vec_env",
+ default_value="dummy",
+ description="Type of VecEnv to use (dummy or subproc).",
+ )
+ )
+ # Algorithm and training
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "algo",
+ default_value="sac",
+ description="RL algorithm to use during the training.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_timesteps",
+ default_value="-1",
+ description="Overwrite the number of timesteps.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "hyperparams",
+ default_value="",
+ description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "num_threads",
+ default_value="-1",
+ description="Number of threads for PyTorch (-1 to use default).",
+ )
+ )
+ # Continue training an already trained agent
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "trained_agent",
+ default_value="",
+ description="Path to a pretrained agent to continue training.",
+ )
+ )
+ # Random seed
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "seed",
+ default_value="-1",
+ description="Random generator seed.",
+ )
+ )
+ # Saving of model
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "save_freq",
+ default_value="10000",
+ description="Save the model every n steps (if negative, no checkpoint).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "save_replay_buffer",
+ default_value="False",
+ description="Save the replay buffer too (when applicable).",
+ )
+ )
+ # Pre-load a replay buffer and start training on it
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "preload_replay_buffer",
+ default_value="",
+ description="Path to a replay buffer that should be preloaded before starting the training process.",
+ )
+ )
+ # Logging
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_folder",
+ default_value="logs",
+ description="Path to the log directory.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "tensorboard_log",
+ default_value="tensorboard_logs",
+ description="Tensorboard log dir.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_interval",
+ default_value="-1",
+ description="Override log interval (default: -1, no change).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "uuid",
+ default_value="False",
+ description="Ensure that the run has a unique ID.",
+ )
+ )
+ # Evaluation
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "eval_freq",
+ default_value="-1",
+ description="Evaluate the agent every n steps (if negative, no evaluation).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "eval_episodes",
+ default_value="5",
+ description="Number of episodes to use for evaluation.",
+ )
+ )
+ # Verbosity
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "verbose",
+ default_value="1",
+ description="Verbose mode (0: no output, 1: INFO).",
+ )
+ )
+ # HER specifics
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "truncate_last_trajectory",
+ default_value="True",
+ description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_level",
+ default_value="error",
+ description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
+ )
+ )
+ cpu_c = cpu_count()
+ if cpu_c is not None:
+ env_variables = [
+ SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"),
+ SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_c // 2)),
+ ]
+ else:
+ env_variables = [
+ SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"),
+ ]
+
+ return LaunchDescription(
+ declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables
+ )
diff --git a/env_manager/rbs_gym/launch/train.launch.py b/env_manager/rbs_gym/launch/train.launch.py
new file mode 100644
index 0000000..abea297
--- /dev/null
+++ b/env_manager/rbs_gym/launch/train.launch.py
@@ -0,0 +1,524 @@
+from launch import LaunchDescription
+from launch.actions import (
+ DeclareLaunchArgument,
+ IncludeLaunchDescription,
+ OpaqueFunction,
+ SetEnvironmentVariable,
+ TimerAction
+)
+from launch.launch_description_sources import PythonLaunchDescriptionSource
+from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
+from launch_ros.substitutions import FindPackageShare
+from launch_ros.actions import Node
+import os
+from os import cpu_count
+from ament_index_python.packages import get_package_share_directory
+import yaml
+import xacro
+
+def launch_setup(context, *args, **kwargs):
+ # Initialize Arguments
+ robot_type = LaunchConfiguration("robot_type")
+ # General arguments
+ with_gripper_condition = LaunchConfiguration("with_gripper")
+ description_package = LaunchConfiguration("description_package")
+ description_file = LaunchConfiguration("description_file")
+ use_moveit = LaunchConfiguration("use_moveit")
+ moveit_config_package = LaunchConfiguration("moveit_config_package")
+ moveit_config_file = LaunchConfiguration("moveit_config_file")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ scene_config_file = LaunchConfiguration("scene_config_file").perform(context)
+
+ ee_link_name = LaunchConfiguration("ee_link_name").perform(context)
+ base_link_name = LaunchConfiguration("base_link_name").perform(context)
+
+ control_space = LaunchConfiguration("control_space").perform(context)
+ control_strategy = LaunchConfiguration("control_strategy").perform(context)
+
+ interactive = LaunchConfiguration("interactive").perform(context)
+
+ # training arguments
+ env = LaunchConfiguration("env")
+ algo = LaunchConfiguration("algo")
+ hyperparams = LaunchConfiguration("hyperparams")
+ n_timesteps = LaunchConfiguration("n_timesteps")
+ num_threads = LaunchConfiguration("num_threads")
+ seed = LaunchConfiguration("seed")
+ trained_agent = LaunchConfiguration("trained_agent")
+ save_freq = LaunchConfiguration("save_freq")
+ save_replay_buffer = LaunchConfiguration("save_replay_buffer")
+ preload_replay_buffer = LaunchConfiguration("preload_replay_buffer")
+ log_folder = LaunchConfiguration("log_folder")
+ tensorboard_log = LaunchConfiguration("tensorboard_log")
+ log_interval = LaunchConfiguration("log_interval")
+ uuid = LaunchConfiguration("uuid")
+ eval_freq = LaunchConfiguration("eval_freq")
+ eval_episodes = LaunchConfiguration("eval_episodes")
+ verbose = LaunchConfiguration("verbose")
+ truncate_last_trajectory = LaunchConfiguration("truncate_last_trajectory")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ log_level = LaunchConfiguration("log_level")
+ env_kwargs = LaunchConfiguration("env_kwargs")
+
+ track = LaunchConfiguration("track")
+
+ description_package_abs_path = get_package_share_directory(
+ description_package.perform(context)
+ )
+
+
+ xacro_file = os.path.join(
+ description_package_abs_path,
+ "urdf",
+ description_file.perform(context),
+ )
+
+ if not scene_config_file == "":
+ config_file = {"config_file": scene_config_file}
+ else:
+ config_file = {}
+
+ description_package_abs_path = get_package_share_directory(
+ description_package.perform(context)
+ )
+
+ controllers_file = os.path.join(
+ description_package_abs_path, "config", "controllers.yaml"
+ )
+
+ xacro_file = os.path.join(
+ description_package_abs_path,
+ "urdf",
+ description_file.perform(context),
+ )
+
+ # xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml"
+ xacro_config_file = os.path.join(
+ description_package_abs_path,
+ "urdf",
+ "xacro_args.yaml"
+ )
+
+
+ # TODO: hide this to another place
+ # Load xacro_args
+ def param_constructor(loader, node, local_vars):
+ value = loader.construct_scalar(node)
+ return LaunchConfiguration(value).perform(
+ local_vars.get("context", "Launch context if not defined")
+ )
+
+ def variable_constructor(loader, node, local_vars):
+ value = loader.construct_scalar(node)
+ return local_vars.get(value, f"Variable '{value}' not found")
+
+ def load_xacro_args(yaml_file, local_vars):
+ # Get valut from ros2 argument
+ yaml.add_constructor(
+ "!param", lambda loader, node: param_constructor(loader, node, local_vars)
+ )
+
+ # Get value from local variable in this code
+ # The local variable should be initialized before the loader was called
+ yaml.add_constructor(
+ "!variable",
+ lambda loader, node: variable_constructor(loader, node, local_vars),
+ )
+
+ with open(yaml_file, "r") as file:
+ return yaml.load(file, Loader=yaml.FullLoader)
+
+ mappings_data = load_xacro_args(xacro_config_file, locals())
+
+ robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data)
+
+ robot_description_semantic_content = ""
+
+
+ robot_description_content = robot_description_doc.toprettyxml(indent=" ")
+ robot_description = {"robot_description": robot_description_content}
+
+ single_robot_setup = IncludeLaunchDescription(
+ PythonLaunchDescriptionSource(
+ [
+ PathJoinSubstitution(
+ [FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"]
+ )
+ ]
+ ),
+ launch_arguments={
+ "with_gripper": with_gripper_condition,
+ "controllers_file": controllers_file,
+ "robot_type": robot_type,
+ "description_package": description_package,
+ "description_file": description_file,
+ "robot_name": robot_type,
+ "use_moveit": use_moveit,
+ "moveit_config_package": moveit_config_package,
+ "moveit_config_file": moveit_config_file,
+ "use_sim_time": use_sim_time,
+ "use_controllers": "true",
+ "robot_description": robot_description_content,
+ "robot_description_semantic": robot_description_semantic_content,
+ "base_link_name": base_link_name,
+ "ee_link_name": ee_link_name,
+ "control_space": control_space,
+ "control_strategy": control_strategy,
+ "interactive_control": interactive,
+ }.items(),
+ )
+
+ args = [
+ "--env",
+ env,
+ "--env-kwargs",
+ env_kwargs,
+ "--algo",
+ algo,
+ "--hyperparams",
+ hyperparams,
+ "--n-timesteps",
+ n_timesteps,
+ "--num-threads",
+ num_threads,
+ "--seed",
+ seed,
+ "--trained-agent",
+ trained_agent,
+ "--save-freq",
+ save_freq,
+ "--save-replay-buffer",
+ save_replay_buffer,
+ "--preload-replay-buffer",
+ preload_replay_buffer,
+ "--log-folder",
+ log_folder,
+ "--tensorboard-log",
+ tensorboard_log,
+ "--log-interval",
+ log_interval,
+ "--uuid",
+ uuid,
+ "--eval-freq",
+ eval_freq,
+ "--eval-episodes",
+ eval_episodes,
+ "--verbose",
+ verbose,
+ "--track",
+ track,
+ "--truncate-last-trajectory",
+ truncate_last_trajectory,
+ "--ros-args",
+ "--log-level",
+ log_level,
+ ]
+
+ clock_bridge = Node(
+ package='ros_gz_bridge',
+ executable='parameter_bridge',
+ arguments=['/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock'],
+ output='screen')
+
+ rl_task = Node(
+ package="rbs_gym",
+ executable="train",
+ output="log",
+ arguments=args,
+ parameters=[{"use_sim_time": True}]
+ )
+
+
+ delay_robot_control_stack = TimerAction(
+ period=20.0,
+ actions=[single_robot_setup]
+ )
+
+ nodes_to_start = [
+ # env,
+ rl_task,
+ clock_bridge,
+ delay_robot_control_stack
+ ]
+ return nodes_to_start
+
+
+def generate_launch_description():
+ declared_arguments = []
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_type",
+ description="Type of robot by name",
+ choices=[
+ "rbs_arm",
+ "ar4",
+ "ur3",
+ "ur3e",
+ "ur5",
+ "ur5e",
+ "ur10",
+ "ur10e",
+ "ur16e",
+ ],
+ default_value="rbs_arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_package",
+ default_value="rbs_arm",
+ description="Description package with robot URDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom description.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_file",
+ default_value="rbs_arm_modular.xacro",
+ description="URDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_name",
+ default_value="arm0",
+ description="Name for robot, used to apply namespace for specific robot in multirobot setup",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_package",
+ default_value="rbs_arm",
+ description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom moveit config.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_file",
+ default_value="rbs_arm.srdf.xacro",
+ description="MoveIt SRDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_sim_time",
+ default_value="true",
+ description="Make MoveIt to use simulation time.\
+ This is needed for the trajectory planing in simulation.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "with_gripper", default_value="true", description="With gripper or not?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_moveit", default_value="false", description="Launch moveit?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "launch_perception", default_value="false", description="Launch perception?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_controllers",
+ default_value="true",
+ description="Launch controllers?",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "scene_config_file",
+ default_value="",
+ description="Path to a scene configuration file",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "ee_link_name",
+ default_value="",
+ description="End effector name of robot arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "base_link_name",
+ default_value="",
+ description="Base link name if robot arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "control_space",
+ default_value="task",
+ choices=["task", "joint"],
+ description="Specify the control space for the robot (e.g., task space).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "control_strategy",
+ default_value="position",
+ choices=["position", "velocity", "effort"],
+ description="Specify the control strategy (e.g., position control).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "interactive",
+ default_value="true",
+ description="Wheter to run the motion_control_handle controller",
+ ),
+ )
+ # training arguments
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env",
+ default_value="Reach-Gazebo-v0",
+ description="Environment ID",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "env_kwargs",
+ default_value="",
+ description="Optional keyword argument to pass to the env constructor.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "vec_env",
+ default_value="dummy",
+ description="Type of VecEnv to use (dummy or subproc).",
+ ))
+ # Algorithm and training
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "algo",
+ default_value="sac",
+ description="RL algorithm to use during the training.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "n_timesteps",
+ default_value="-1",
+ description="Overwrite the number of timesteps.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "hyperparams",
+ default_value="",
+ description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "num_threads",
+ default_value="-1",
+ description="Number of threads for PyTorch (-1 to use default).",
+ ))
+ # Continue training an already trained agent
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "trained_agent",
+ default_value="",
+ description="Path to a pretrained agent to continue training.",
+ ))
+ # Random seed
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "seed",
+ default_value="-1",
+ description="Random generator seed.",
+ ))
+ # Saving of model
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "save_freq",
+ default_value="10000",
+ description="Save the model every n steps (if negative, no checkpoint).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "save_replay_buffer",
+ default_value="False",
+ description="Save the replay buffer too (when applicable).",
+ ))
+ # Pre-load a replay buffer and start training on it
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "preload_replay_buffer",
+ default_value="",
+ description="Path to a replay buffer that should be preloaded before starting the training process.",
+ ))
+ # Logging
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_folder",
+ default_value="logs",
+ description="Path to the log directory.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "tensorboard_log",
+ default_value="tensorboard_logs",
+ description="Tensorboard log dir.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_interval",
+ default_value="-1",
+ description="Override log interval (default: -1, no change).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "uuid",
+ default_value="False",
+ description="Ensure that the run has a unique ID.",
+ ))
+ # Evaluation
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "eval_freq",
+ default_value="-1",
+ description="Evaluate the agent every n steps (if negative, no evaluation).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "eval_episodes",
+ default_value="5",
+ description="Number of episodes to use for evaluation.",
+ ))
+ # Verbosity
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "verbose",
+ default_value="1",
+ description="Verbose mode (0: no output, 1: INFO).",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "truncate_last_trajectory",
+ default_value="True",
+ description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer."
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "log_level",
+ default_value="error",
+ description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
+ ))
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "track",
+ default_value="true",
+ description="The level of logging that is applied to all ROS 2 nodes launched by this script.",
+ ))
+
+ env_variables = [
+ SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"),
+ SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2))
+ ]
+
+ return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables)
diff --git a/env_manager/rbs_gym/package.nix b/env_manager/rbs_gym/package.nix
new file mode 100644
index 0000000..6411afe
--- /dev/null
+++ b/env_manager/rbs_gym/package.nix
@@ -0,0 +1,20 @@
+# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0
+
+# Copyright 2025 None
+# Distributed under the terms of the Apache-2.0 license
+
+{ lib, buildRosPackage, ament-copyright, ament-flake8, ament-pep257, pythonPackages }:
+buildRosPackage rec {
+ pname = "ros-jazzy-rbs-gym";
+ version = "0.0.0";
+
+ src = ./.;
+
+ buildType = "ament_python";
+ checkInputs = [ ament-copyright ament-flake8 ament-pep257 pythonPackages.pytest ];
+
+ meta = {
+ description = "TODO: Package description";
+ license = with lib.licenses; [ asl20 ];
+ };
+}
diff --git a/env_manager/rbs_gym/package.xml b/env_manager/rbs_gym/package.xml
new file mode 100644
index 0000000..c11a38c
--- /dev/null
+++ b/env_manager/rbs_gym/package.xml
@@ -0,0 +1,18 @@
+
+
+
+ rbs_gym
+ 0.0.0
+ TODO: Package description
+ narmak
+ Apache-2.0
+
+ ament_copyright
+ ament_flake8
+ ament_pep257
+ python3-pytest
+
+
+ ament_python
+
+
diff --git a/env_manager/rbs_gym/rbs_gym/__init__.py b/env_manager/rbs_gym/rbs_gym/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/env_manager/rbs_gym/rbs_gym/envs/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/__init__.py
new file mode 100644
index 0000000..f0db77a
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/__init__.py
@@ -0,0 +1,184 @@
+# Note: The `open3d` and `stable_baselines3` modules must be imported prior to `gym_gz`
+from env_manager.models.configs import (
+ RobotData,
+ SceneData,
+ SphereObjectData,
+ TerrainData,
+)
+from env_manager.models.configs.objects import ObjectRandomizerData
+from env_manager.models.configs.robot import RobotRandomizerData
+from env_manager.models.configs.scene import PluginsData
+# import open3d # isort:skip
+import stable_baselines3 # isort:skip
+
+# Note: If installed, `tensorflow` module must be imported before `gym_gz`/`scenario`
+# Otherwise, protobuf version incompatibility will cause an error
+try:
+ from importlib.util import find_spec
+
+ if find_spec("tensorflow") is not None:
+ import tensorflow
+except:
+ pass
+
+from os import environ, path
+from typing import Any, Dict, Tuple
+
+import numpy as np
+from ament_index_python.packages import get_package_share_directory
+from gymnasium.envs.registration import register
+from rbs_assets_library import get_world_file
+
+from rbs_gym.utils.utils import str2bool
+
+from . import tasks
+
+######################
+# Runtime Entrypoint #
+######################
+
+RBS_ENVS_TASK_ENTRYPOINT: str = "gym_gz.runtimes.gazebo_runtime:GazeboRuntime"
+
+
+###################
+# Robot Specifics #
+###################
+# Default robot model to use in the tasks where robot can be static
+RBS_ENVS_ROBOT_MODEL: str = "rbs_arm"
+
+######################
+# Datasets and paths #
+######################
+# Path to directory containing base SDF worlds
+RBS_ENVS_WORLDS_DIR: str = path.join(get_package_share_directory("rbs_gym"), "worlds")
+
+
+###########
+# Presets #
+###########
+# Gravity preset for Earth
+# GRAVITY_EARTH: Tuple[float, float, float] = (0.0, 0.0, -9.80665)
+# GRAVITY_EARTH_STD: Tuple[float, float, float] = (0.0, 0.0, 0.0232)
+
+############################
+# Additional Configuration #
+############################
+# BROADCAST_GUI: bool = str2bool(
+# environ.get("RBS_ENVS_BROADCAST_INTERACTIVE_GUI", default=True)
+# )
+
+
+#########
+# Reach #
+#########
+REACH_MAX_EPISODE_STEPS: int = 100
+REACH_KWARGS: Dict[str, Any] = {
+ "agent_rate": 4.0,
+ "robot_model": RBS_ENVS_ROBOT_MODEL,
+ "workspace_frame_id": "world",
+ "workspace_centre": (-0.45, 0.0, 0.35),
+ "workspace_volume": (0.7, 0.7, 0.7),
+ "ignore_new_actions_while_executing": False,
+ "use_servo": True,
+ "scaling_factor_translation": 8.0,
+ "scaling_factor_rotation": 3.0,
+ "restrict_position_goal_to_workspace": False,
+ "enable_gripper": False,
+ "sparse_reward": False,
+ "collision_reward": -10.0,
+ "act_quick_reward": -0.01,
+ "required_accuracy": 0.05,
+ "num_threads": 3,
+}
+REACH_KWARGS_SIM: Dict[str, Any] = {
+ "physics_rate": 100,
+ "real_time_factor": float(np.finfo(np.float32).max),
+ "world": get_world_file("ungravitational"),
+}
+
+
+REACH_RANDOMIZER: str = "rbs_gym.envs.randomizers:ManipulationGazeboEnvRandomizer"
+
+SCENE_CONFIGURATION: SceneData = SceneData(
+ physics_rollouts_num=0,
+ robot=RobotData(
+ name="rbs_arm",
+ joint_positions=[0.0, 0.5, 3.14159, 1.5, 0.0, 1.4, 0.0],
+ with_gripper=True,
+ gripper_joint_positions=0.00,
+ randomizer=RobotRandomizerData(joint_positions=True),
+ ),
+ objects=[
+ SphereObjectData(
+ name="sphere",
+ type="sphere",
+ relative_to="base_link",
+ position=(0.0, 0.3, 0.5),
+ static=True,
+ collision=False,
+ color=(0.0, 1.0, 0.0, 0.8),
+ randomize=ObjectRandomizerData(random_pose=True, models_rollouts_num=2),
+ )
+ ],
+ plugins=PluginsData(
+ scene_broadcaster=True, user_commands=True, fts_broadcaster=True
+ ),
+)
+
+
+# Task
+register(
+ id="Reach-v0",
+ entry_point=RBS_ENVS_TASK_ENTRYPOINT,
+ max_episode_steps=REACH_MAX_EPISODE_STEPS,
+ kwargs={
+ "task_cls": tasks.Reach,
+ **REACH_KWARGS,
+ },
+)
+register(
+ id="Reach-ColorImage-v0",
+ entry_point=RBS_ENVS_TASK_ENTRYPOINT,
+ max_episode_steps=REACH_MAX_EPISODE_STEPS,
+ kwargs={
+ "task_cls": tasks.ReachColorImage,
+ **REACH_KWARGS,
+ },
+)
+register(
+ id="Reach-DepthImage-v0",
+ entry_point=RBS_ENVS_TASK_ENTRYPOINT,
+ max_episode_steps=REACH_MAX_EPISODE_STEPS,
+ kwargs={
+ "task_cls": tasks.ReachDepthImage,
+ **REACH_KWARGS,
+ },
+)
+
+# Gazebo wrapper
+register(
+ id="Reach-Gazebo-v0",
+ entry_point=REACH_RANDOMIZER,
+ max_episode_steps=REACH_MAX_EPISODE_STEPS,
+ kwargs={"env": "Reach-v0", **REACH_KWARGS_SIM, "scene_args": SCENE_CONFIGURATION},
+)
+register(
+ id="Reach-ColorImage-Gazebo-v0",
+ entry_point=REACH_RANDOMIZER,
+ max_episode_steps=REACH_MAX_EPISODE_STEPS,
+ kwargs={
+ "env": "Reach-ColorImage-v0",
+ **REACH_KWARGS_SIM,
+ "scene_args": SCENE_CONFIGURATION,
+ },
+)
+register(
+ id="Reach-DepthImage-Gazebo-v0",
+ entry_point=REACH_RANDOMIZER,
+ max_episode_steps=REACH_MAX_EPISODE_STEPS,
+ kwargs={
+ "env": "Reach-DepthImage-v0",
+ **REACH_KWARGS_SIM,
+ "scene_args": SCENE_CONFIGURATION,
+ },
+)
diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/control/__init__.py
new file mode 100644
index 0000000..b551c3d
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/control/__init__.py
@@ -0,0 +1,3 @@
+from .cartesian_force_controller import CartesianForceController
+from .grippper_controller import GripperController
+from .joint_effort_controller import JointEffortController
diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_force_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_force_controller.py
new file mode 100644
index 0000000..3c60500
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_force_controller.py
@@ -0,0 +1,41 @@
+from typing import Optional
+from geometry_msgs.msg import WrenchStamped
+from rclpy.node import Node
+from rclpy.parameter import Parameter
+
+class CartesianForceController:
+ def __init__(self, node, namespace: Optional[str] = "") -> None:
+ self.node = node
+ self.publisher = node.create_publisher(WrenchStamped,
+ namespace + "/cartesian_force_controller/target_wrench", 10)
+ self.timer = node.create_timer(0.1, self.timer_callback)
+ self.publish = False
+ self._target_wrench = WrenchStamped()
+
+ @property
+ def target_wrench(self) -> WrenchStamped:
+ return self._target_wrench
+
+ @target_wrench.setter
+ def target_wrench(self, wrench: WrenchStamped):
+ self._target_wrench = wrench
+
+ def timer_callback(self):
+ if self.publish:
+ self.publisher.publish(self._target_wrench)
+
+
+class CartesianForceControllerStandalone(Node, CartesianForceController):
+ def __init__(self, node_name:str = "rbs_gym_controller", use_sim_time: bool = True):
+
+ try:
+ rclpy.init()
+ except Exception as e:
+ if not rclpy.ok():
+ sys.exit(f"ROS 2 context could not be initialised: {e}")
+
+ Node.__init__(self, node_name)
+ self.set_parameters(
+ [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)]
+ )
+ CartesianForceController.__init__(self, node=self)
diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_velocity_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_velocity_controller.py
new file mode 100644
index 0000000..b5a12c0
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_velocity_controller.py
@@ -0,0 +1,121 @@
+import rclpy
+from rclpy.node import Node
+import numpy as np
+import quaternion
+from geometry_msgs.msg import Twist
+from geometry_msgs.msg import PoseStamped
+import tf2_ros
+import sys
+import time
+import threading
+import os
+
+
+class VelocityController:
+ """Convert Twist messages to PoseStamped
+
+ Use this node to integrate twist messages into a moving target pose in
+ Cartesian space. An initial TF lookup assures that the target pose always
+ starts at the robot's end-effector.
+ """
+
+ def __init__(self, node: Node, topic_pose: str, topic_twist: str, base_frame: str, ee_frame: str):
+ self.node = node
+
+ self._frame_id = base_frame
+ self._end_effector = ee_frame
+
+ self.tf_buffer = tf2_ros.Buffer()
+ self.tf_listener = tf2_ros.TransformListener(self.tf_buffer, self)
+ self.rot = np.quaternion(0, 0, 0, 1)
+ self.pos = [0, 0, 0]
+
+ self.pub = node.create_publisher(PoseStamped, topic_pose, 3)
+ self.sub = node.create_subscription(Twist, topic_twist, self.twist_cb, 1)
+ self.last = time.time()
+
+ self.startup_done = False
+ period = 1.0 / node.declare_parameter("publishing_rate", 100).value
+ self.timer = node.create_timer(period, self.publish)
+ self.publish_it = False
+
+ self.thread = threading.Thread(target=self.startup, daemon=True)
+ self.thread.start()
+
+ def startup(self):
+ """Make sure to start at the robot's current pose"""
+ # Wait until we entered spinning in the main thread.
+ time.sleep(1)
+ try:
+ start = self.tf_buffer.lookup_transform(
+ target_frame=self._frame_id,
+ source_frame=self._end_effector,
+ time=rclpy.time.Time(),
+ )
+
+ except (
+ tf2_ros.InvalidArgumentException,
+ tf2_ros.LookupException,
+ tf2_ros.ConnectivityException,
+ tf2_ros.ExtrapolationException,
+ ) as e:
+ print(f"Startup failed: {e}")
+ os._exit(1)
+
+ self.pos[0] = start.transform.translation.x
+ self.pos[1] = start.transform.translation.y
+ self.pos[2] = start.transform.translation.z
+ self.rot.x = start.transform.rotation.x
+ self.rot.y = start.transform.rotation.y
+ self.rot.z = start.transform.rotation.z
+ self.rot.w = start.transform.rotation.w
+ self.startup_done = True
+
+ def twist_cb(self, data):
+ """Numerically integrate twist message into a pose
+
+ Use global self.frame_id as reference for the navigation commands.
+ """
+ now = time.time()
+ dt = now - self.last
+ self.last = now
+
+ # Position update
+ self.pos[0] += data.linear.x * dt
+ self.pos[1] += data.linear.y * dt
+ self.pos[2] += data.linear.z * dt
+
+ # Orientation update
+ wx = data.angular.x
+ wy = data.angular.y
+ wz = data.angular.z
+
+ _, q = quaternion.integrate_angular_velocity(
+ lambda _: (wx, wy, wz), 0, dt, self.rot
+ )
+
+ self.rot = q[-1] # the last one is after dt passed
+
+ def publish(self):
+ if not self.startup_done:
+ return
+ if not self.publish_it:
+ return
+ try:
+ msg = PoseStamped()
+ msg.header.stamp = self.get_clock().now().to_msg()
+ msg.header.frame_id = self.frame_id
+ msg.pose.position.x = self.pos[0]
+ msg.pose.position.y = self.pos[1]
+ msg.pose.position.z = self.pos[2]
+ msg.pose.orientation.x = self.rot.x
+ msg.pose.orientation.y = self.rot.y
+ msg.pose.orientation.z = self.rot.z
+ msg.pose.orientation.w = self.rot.w
+
+ self.pub.publish(msg)
+ except Exception:
+ # Swallow 'publish() to closed topic' error.
+ # This rarely happens on killing this node.
+ pass
+
diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/grippper_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/grippper_controller.py
new file mode 100644
index 0000000..79a9e2a
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/control/grippper_controller.py
@@ -0,0 +1,50 @@
+from typing import Optional
+from control_msgs.action import GripperCommand
+from rclpy.action import ActionClient
+
+
+class GripperController:
+ def __init__(self, node,
+ open_position: Optional[float] = 0.0,
+ close_position: Optional[float] = 0.0,
+ max_effort: Optional[float] = 0.0,
+ namespace: Optional[str] = ""):
+ self._action_client = ActionClient(node, GripperCommand,
+ namespace + "/gripper_controller/gripper_cmd")
+ self._open_position = open_position
+ self._close_position = close_position
+ self._max_effort = max_effort
+ self.is_executed = False
+
+ def open(self):
+ self.send_goal(self._open_position)
+
+ def close(self):
+ self.send_goal(self._close_position)
+
+ def send_goal(self, goal: float):
+ goal_msg = GripperCommand.Goal()
+ goal_msg._command.position = goal
+ goal_msg._command.max_effort = self._max_effort
+ self._action_client.wait_for_server()
+ self._send_goal_future = self._action_client.send_goal_async(goal_msg)
+ self._send_goal_future.add_done_callback(self.goal_response_callback)
+
+ def goal_response_callback(self, future):
+ goal_handle = future.result()
+ if not goal_handle.accepted:
+ self.get_logger().info('Goal rejected :(')
+ return
+
+ self.get_logger().info('Goal accepted :)')
+
+ self._get_result_future = goal_handle.get_result_async()
+ self._get_result_future.add_done_callback(self.get_result_callback)
+
+ def get_result_callback(self, future):
+ result = future.result().result
+ self.get_logger().info(f"Gripper position: {result.position}")
+
+ def wait_until_executed(self):
+ while not self.is_executed:
+ pass
diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/joint_effort_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/joint_effort_controller.py
new file mode 100644
index 0000000..b58ff8a
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/control/joint_effort_controller.py
@@ -0,0 +1,25 @@
+from typing import Optional
+from std_msgs.msg import Float64MultiArray
+
+class JointEffortController:
+ def __init__(self, node, namespace: Optional[str] = "") -> None:
+ self.node = node
+ self.publisher = node.create_publisher(Float64MultiArray,
+ namespace + "/joint_effort_controller/commands", 10)
+ # self.timer = node.create_timer(0.1, self.timer_callback)
+ # self.publish = True
+ self._effort_array = Float64MultiArray()
+
+ @property
+ def target_effort(self) -> Float64MultiArray:
+ return self._effort_array
+
+ @target_effort.setter
+ def target_effort(self, data: Float64MultiArray):
+ self._effort_array = data
+
+ # def timer_callback(self):
+ # if self.publish:
+ # self.publisher.publish(self._target_wrench)
+
+
diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/observation/__init__.py
new file mode 100644
index 0000000..50b9243
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/observation/__init__.py
@@ -0,0 +1,4 @@
+from .camera_subscriber import CameraSubscriber, CameraSubscriberStandalone
+from .twist_subscriber import TwistSubscriber
+from .joint_states import JointStates
+# from .octree import OctreeCreator
diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/camera_subscriber.py b/env_manager/rbs_gym/rbs_gym/envs/observation/camera_subscriber.py
new file mode 100644
index 0000000..a60df57
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/observation/camera_subscriber.py
@@ -0,0 +1,118 @@
+import sys
+from threading import Lock, Thread
+from typing import Optional, Union
+
+import rclpy
+from rclpy.callback_groups import CallbackGroup
+from rclpy.executors import SingleThreadedExecutor
+from rclpy.node import Node
+from rclpy.parameter import Parameter
+from rclpy.qos import (
+ QoSDurabilityPolicy,
+ QoSHistoryPolicy,
+ QoSProfile,
+ QoSReliabilityPolicy,
+)
+from sensor_msgs.msg import Image, PointCloud2
+
+
+class CameraSubscriber:
+ def __init__(
+ self,
+ node: Node,
+ topic: str,
+ is_point_cloud: bool,
+ callback_group: Optional[CallbackGroup] = None,
+ ):
+
+ self._node = node
+
+ # Prepare the subscriber
+ if is_point_cloud:
+ camera_msg_type = PointCloud2
+ else:
+ camera_msg_type = Image
+ self.__observation = camera_msg_type()
+ self._node.create_subscription(
+ msg_type=camera_msg_type,
+ topic=topic,
+ callback=self.observation_callback,
+ qos_profile=QoSProfile(
+ reliability=QoSReliabilityPolicy.RELIABLE,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ history=QoSHistoryPolicy.KEEP_LAST,
+ depth=1,
+ ),
+ callback_group=callback_group,
+ )
+ self.__observation_mutex = Lock()
+ self.__new_observation_available = False
+
+ def observation_callback(self, msg):
+ """
+ Callback for getting observation.
+ """
+
+ self.__observation_mutex.acquire()
+ self.__observation = msg
+ self.__new_observation_available = True
+ self._node.get_logger().debug("New observation received.")
+ self.__observation_mutex.release()
+
+ def get_observation(self) -> Union[PointCloud2, Image]:
+ """
+ Get the last received observation.
+ """
+
+ self.__observation_mutex.acquire()
+ observation = self.__observation
+ self.__observation_mutex.release()
+ return observation
+
+ def reset_new_observation_checker(self):
+ """
+ Reset checker of new observations, i.e. `self.new_observation_available()`
+ """
+
+ self.__observation_mutex.acquire()
+ self.__new_observation_available = False
+ self.__observation_mutex.release()
+
+ @property
+ def new_observation_available(self):
+ """
+ Check if new observation is available since `self.reset_new_observation_checker()` was called
+ """
+
+ return self.__new_observation_available
+
+
+class CameraSubscriberStandalone(Node, CameraSubscriber):
+ def __init__(
+ self,
+ topic: str,
+ is_point_cloud: bool,
+ node_name: str = "rbs_gym_camera_sub",
+ use_sim_time: bool = True,
+ ):
+
+ try:
+ rclpy.init()
+ except Exception as e:
+ if not rclpy.ok():
+ sys.exit(f"ROS 2 context could not be initialised: {e}")
+
+ Node.__init__(self, node_name)
+ self.set_parameters(
+ [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)]
+ )
+
+ CameraSubscriber.__init__(
+ self, node=self, topic=topic, is_point_cloud=is_point_cloud
+ )
+
+ # Spin the node in a separate thread
+ self._executor = SingleThreadedExecutor()
+ self._executor.add_node(self)
+ self._executor_thread = Thread(target=self._executor.spin, daemon=True, args=())
+ self._executor_thread.start()
diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/joint_states.py b/env_manager/rbs_gym/rbs_gym/envs/observation/joint_states.py
new file mode 100644
index 0000000..bc1a097
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/observation/joint_states.py
@@ -0,0 +1,107 @@
+from array import array
+from threading import Lock
+from typing import Optional
+
+from rclpy.callback_groups import CallbackGroup
+from rclpy.node import Node
+from rclpy.qos import (
+ QoSDurabilityPolicy,
+ QoSHistoryPolicy,
+ QoSProfile,
+ QoSReliabilityPolicy,
+)
+
+from sensor_msgs.msg import JointState
+
+class JointStates:
+ def __init__(
+ self,
+ node: Node,
+ topic: str,
+ callback_group: Optional[CallbackGroup] = None,
+ ):
+
+ self._node = node
+
+ self.__observation = JointState()
+ self._node.create_subscription(
+ msg_type=JointState,
+ topic=topic,
+ callback=self.observation_callback,
+ qos_profile=QoSProfile(
+ reliability=QoSReliabilityPolicy.RELIABLE,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ history=QoSHistoryPolicy.KEEP_LAST,
+ depth=1,
+ ),
+ callback_group=callback_group,
+ )
+ self.__observation_mutex = Lock()
+ self.__new_observation_available = False
+
+ self.__observation.position
+
+ def observation_callback(self, msg):
+ """
+ Callback for getting observation.
+ """
+
+ self.__observation_mutex.acquire()
+ self.__observation = msg
+ self.__new_observation_available = True
+ self._node.get_logger().debug("New observation received.")
+ self.__observation_mutex.release()
+
+ def get_observation(self) -> JointState:
+ """
+ Get the last received observation.
+ """
+ self.__observation_mutex.acquire()
+ observation = self.__observation
+ self.__observation_mutex.release()
+ return observation
+
+ def get_positions(self) -> array:
+ """
+ Get the last recorded observation position
+ """
+ self.__observation_mutex.acquire()
+ observation = self.__observation.position
+ self.__observation_mutex.release()
+ return observation
+
+ def get_velocities(self) -> array:
+ """
+ Get the last recorded observation velocity
+ """
+ self.__observation_mutex.acquire()
+ observation = self.__observation.velocity
+ self.__observation_mutex.release()
+ return observation
+
+ def get_efforts(self) -> array:
+ """
+ Get the last recorded observation effort
+ """
+ self.__observation_mutex.acquire()
+ observation = self.__observation.effort
+ self.__observation_mutex.release()
+ return observation
+
+ def reset_new_observation_checker(self):
+ """
+ Reset checker of new observations, i.e. `self.new_observation_available()`
+ """
+
+ self.__observation_mutex.acquire()
+ self.__new_observation_available = False
+ self.__observation_mutex.release()
+
+ @property
+ def new_observation_available(self):
+ """
+ Check if new observation is available since `self.reset_new_observation_checker()` was called
+ """
+
+ return self.__new_observation_available
+
diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/octree.py b/env_manager/rbs_gym/rbs_gym/envs/observation/octree.py
new file mode 100644
index 0000000..55caaaf
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/observation/octree.py
@@ -0,0 +1,211 @@
+from typing import List, Tuple
+
+import numpy as np
+import ocnn
+import open3d
+import torch
+from rclpy.node import Node
+from sensor_msgs.msg import PointCloud2
+
+from env_manager.utils import Tf2Listener, conversions
+
+
+class OctreeCreator:
+ def __init__(
+ self,
+ node: Node,
+ tf2_listener: Tf2Listener,
+ reference_frame_id: str,
+ min_bound: Tuple[float, float, float] = (-1.0, -1.0, -1.0),
+ max_bound: Tuple[float, float, float] = (1.0, 1.0, 1.0),
+ include_color: bool = False,
+ # Note: For efficiency, the first channel of RGB is used for intensity
+ include_intensity: bool = False,
+ depth: int = 4,
+ full_depth: int = 2,
+ adaptive: bool = False,
+ adp_depth: int = 4,
+ normals_radius: float = 0.05,
+ normals_max_nn: int = 10,
+ node_dis: bool = True,
+ node_feature: bool = False,
+ split_label: bool = False,
+ th_normal: float = 0.1,
+ th_distance: float = 2.0,
+ extrapolate: bool = False,
+ save_pts: bool = False,
+ key2xyz: bool = False,
+ debug_draw: bool = False,
+ debug_write_octree: bool = False,
+ ):
+
+ self._node = node
+
+ # Listener of tf2 transforms is shared with the owner
+ self.__tf2_listener = tf2_listener
+
+ # Parameters
+ self._reference_frame_id = reference_frame_id
+ self._min_bound = min_bound
+ self._max_bound = max_bound
+ self._include_color = include_color
+ self._include_intensity = include_intensity
+ self._normals_radius = normals_radius
+ self._normals_max_nn = normals_max_nn
+ self._debug_draw = debug_draw
+ self._debug_write_octree = debug_write_octree
+
+ # Create a converter between points and octree
+ self._points_to_octree = ocnn.Points2Octree(
+ depth=depth,
+ full_depth=full_depth,
+ node_dis=node_dis,
+ node_feature=node_feature,
+ split_label=split_label,
+ adaptive=adaptive,
+ adp_depth=adp_depth,
+ th_normal=th_normal,
+ th_distance=th_distance,
+ extrapolate=extrapolate,
+ save_pts=save_pts,
+ key2xyz=key2xyz,
+ bb_min=min_bound,
+ bb_max=max_bound,
+ )
+
+ def __call__(self, ros_point_cloud2: PointCloud2) -> torch.Tensor:
+
+ # Convert to Open3D PointCloud
+ open3d_point_cloud = conversions.pointcloud2_to_open3d(
+ ros_point_cloud2=ros_point_cloud2,
+ include_color=self._include_color,
+ include_intensity=self._include_intensity,
+ )
+
+ # Preprocess point cloud (transform to robot frame, crop to workspace and estimate normals)
+ open3d_point_cloud = self.preprocess_point_cloud(
+ open3d_point_cloud=open3d_point_cloud,
+ camera_frame_id=ros_point_cloud2.header.frame_id,
+ reference_frame_id=self._reference_frame_id,
+ min_bound=self._min_bound,
+ max_bound=self._max_bound,
+ normals_radius=self._normals_radius,
+ normals_max_nn=self._normals_max_nn,
+ )
+
+ # Draw if needed
+ if self._debug_draw:
+ open3d.visualization.draw_geometries(
+ [
+ open3d_point_cloud,
+ open3d.geometry.TriangleMesh.create_coordinate_frame(
+ size=0.2, origin=[0.0, 0.0, 0.0]
+ ),
+ ],
+ point_show_normal=True,
+ )
+
+ # Construct octree from such point cloud
+ octree = self.construct_octree(
+ open3d_point_cloud,
+ include_color=self._include_color,
+ include_intensity=self._include_intensity,
+ )
+
+ # Write if needed
+ if self._debug_write_octree:
+ ocnn.write_octree(octree, "octree.octree")
+
+ return octree
+
+ def preprocess_point_cloud(
+ self,
+ open3d_point_cloud: open3d.geometry.PointCloud,
+ camera_frame_id: str,
+ reference_frame_id: str,
+ min_bound: List[float],
+ max_bound: List[float],
+ normals_radius: float,
+ normals_max_nn: int,
+ ) -> open3d.geometry.PointCloud:
+
+ # Check if point cloud has any points
+ if not open3d_point_cloud.has_points():
+ self._node.get_logger().warn(
+ "Point cloud has no points. Pre-processing skipped."
+ )
+ return open3d_point_cloud
+
+ # Get transformation from camera to robot and use it to transform point
+ # cloud into robot's base coordinate frame
+ if camera_frame_id != reference_frame_id:
+ transform = self.__tf2_listener.lookup_transform_sync(
+ target_frame=reference_frame_id, source_frame=camera_frame_id
+ )
+ transform_mat = conversions.transform_to_matrix(transform=transform)
+ open3d_point_cloud = open3d_point_cloud.transform(transform_mat)
+
+ # Crop point cloud to include only the workspace
+ open3d_point_cloud = open3d_point_cloud.crop(
+ bounding_box=open3d.geometry.AxisAlignedBoundingBox(
+ min_bound=min_bound, max_bound=max_bound
+ )
+ )
+
+ # Check if any points remain in the area after cropping
+ if not open3d_point_cloud.has_points():
+ self._node.get_logger().warn(
+ "Point cloud has no points after cropping it to the workspace volume."
+ )
+ return open3d_point_cloud
+
+ # Estimate normal vector for each cloud point and orient these towards the camera
+ open3d_point_cloud.estimate_normals(
+ search_param=open3d.geometry.KDTreeSearchParamHybrid(
+ radius=normals_radius, max_nn=normals_max_nn
+ ),
+ fast_normal_computation=True,
+ )
+
+ open3d_point_cloud.orient_normals_towards_camera_location(
+ camera_location=transform_mat[0:3, 3]
+ )
+
+ return open3d_point_cloud
+
+ def construct_octree(
+ self,
+ open3d_point_cloud: open3d.geometry.PointCloud,
+ include_color: bool,
+ include_intensity: bool,
+ ) -> torch.Tensor:
+
+ # In case the point cloud has no points, add a single point
+ # This is a workaround because I was not able to create an empty octree without getting a segfault
+ # TODO: Figure out a better way of making an empty octree (it does not occur if setup correctly, so probably not worth it)
+ if not open3d_point_cloud.has_points():
+ open3d_point_cloud.points.append(
+ (
+ (self._min_bound[0] + self._max_bound[0]) / 2,
+ (self._min_bound[1] + self._max_bound[1]) / 2,
+ (self._min_bound[2] + self._max_bound[2]) / 2,
+ )
+ )
+ open3d_point_cloud.normals.append((0.0, 0.0, 0.0))
+ if include_color or include_intensity:
+ open3d_point_cloud.colors.append((0.0, 0.0, 0.0))
+
+ # Convert open3d point cloud into octree points
+ octree_points = conversions.open3d_point_cloud_to_octree_points(
+ open3d_point_cloud=open3d_point_cloud,
+ include_color=include_color,
+ include_intensity=include_intensity,
+ )
+
+ # Convert octree points into 1D Tensor (via ndarray)
+ # Note: Copy of points here is necessary as ndarray would otherwise be immutable
+ octree_points_ndarray = np.frombuffer(np.copy(octree_points.buffer()), np.uint8)
+ octree_points_tensor = torch.from_numpy(octree_points_ndarray)
+
+ # Finally, create an octree from the points
+ return self._points_to_octree(octree_points_tensor)
diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/twist_subscriber.py b/env_manager/rbs_gym/rbs_gym/envs/observation/twist_subscriber.py
new file mode 100644
index 0000000..6834a37
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/observation/twist_subscriber.py
@@ -0,0 +1,81 @@
+import sys
+from threading import Lock, Thread
+from typing import Optional, Union
+
+import rclpy
+from rclpy.callback_groups import CallbackGroup
+from rclpy.executors import SingleThreadedExecutor
+from rclpy.node import Node
+from rclpy.parameter import Parameter
+from rclpy.qos import (
+ QoSDurabilityPolicy,
+ QoSHistoryPolicy,
+ QoSProfile,
+ QoSReliabilityPolicy,
+)
+from geometry_msgs.msg import TwistStamped
+
+class TwistSubscriber:
+ def __init__(
+ self,
+ node: Node,
+ topic: str,
+ callback_group: Optional[CallbackGroup] = None,
+ ):
+
+ self._node = node
+
+ self.__observation = TwistStamped()
+ self._node.create_subscription(
+ msg_type=TwistStamped,
+ topic=topic,
+ callback=self.observation_callback,
+ qos_profile=QoSProfile(
+ reliability=QoSReliabilityPolicy.RELIABLE,
+ durability=QoSDurabilityPolicy.VOLATILE,
+ history=QoSHistoryPolicy.KEEP_LAST,
+ depth=1,
+ ),
+ callback_group=callback_group,
+ )
+ self.__observation_mutex = Lock()
+ self.__new_observation_available = False
+
+ def observation_callback(self, msg):
+ """
+ Callback for getting observation.
+ """
+
+ self.__observation_mutex.acquire()
+ self.__observation = msg
+ self.__new_observation_available = True
+ self._node.get_logger().debug("New observation received.")
+ self.__observation_mutex.release()
+
+ def get_observation(self) -> TwistStamped:
+ """
+ Get the last received observation.
+ """
+
+ self.__observation_mutex.acquire()
+ observation = self.__observation
+ self.__observation_mutex.release()
+ return observation
+
+ def reset_new_observation_checker(self):
+ """
+ Reset checker of new observations, i.e. `self.new_observation_available()`
+ """
+
+ self.__observation_mutex.acquire()
+ self.__new_observation_available = False
+ self.__observation_mutex.release()
+
+ @property
+ def new_observation_available(self):
+ """
+ Check if new observation is available since `self.reset_new_observation_checker()` was called
+ """
+
+ return self.__new_observation_available
+
diff --git a/env_manager/rbs_gym/rbs_gym/envs/randomizers/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/randomizers/__init__.py
new file mode 100644
index 0000000..9beec96
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/randomizers/__init__.py
@@ -0,0 +1 @@
+from .manipulation import ManipulationGazeboEnvRandomizer
diff --git a/env_manager/rbs_gym/rbs_gym/envs/randomizers/manipulation.py b/env_manager/rbs_gym/rbs_gym/envs/randomizers/manipulation.py
new file mode 100644
index 0000000..5777a2c
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/randomizers/manipulation.py
@@ -0,0 +1,345 @@
+import abc
+
+from env_manager.models.configs import (
+ SceneData,
+)
+from env_manager.scene import Scene
+from env_manager.utils.logging import set_log_level
+from gym_gz import randomizers
+from gym_gz.randomizers import gazebo_env_randomizer
+from gym_gz.randomizers.gazebo_env_randomizer import MakeEnvCallable
+from scenario import gazebo as scenario
+from scenario.bindings.gazebo import GazeboSimulator
+
+from rbs_gym.envs import tasks
+
+SupportedTasks = tasks.Reach | tasks.ReachColorImage | tasks.ReachDepthImage
+
+
+class ManipulationGazeboEnvRandomizer(
+ gazebo_env_randomizer.GazeboEnvRandomizer,
+ randomizers.abc.PhysicsRandomizer,
+ randomizers.abc.TaskRandomizer,
+ abc.ABC,
+):
+ """
+ Basic randomizer of environments for robotic manipulation inside Ignition Gazebo. This randomizer
+ also populates the simulated world with robot, terrain, lighting and other entities.
+ """
+
+ POST_RANDOMIZATION_MAX_STEPS = 50
+
+ metadata = {"render_modes": ["human"]}
+
+ def __init__(
+ self,
+ env: MakeEnvCallable,
+ scene_args: SceneData = SceneData(),
+ **kwargs,
+ ):
+ self._scene_data = scene_args
+ # TODO (a lot of work): Implement proper physics randomization.
+ if scene_args.physics_rollouts_num != 0:
+ raise TypeError(
+ "Proper physics randomization at each reset is not yet implemented. Please set `physics_rollouts_num=0`."
+ )
+
+ # Update kwargs before passing them to the task constructor (some tasks might need them)
+ # TODO: update logic when will need choose between cameras
+ cameras: list[dict[str, str | int]] = []
+ for camera in self._scene_data.camera:
+ camera_info: dict[str, str | int] = {}
+ camera_info["name"] = camera.name
+ camera_info["type"] = camera.type
+ camera_info["width"] = camera.width
+ camera_info["height"] = camera.height
+ cameras.append(camera_info)
+
+ kwargs.update({"camera_info": cameras})
+
+ # Initialize base classes
+ randomizers.abc.TaskRandomizer.__init__(self)
+ randomizers.abc.PhysicsRandomizer.__init__(
+ self, randomize_after_rollouts_num=scene_args.physics_rollouts_num
+ )
+ gazebo_env_randomizer.GazeboEnvRandomizer.__init__(
+ self, env=env, physics_randomizer=self, **kwargs
+ )
+
+ self.__env_initialised = False
+
+ ##########################
+ # PhysicsRandomizer impl #
+ ##########################
+
+ def init_physics_preset(self, task: SupportedTasks):
+ self.set_gravity(task=task)
+
+ def randomize_physics(self, task: SupportedTasks, **kwargs):
+ self.set_gravity(task=task)
+
+ def set_gravity(self, task: SupportedTasks):
+ if not task.world.to_gazebo().set_gravity(
+ (
+ task.np_random.normal(loc=self._gravity[0], scale=self._gravity_std[0]),
+ task.np_random.normal(loc=self._gravity[1], scale=self._gravity_std[1]),
+ task.np_random.normal(loc=self._gravity[2], scale=self._gravity_std[2]),
+ )
+ ):
+ raise RuntimeError("Failed to set the gravity")
+
+ def get_engine(self):
+ return scenario.PhysicsEngine_dart
+
+ #######################
+ # TaskRandomizer impl #
+ #######################
+
+ def randomize_task(self, task: SupportedTasks, **kwargs):
+ """
+ Randomization of the task, which is called on each reset of the environment.
+ Note that this randomizer reset is called before `reset_task()`.
+ """
+
+ # Get gazebo instance associated with the task
+ if "gazebo" not in kwargs:
+ raise ValueError("Randomizer does not have access to the gazebo interface")
+ if isinstance(kwargs["gazebo"], GazeboSimulator):
+ gazebo: GazeboSimulator = kwargs["gazebo"]
+ else:
+ raise RuntimeError("Provide GazeboSimulator instance")
+
+ # Initialise the environment on the first iteration
+ if not self.__env_initialised:
+ self.init_env(task=task, gazebo=gazebo)
+ self.__env_initialised = True
+
+ # Perform pre-randomization steps
+ self.pre_randomization(task=task)
+
+ self._scene.reset_scene()
+
+ # Perform post-randomization steps
+ # TODO: Something in post-randomization causes GUI client to freeze during reset (the simulation server still works fine)
+ self.post_randomization(task, gazebo)
+
+ ###################
+ # Randomizer impl #
+ ###################
+
+ # Initialisation #
+ def init_env(self, task: SupportedTasks, gazebo: scenario.GazeboSimulator):
+ """
+ Initialise an instance of the environment before the very first iteration
+ """
+
+ self._scene = Scene(
+ task,
+ gazebo,
+ self._scene_data,
+ task.get_parameter("robot_description").get_parameter_value().string_value,
+ )
+
+ # Init Scene for task
+ task.scene = self._scene
+ # Set log level for (Gym) Ignition
+ set_log_level(log_level=task.get_logger().get_effective_level().name)
+
+ self._scene.init_scene()
+
+ # Pre-randomization #
+ def pre_randomization(self, task: SupportedTasks):
+ """
+ Perform steps that are required before randomization is performed.
+ """
+
+ for obj in self._scene.objects:
+ # If desired, select random spawn position from the segments
+ # It is performed here because object spawn position might be of interest also for robot and camera randomization
+
+ segments_len = len(obj.randomize.random_spawn_position_segments)
+ if segments_len > 1:
+ # Randomly select a segment between two points
+ start_index = task.np_random.randint(segments_len - 1)
+ segment = (
+ obj.randomize.random_spawn_position_segments[start_index],
+ obj.randomize.random_spawn_position_segments[start_index + 1],
+ )
+
+ # Randomly select a point on the segment and use it as the new object spawn position
+ intersect = task.np_random.random()
+ direction = (
+ segment[1][0] - segment[0][0],
+ segment[1][1] - segment[0][1],
+ segment[1][2] - segment[0][2],
+ )
+ obj.position = (
+ segment[0][0] + intersect * direction[0],
+ segment[0][1] + intersect * direction[1],
+ segment[0][2] + intersect * direction[2],
+ )
+
+ # TODO: add bounding box with multiple objects
+
+ # # Update also the workspace centre (and bounding box) if desired
+ # if self._object_random_spawn_position_update_workspace_centre:
+ # task.workspace_centre = (
+ # self._object_spawn_position[0],
+ # self._object_spawn_position[1],
+ # # Z workspace is currently kept the same on purpose
+ # task.workspace_centre[2],
+ # )
+ # workspace_volume_half = (
+ # task.workspace_volume[0] / 2,
+ # task.workspace_volume[1] / 2,
+ # task.workspace_volume[2] / 2,
+ # )
+ # task.workspace_min_bound = (
+ # task.workspace_centre[0] - workspace_volume_half[0],
+ # task.workspace_centre[1] - workspace_volume_half[1],
+ # task.workspace_centre[2] - workspace_volume_half[2],
+ # )
+ # task.workspace_max_bound = (
+ # task.workspace_centre[0] + workspace_volume_half[0],
+ # task.workspace_centre[1] + workspace_volume_half[1],
+ # task.workspace_centre[2] + workspace_volume_half[2],
+ # )
+
+ def post_randomization(
+ self, task: SupportedTasks, gazebo: scenario.GazeboSimulator
+ ):
+ """
+ Perform steps required once randomization is complete and the simulation can be stepped unpaused.
+ """
+
+ def perform_gazebo_step():
+ if not gazebo.step():
+ raise RuntimeError("Failed to execute an unpaused Gazebo step")
+
+ def wait_for_new_observations():
+ attempts = 0
+ while True:
+ attempts += 1
+ if attempts % self.POST_RANDOMIZATION_MAX_STEPS == 0:
+ task.get_logger().debug(
+ f"Waiting for new joint state after reset. Iteration #{attempts}..."
+ )
+ else:
+ task.get_logger().debug("Waiting for new joint state after reset.")
+
+ perform_gazebo_step()
+
+ # If camera_sub is defined, ensure all new observations are available
+ if hasattr(task, "camera_subs") and not all(
+ sub.new_observation_available for sub in task.camera_subs
+ ):
+ continue
+
+ return # Observations are ready
+
+ attempts = 0
+ processed_objects = set()
+
+ # Early exit if the maximum number of steps is already reached
+ if self.POST_RANDOMIZATION_MAX_STEPS == 0:
+ task.get_logger().error(
+ "Robot keeps falling through the terrain. There is something wrong..."
+ )
+ return
+
+ # Ensure no objects are overlapping
+ for obj in self._scene.objects:
+ if not obj.randomize.random_pose or obj.name in processed_objects:
+ continue
+
+ # Try repositioning until no overlap or maximum attempts reached
+ for _ in range(self.POST_RANDOMIZATION_MAX_STEPS):
+ task.get_logger().debug(
+ f"Checking overlap for {obj.name}, attempt {attempts + 1}"
+ )
+ if self._scene.check_object_overlapping(obj):
+ processed_objects.add(obj.name)
+ break # No overlap, move to next object
+ else:
+ task.get_logger().debug(
+ f"Objects overlapping, trying new positions for {obj.name}"
+ )
+
+ perform_gazebo_step()
+
+ else:
+ task.get_logger().warn(
+ f"Could not place {obj.name} without overlap after {self.POST_RANDOMIZATION_MAX_STEPS} attempts"
+ )
+ continue # Move to next object in case of failure
+
+ # Execute steps until new observations are available
+ if hasattr(task, "camera_subs") or task._enable_gripper:
+ wait_for_new_observations()
+
+ # Final check if observations are not available within the maximum steps
+ if self.POST_RANDOMIZATION_MAX_STEPS == attempts:
+ task.get_logger().error("Cannot obtain new observation.")
+
+ # =============================
+ # Additional features and debug
+ # =============================
+
+ # def visualise_workspace(
+ # self,
+ # task: SupportedTasks,
+ # gazebo: scenario.GazeboSimulator,
+ # color: Tuple[float, float, float, float] = (0, 1, 0, 0.8),
+ # ):
+ # # Insert a translucent box visible only in simulation with no physical interactions
+ # models.Box(
+ # world=task.world,
+ # name="_workspace_volume",
+ # position=self._object_spawn_position,
+ # orientation=(0, 0, 0, 1),
+ # size=task.workspace_volume,
+ # collision=False,
+ # visual=True,
+ # gui_only=True,
+ # static=True,
+ # color=color,
+ # )
+ # # Execute a paused run to process model insertion
+ # if not gazebo.run(paused=True):
+ # raise RuntimeError("Failed to execute a paused Gazebo run")
+ #
+ # def visualise_spawn_volume(
+ # self,
+ # task: SupportedTasks,
+ # gazebo: scenario.GazeboSimulator,
+ # color: Tuple[float, float, float, float] = (0, 0, 1, 0.8),
+ # color_with_height: Tuple[float, float, float, float] = (1, 0, 1, 0.7),
+ # ):
+ # # Insert translucent boxes visible only in simulation with no physical interactions
+ # models.Box(
+ # world=task.world,
+ # name="_object_random_spawn_volume",
+ # position=self._object_spawn_position,
+ # orientation=(0, 0, 0, 1),
+ # size=self._object_random_spawn_volume,
+ # collision=False,
+ # visual=True,
+ # gui_only=True,
+ # static=True,
+ # color=color,
+ # )
+ # models.Box(
+ # world=task.world,
+ # name="_object_random_spawn_volume_with_height",
+ # position=self._object_spawn_position,
+ # orientation=(0, 0, 0, 1),
+ # size=self._object_random_spawn_volume,
+ # collision=False,
+ # visual=True,
+ # gui_only=True,
+ # static=True,
+ # color=color_with_height,
+ # )
+ # # Execute a paused run to process model insertion
+ # if not gazebo.run(paused=True):
+ # raise RuntimeError("Failed to execute a paused Gazebo run")
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/__init__.py
new file mode 100644
index 0000000..15d114f
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/__init__.py
@@ -0,0 +1,4 @@
+# from .curriculums import *
+# from .grasp import *
+# from .grasp_planetary import *
+from .reach import *
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/__init__.py
new file mode 100644
index 0000000..3e3828e
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/__init__.py
@@ -0,0 +1 @@
+from .grasp import GraspCurriculum
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/common.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/common.py
new file mode 100644
index 0000000..5caf7bc
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/common.py
@@ -0,0 +1,700 @@
+from __future__ import annotations
+
+import enum
+import itertools
+import math
+from collections import deque
+from typing import Callable, Deque, Dict, Optional, Tuple, Type
+
+import numpy as np
+from gym_gz.base.task import Task
+from gym_gz.utils.typing import Reward
+from tf2_ros.buffer_interface import TypeException
+
+INFO_MEAN_STEP_KEY: str = "__mean_step__"
+INFO_MEAN_EPISODE_KEY: str = "__mean_episode__"
+
+
+@enum.unique
+class CurriculumStage(enum.Enum):
+ """
+ Ordered enum that represents stages of a curriculum for RL task.
+ """
+
+ @classmethod
+ def first(self) -> CurriculumStage:
+
+ return self(1)
+
+ @classmethod
+ def last(self) -> CurriculumStage:
+
+ return self(len(self))
+
+ def next(self) -> Optional[CurriculumStage]:
+
+ next_value = self.value + 1
+
+ if next_value > self.last().value:
+ return None
+ else:
+ return self(next_value)
+
+ def previous(self) -> Optional[CurriculumStage]:
+
+ previous_value = self.value - 1
+
+ if previous_value < self.first().value:
+ return None
+ else:
+ return self(previous_value)
+
+
+class StageRewardCurriculum:
+ """
+ Curriculum that begins to compute rewards for a stage once all previous stages are complete.
+ """
+
+ PERSISTENT_ID: str = "PERSISTENT"
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ curriculum_stage: Type[CurriculumStage],
+ stage_reward_multiplier: float,
+ dense_reward: bool = False,
+ **kwargs,
+ ):
+
+ if 0 == len(curriculum_stage):
+ raise TypeException(f"{curriculum_stage} has length of 0")
+
+ self.__use_dense_reward = dense_reward
+ if self.__use_dense_reward:
+ raise ValueError(
+ "Dense reward is currently not implemented for any curriculum"
+ )
+
+ # Setup internals
+ self._stage_type = curriculum_stage
+ self._stage_reward_functions: Dict[curriculum_stage, Callable] = {
+ curriculum_stage(stage): getattr(self, f"get_reward_{stage.name}")
+ for stage in iter(curriculum_stage)
+ }
+ self.__stage_reward_multipliers: Dict[curriculum_stage, float] = {
+ curriculum_stage(stage): stage_reward_multiplier ** (stage.value - 1)
+ for stage in iter(curriculum_stage)
+ }
+
+ self.stages_completed_this_episode: Dict[curriculum_stage, bool] = {
+ curriculum_stage(stage): False for stage in iter(curriculum_stage)
+ }
+ self.__stages_rewards_this_episode: Dict[curriculum_stage, float] = {
+ curriculum_stage(stage): 0.0 for stage in iter(curriculum_stage)
+ }
+ self.__stages_rewards_this_episode[self.PERSISTENT_ID] = 0.0
+
+ self.__episode_succeeded: bool = False
+ self.__episode_failed: bool = False
+
+ def get_reward(self, only_last_stage: bool = False, **kwargs) -> Reward:
+
+ reward = 0.0
+
+ # Determine the stage at which to start computing reward [performance - done stages give no reward]
+ if only_last_stage:
+ first_stage_to_process = self._stage_type.last()
+ else:
+ for stage in iter(self._stage_type):
+ if not self.stages_completed_this_episode[stage]:
+ first_stage_to_process = stage
+ break
+
+ # Iterate over all stages that might need to be processed
+ for stage in range(first_stage_to_process.value, len(self._stage_type) + 1):
+ stage = self._stage_type(stage)
+
+ # Compute reward for the current stage
+ stage_reward = self._stage_reward_functions[stage](**kwargs)
+ # Multiply by the reward multiplier
+ stage_reward *= self.__stage_reward_multipliers[stage]
+ # Add to the total step reward
+ reward += stage_reward
+ # Add reward to the list for info
+ self.__stages_rewards_this_episode[stage] += stage_reward
+
+ # Break if stage is not yet completed [performance - next stages won't give any reward]
+ if not self.stages_completed_this_episode[stage]:
+ break
+
+ # If the last stage is complete, the episode has succeeded
+ self.__episode_succeeded = self.stages_completed_this_episode[
+ self._stage_type.last()
+ ]
+ if self.__episode_succeeded:
+ return reward
+
+ # Add persistent reward that is added regardless of the episode (unless task already succeeded)
+ persistent_reward = self.get_persistent_reward(**kwargs)
+ # Add to the total step reward
+ reward += persistent_reward
+ # Add reward to the list for info
+ self.__stages_rewards_this_episode[self.PERSISTENT_ID] += persistent_reward
+
+ return reward
+
+ def is_done(self) -> bool:
+
+ if self.__episode_succeeded:
+ # The episode ended with success
+ self.on_episode_success()
+ return True
+ elif self.__episode_failed:
+ # The episode ended due to failure
+ self.on_episode_failure()
+ return True
+ else:
+ # Otherwise, the episode is not yet done
+ return False
+
+ def get_info(self) -> Dict:
+
+ # Whether the episode succeeded
+ info = {
+ "is_success": self.__episode_succeeded,
+ }
+
+ # What stage was reached during this episode so far
+ for stage in iter(self._stage_type):
+ reached_stage = stage
+ if not self.stages_completed_this_episode[stage]:
+ break
+ info.update(
+ {
+ f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}ep_reached_stage_mean": reached_stage.value,
+ }
+ )
+
+ # Rewards for the individual stages
+ info.update(
+ {
+ f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}ep_rew_mean_{stage.value}_{stage.name.lower()}": self.__stages_rewards_this_episode[
+ stage
+ ]
+ for stage in iter(self._stage_type)
+ }
+ )
+ info.update(
+ {
+ f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}ep_rew_mean_{self.PERSISTENT_ID.lower()}": self.__stages_rewards_this_episode[
+ self.PERSISTENT_ID
+ ]
+ }
+ )
+
+ return info
+
+ def reset_task(self):
+
+ if not (self.__episode_succeeded or self.__episode_failed):
+ # The episode ended due to timeout
+ self.on_episode_timeout()
+
+ # Reset internals
+ self.stages_completed_this_episode = dict.fromkeys(
+ self.stages_completed_this_episode, False
+ )
+ self.__stages_rewards_this_episode = dict.fromkeys(
+ self.__stages_rewards_this_episode, 0.0
+ )
+ self.__stages_rewards_this_episode[self.PERSISTENT_ID] = 0.0
+ self.__episode_succeeded = False
+ self.__episode_failed = False
+
+ @property
+ def episode_succeeded(self) -> bool:
+
+ return self.__episode_succeeded
+
+ @episode_succeeded.setter
+ def episode_succeeded(self, value: bool):
+
+ self.__episode_succeeded = value
+
+ @property
+ def episode_failed(self) -> bool:
+
+ return self.__episode_failed
+
+ @episode_failed.setter
+ def episode_failed(self, value: bool):
+
+ self.__episode_failed = value
+
+ @property
+ def use_dense_reward(self) -> bool:
+
+ return self.__use_dense_reward
+
+ def get_persistent_reward(self, **kwargs) -> float:
+ """
+ Virtual method.
+ """
+
+ reward = 0.0
+
+ return reward
+
+ def on_episode_success(self):
+ """
+ Virtual method.
+ """
+
+ pass
+
+ def on_episode_failure(self):
+ """
+ Virtual method.
+ """
+
+ pass
+
+ def on_episode_timeout(self):
+ """
+ Virtual method.
+ """
+
+ pass
+
+
+class SuccessRateImpl:
+ """
+ Moving average over the success rate of last N episodes.
+ """
+
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ initial_success_rate: float = 0.0,
+ rolling_average_n: int = 100,
+ **kwargs,
+ ):
+
+ self.__success_rate = initial_success_rate
+ self.__rolling_average_n = rolling_average_n
+
+ # Setup internals
+ self.__previous_success_rate_weight: int = 0
+ self.__collected_samples: int = 0
+
+ def get_info(self) -> Dict:
+
+ info = {
+ f"{self.INFO_CURRICULUM_PREFIX}_success_rate": self.__success_rate,
+ }
+
+ return info
+
+ def update_success_rate(self, is_success: bool):
+
+ # Until `rolling_average_n` is reached, use number of collected samples during computations
+ if self.__collected_samples < self.__rolling_average_n:
+ self.__previous_success_rate_weight = self.__collected_samples
+ self.__collected_samples += 1
+
+ self.__success_rate = (
+ self.__previous_success_rate_weight * self.__success_rate
+ + float(is_success)
+ ) / self.__collected_samples
+
+ @property
+ def success_rate(self) -> float:
+
+ return self.__success_rate
+
+
+class WorkspaceScaleCurriculum:
+ """
+ Curriculum that increases the workspace size as the success rate increases.
+ """
+
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ task: Task,
+ success_rate_impl: SuccessRateImpl,
+ min_workspace_scale: float,
+ max_workspace_volume: Tuple[float, float, float],
+ max_workspace_scale_success_rate_threshold: float,
+ **kwargs,
+ ):
+
+ self.__task = task
+ self.__success_rate_impl = success_rate_impl
+ self.__min_workspace_scale = min_workspace_scale
+ self.__max_workspace_volume = max_workspace_volume
+ self.__max_workspace_scale_success_rate_threshold = (
+ max_workspace_scale_success_rate_threshold
+ )
+
+ def get_info(self) -> Dict:
+
+ info = {
+ f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}workspace_scale": self.__workspace_scale,
+ }
+
+ return info
+
+ def reset_task(self):
+
+ # Update workspace size
+ self.__update_workspace_size()
+
+ def __update_workspace_size(self):
+
+ self.__workspace_scale = min(
+ 1.0,
+ max(
+ self.__min_workspace_scale,
+ self.__success_rate_impl.success_rate
+ / self.__max_workspace_scale_success_rate_threshold,
+ ),
+ )
+
+ workspace_volume_new = (
+ self.__workspace_scale * self.__max_workspace_volume[0],
+ self.__workspace_scale * self.__max_workspace_volume[1],
+ # Z workspace is currently kept the same on purpose
+ self.__max_workspace_volume[2],
+ )
+ workspace_volume_half_new = (
+ workspace_volume_new[0] / 2,
+ workspace_volume_new[1] / 2,
+ workspace_volume_new[2] / 2,
+ )
+ workspace_min_bound_new = (
+ self.__task.workspace_centre[0] - workspace_volume_half_new[0],
+ self.__task.workspace_centre[1] - workspace_volume_half_new[1],
+ self.__task.workspace_centre[2] - workspace_volume_half_new[2],
+ )
+ workspace_max_bound_new = (
+ self.__task.workspace_centre[0] + workspace_volume_half_new[0],
+ self.__task.workspace_centre[1] + workspace_volume_half_new[1],
+ self.__task.workspace_centre[2] + workspace_volume_half_new[2],
+ )
+
+ self.__task.add_task_parameter_overrides(
+ {
+ "workspace_volume": workspace_volume_new,
+ "workspace_min_bound": workspace_min_bound_new,
+ "workspace_max_bound": workspace_max_bound_new,
+ }
+ )
+
+
+class ObjectSpawnVolumeScaleCurriculum:
+ """
+ Curriculum that increases the object spawn volume as the success rate increases.
+ """
+
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ task: Task,
+ success_rate_impl: SuccessRateImpl,
+ min_object_spawn_volume_scale: float,
+ max_object_spawn_volume: Tuple[float, float, float],
+ max_object_spawn_volume_scale_success_rate_threshold: float,
+ **kwargs,
+ ):
+
+ self.__task = task
+ self.__success_rate_impl = success_rate_impl
+ self.__min_object_spawn_volume_scale = min_object_spawn_volume_scale
+ self.__max_object_spawn_volume = max_object_spawn_volume
+ self.__max_object_spawn_volume_scale_success_rate_threshold = (
+ max_object_spawn_volume_scale_success_rate_threshold
+ )
+
+ def get_info(self) -> Dict:
+
+ info = {
+ f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}object_spawn_volume_scale": self.__object_spawn_volume_scale,
+ }
+
+ return info
+
+ def reset_task(self):
+
+ # Update object_spawn_volume size
+ self.__update_object_spawn_volume_size()
+
+ def __update_object_spawn_volume_size(self):
+
+ self.__object_spawn_volume_scale = min(
+ 1.0,
+ max(
+ self.__min_object_spawn_volume_scale,
+ self.__success_rate_impl.success_rate
+ / self.__max_object_spawn_volume_scale_success_rate_threshold,
+ ),
+ )
+
+ object_spawn_volume_volume_new = (
+ self.__object_spawn_volume_scale * self.__max_object_spawn_volume[0],
+ self.__object_spawn_volume_scale * self.__max_object_spawn_volume[1],
+ self.__object_spawn_volume_scale * self.__max_object_spawn_volume[2],
+ )
+
+ self.__task.add_randomizer_parameter_overrides(
+ {
+ "object_random_spawn_volume": object_spawn_volume_volume_new,
+ }
+ )
+
+
+class ObjectCountCurriculum:
+ """
+ Curriculum that increases the number of objects as the success rate increases.
+ """
+
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ task: Task,
+ success_rate_impl: SuccessRateImpl,
+ object_count_min: int,
+ object_count_max: int,
+ max_object_count_success_rate_threshold: float,
+ **kwargs,
+ ):
+
+ self.__task = task
+ self.__success_rate_impl = success_rate_impl
+ self.__object_count_min = object_count_min
+ self.__object_count_max = object_count_max
+ self.__max_object_count_success_rate_threshold = (
+ max_object_count_success_rate_threshold
+ )
+
+ self.__object_count_min_max_diff = object_count_max - object_count_min
+ if self.__object_count_min_max_diff < 0:
+ raise Exception(
+ "'object_count_min' cannot be larger than 'object_count_max'"
+ )
+
+ def get_info(self) -> Dict:
+
+ info = {
+ f"{self.INFO_CURRICULUM_PREFIX}object_count": self.__object_count,
+ }
+
+ return info
+
+ def reset_task(self):
+
+ # Update object count
+ self.__update_object_count()
+
+ def __update_object_count(self):
+
+ self.__object_count = min(
+ self.__object_count_max,
+ math.floor(
+ self.__object_count_min
+ + (
+ self.__success_rate_impl.success_rate
+ / self.__max_object_count_success_rate_threshold
+ )
+ * self.__object_count_min_max_diff
+ ),
+ )
+
+ self.__task.add_randomizer_parameter_overrides(
+ {
+ "object_count": self.__object_count,
+ }
+ )
+
+
+class ArmStuckChecker:
+ """
+ Checker for arm getting stuck.
+ """
+
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ task: Task,
+ arm_stuck_n_steps: int,
+ arm_stuck_min_joint_difference_norm: float,
+ **kwargs,
+ ):
+
+ self.__task = task
+ self.__arm_stuck_min_joint_difference_norm = arm_stuck_min_joint_difference_norm
+
+ # List of previous join positions (used to compute difference norm with an older previous reading)
+ self.__previous_joint_positions: Deque[np.ndarray] = deque(
+ [], maxlen=arm_stuck_n_steps
+ )
+ # Counter of how many time the robot got stuck
+ self.__robot_stuck_total_counter: int = 0
+
+ # Initialize list of indices for the arm.
+ # It is assumed that these indices do not change during the operation
+ self.__arm_joint_indices = None
+
+ def get_info(self) -> Dict:
+
+ info = {
+ f"{self.INFO_CURRICULUM_PREFIX}robot_stuck_count": self.__robot_stuck_total_counter,
+ }
+
+ return info
+
+ def reset_task(self):
+
+ self.__previous_joint_positions.clear()
+
+ joint_positions = self.__get_arm_joint_positions()
+ if joint_positions is not None:
+ self.__previous_joint_positions.append(joint_positions)
+
+ def is_robot_stuck(self) -> bool:
+
+ # Get current position and append to the list of previous ones
+ current_joint_positions = self.__get_arm_joint_positions()
+ if current_joint_positions is not None:
+ self.__previous_joint_positions.append(current_joint_positions)
+
+ # Stop checking if there is not yet enough entries in the list
+ if (
+ len(self.__previous_joint_positions)
+ < self.__previous_joint_positions.maxlen
+ ):
+ return False
+
+ # Make sure the length of joint position matches
+ if len(current_joint_positions) != len(self.__previous_joint_positions[0]):
+ return False
+
+ # Compute joint difference norm only with the `t - arm_stuck_n_steps` entry first (performance reason)
+ joint_difference_norm = np.linalg.norm(
+ current_joint_positions - self.__previous_joint_positions[0]
+ )
+
+ # If the difference is large enough, the arm does not appear to be stuck, so skip computing all other entries
+ if joint_difference_norm > self.__arm_stuck_min_joint_difference_norm:
+ return False
+
+ # If it is too small, consider all other entries as well
+ joint_difference_norms = np.linalg.norm(
+ current_joint_positions
+ - list(itertools.islice(self.__previous_joint_positions, 1, None)),
+ axis=1,
+ )
+
+ # Return true (stuck) if all joint difference entries are too small
+ is_stuck = all(
+ joint_difference_norms < self.__arm_stuck_min_joint_difference_norm
+ )
+ self.__robot_stuck_total_counter += int(is_stuck)
+ return is_stuck
+
+ def __get_arm_joint_positions(self) -> Optional[np.ndarray[float]]:
+
+ joint_state = self.__task.moveit2.joint_state
+
+ if joint_state is None:
+ return None
+
+ if self.__arm_joint_indices is None:
+ self.__arm_joint_indices = [
+ i
+ for i, joint_name in enumerate(joint_state.name)
+ if joint_name in self.__task.robot_arm_joint_names
+ ]
+
+ return np.take(joint_state.position, self.__arm_joint_indices)
+
+
+class AttributeCurriculum:
+ """
+ Curriculum that increases the value of an attribute (e.g. requirement) as the success rate increases.
+ Currently support only attributes that are increasing.
+ """
+
+ INFO_CURRICULUM_PREFIX: str = "curriculum/"
+
+ def __init__(
+ self,
+ success_rate_impl: SuccessRateImpl,
+ attribute_owner: Type,
+ attribute_name: str,
+ initial_value: float,
+ target_value: float,
+ target_value_threshold: float,
+ **kwargs,
+ ):
+
+ self.__success_rate_impl = success_rate_impl
+ self.__attribute_owner = attribute_owner
+ self.__attribute_name = attribute_name
+ self.__initial_value = initial_value
+ self.__target_value_threshold = target_value_threshold
+
+ # Initialise current value of the attribute
+ self.__current_value = initial_value
+
+ # Store difference for faster computations
+ self.__value_diff = target_value - initial_value
+
+ def get_info(self) -> Dict:
+
+ info = {
+ f"{self.INFO_CURRICULUM_PREFIX}{self.__attribute_name}": self.__current_value,
+ }
+
+ return info
+
+ def reset_task(self):
+
+ # Update object count
+ self.__update_attribute()
+
+ def __update_attribute(self):
+
+ scale = min(
+ 1.0,
+ max(
+ self.__initial_value,
+ self.__success_rate_impl.success_rate / self.__target_value_threshold,
+ ),
+ )
+
+ self.__current_value = self.__initial_value + (scale * self.__value_diff)
+
+ if hasattr(self.__attribute_owner, self.__attribute_name):
+ setattr(self.__attribute_owner, self.__attribute_name, self.__current_value)
+ elif hasattr(self.__attribute_owner, f"_{self.__attribute_name}"):
+ setattr(
+ self.__attribute_owner,
+ f"_{self.__attribute_name}",
+ self.__current_value,
+ )
+ elif hasattr(self.__attribute_owner, f"__{self.__attribute_name}"):
+ setattr(
+ self.__attribute_owner,
+ f"__{self.__attribute_name}",
+ self.__current_value,
+ )
+ else:
+ raise Exception(
+ f"Attribute owner '{self.__attribute_owner}' does not have any attribute named {self.__attribute_name}."
+ )
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/grasp.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/grasp.py
new file mode 100644
index 0000000..bd6a751
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/grasp.py
@@ -0,0 +1,341 @@
+from typing import Dict, List, Tuple
+
+from gym_gz.base.task import Task
+
+from rbs_gym.envs.tasks.curriculums.common import *
+from rbs_gym.envs.utils.math import distance_to_nearest_point
+
+
+class GraspStage(CurriculumStage):
+ """
+ Ordered enum that represents stages of a curriculum for Grasp (and GraspPlanetary) task.
+ """
+
+ REACH = 1
+ TOUCH = 2
+ GRASP = 3
+ LIFT = 4
+
+
+class GraspCurriculum(
+ StageRewardCurriculum,
+ SuccessRateImpl,
+ WorkspaceScaleCurriculum,
+ ObjectSpawnVolumeScaleCurriculum,
+ ObjectCountCurriculum,
+ ArmStuckChecker,
+):
+ """
+ Curriculum learning implementation for grasp task that provides termination (success/fail) and reward for each stage of the task.
+ """
+
+ def __init__(
+ self,
+ task: Task,
+ stages_base_reward: float,
+ reach_required_distance: float,
+ lift_required_height: float,
+ persistent_reward_each_step: float,
+ persistent_reward_terrain_collision: float,
+ persistent_reward_all_objects_outside_workspace: float,
+ persistent_reward_arm_stuck: float,
+ enable_stage_reward_curriculum: bool,
+ enable_workspace_scale_curriculum: bool,
+ enable_object_spawn_volume_scale_curriculum: bool,
+ enable_object_count_curriculum: bool,
+ reach_required_distance_min: Optional[float] = None,
+ reach_required_distance_max: Optional[float] = None,
+ reach_required_distance_max_threshold: Optional[float] = None,
+ lift_required_height_min: Optional[float] = None,
+ lift_required_height_max: Optional[float] = None,
+ lift_required_height_max_threshold: Optional[float] = None,
+ **kwargs,
+ ):
+
+ StageRewardCurriculum.__init__(self, curriculum_stage=GraspStage, **kwargs)
+ SuccessRateImpl.__init__(self, **kwargs)
+ WorkspaceScaleCurriculum.__init__(
+ self, task=task, success_rate_impl=self, **kwargs
+ )
+ ObjectSpawnVolumeScaleCurriculum.__init__(
+ self, task=task, success_rate_impl=self, **kwargs
+ )
+ ObjectCountCurriculum.__init__(
+ self, task=task, success_rate_impl=self, **kwargs
+ )
+ ArmStuckChecker.__init__(self, task=task, **kwargs)
+
+ # Grasp task/environment that will be used to extract information from the scene
+ self.__task = task
+
+ # Parameters
+ self.__stages_base_reward = stages_base_reward
+ self.reach_required_distance = reach_required_distance
+ self.lift_required_height = lift_required_height
+ self.__persistent_reward_each_step = persistent_reward_each_step
+ self.__persistent_reward_terrain_collision = persistent_reward_terrain_collision
+ self.__persistent_reward_all_objects_outside_workspace = (
+ persistent_reward_all_objects_outside_workspace
+ )
+ self.__persistent_reward_arm_stuck = persistent_reward_arm_stuck
+ self.__enable_stage_reward_curriculum = enable_stage_reward_curriculum
+ self.__enable_workspace_scale_curriculum = enable_workspace_scale_curriculum
+ self.__enable_object_spawn_volume_scale_curriculum = (
+ enable_object_spawn_volume_scale_curriculum
+ )
+ self.__enable_object_count_curriculum = enable_object_count_curriculum
+
+ # Make sure that the persistent rewards for each step are negative
+ if self.__persistent_reward_each_step > 0.0:
+ self.__persistent_reward_each_step *= -1.0
+ if self.__persistent_reward_terrain_collision > 0.0:
+ self.__persistent_reward_terrain_collision *= -1.0
+ if self.__persistent_reward_all_objects_outside_workspace > 0.0:
+ self.__persistent_reward_all_objects_outside_workspace *= -1.0
+ if self.__persistent_reward_arm_stuck > 0.0:
+ self.__persistent_reward_arm_stuck *= -1.0
+
+ # Setup curriculum for Reach distance requirement (if enabled)
+ reach_required_distance_min = (
+ reach_required_distance_min
+ if reach_required_distance_min is not None
+ else reach_required_distance
+ )
+ reach_required_distance_max = (
+ reach_required_distance_max
+ if reach_required_distance_max is not None
+ else reach_required_distance
+ )
+ reach_required_distance_max_threshold = (
+ reach_required_distance_max_threshold
+ if reach_required_distance_max_threshold is not None
+ else 0.5
+ )
+ self.__reach_required_distance_curriculum_enabled = (
+ not reach_required_distance_min == reach_required_distance_max
+ )
+ if self.__reach_required_distance_curriculum_enabled:
+ self.__reach_required_distance_curriculum = AttributeCurriculum(
+ success_rate_impl=self,
+ attribute_owner=self,
+ attribute_name="reach_required_distance",
+ initial_value=reach_required_distance_min,
+ target_value=reach_required_distance_max,
+ target_value_threshold=reach_required_distance_max_threshold,
+ )
+
+ # Setup curriculum for Lift height requirement (if enabled)
+ lift_required_height_min = (
+ lift_required_height_min
+ if lift_required_height_min is not None
+ else lift_required_height
+ )
+ lift_required_height_max = (
+ lift_required_height_max
+ if lift_required_height_max is not None
+ else lift_required_height
+ )
+ lift_required_height_max_threshold = (
+ lift_required_height_max_threshold
+ if lift_required_height_max_threshold is not None
+ else 0.5
+ )
+ # Offset Lift height requirement by the robot base offset
+ lift_required_height += task.robot_model_class.BASE_LINK_Z_OFFSET
+ lift_required_height_min += task.robot_model_class.BASE_LINK_Z_OFFSET
+ lift_required_height_max += task.robot_model_class.BASE_LINK_Z_OFFSET
+ lift_required_height_max_threshold += task.robot_model_class.BASE_LINK_Z_OFFSET
+
+ self.__lift_required_height_curriculum_enabled = (
+ not lift_required_height_min == lift_required_height_max
+ )
+ if self.__lift_required_height_curriculum_enabled:
+ self.__lift_required_height_curriculum = AttributeCurriculum(
+ success_rate_impl=self,
+ attribute_owner=self,
+ attribute_name="lift_required_height",
+ initial_value=lift_required_height_min,
+ target_value=lift_required_height_max,
+ target_value_threshold=lift_required_height_max_threshold,
+ )
+
+ def get_reward(self) -> Reward:
+
+ if self.__enable_stage_reward_curriculum:
+ # Try to get reward from each stage
+ return StageRewardCurriculum.get_reward(
+ self,
+ ee_position=self.__task.get_ee_position(),
+ object_positions=self.__task.get_object_positions(),
+ touched_objects=self.__task.get_touched_objects(),
+ grasped_objects=self.__task.get_grasped_objects(),
+ )
+ else:
+ # If curriculum for stages is disabled, compute reward only for the last stage
+ return StageRewardCurriculum.get_reward(
+ self,
+ only_last_stage=True,
+ object_positions=self.__task.get_object_positions(),
+ grasped_objects=self.__task.get_grasped_objects(),
+ )
+
+ def is_done(self) -> bool:
+
+ return StageRewardCurriculum.is_done(self)
+
+ def get_info(self) -> Dict:
+
+ info = StageRewardCurriculum.get_info(self)
+ info.update(SuccessRateImpl.get_info(self))
+ if self.__enable_workspace_scale_curriculum:
+ info.update(WorkspaceScaleCurriculum.get_info(self))
+ if self.__enable_object_spawn_volume_scale_curriculum:
+ info.update(ObjectSpawnVolumeScaleCurriculum.get_info(self))
+ if self.__enable_object_count_curriculum:
+ info.update(ObjectCountCurriculum.get_info(self))
+ if self.__persistent_reward_arm_stuck:
+ info.update(ArmStuckChecker.get_info(self))
+ if self.__reach_required_distance_curriculum_enabled:
+ info.update(self.__reach_required_distance_curriculum.get_info())
+ if self.__lift_required_height_curriculum_enabled:
+ info.update(self.__lift_required_height_curriculum.get_info())
+
+ return info
+
+ def reset_task(self):
+
+ StageRewardCurriculum.reset_task(self)
+ if self.__enable_workspace_scale_curriculum:
+ WorkspaceScaleCurriculum.reset_task(self)
+ if self.__enable_object_spawn_volume_scale_curriculum:
+ ObjectSpawnVolumeScaleCurriculum.reset_task(self)
+ if self.__enable_object_count_curriculum:
+ ObjectCountCurriculum.reset_task(self)
+ if self.__persistent_reward_arm_stuck:
+ ArmStuckChecker.reset_task(self)
+ if self.__reach_required_distance_curriculum_enabled:
+ self.__reach_required_distance_curriculum.reset_task()
+ if self.__lift_required_height_curriculum_enabled:
+ self.__lift_required_height_curriculum.reset_task()
+
+ def on_episode_success(self):
+
+ self.update_success_rate(is_success=True)
+
+ def on_episode_failure(self):
+
+ self.update_success_rate(is_success=False)
+
+ def on_episode_timeout(self):
+
+ self.update_success_rate(is_success=False)
+
+ def get_reward_REACH(
+ self,
+ ee_position: Tuple[float, float, float],
+ object_positions: Dict[str, Tuple[float, float, float]],
+ **kwargs,
+ ) -> float:
+
+ if not object_positions:
+ return 0.0
+
+ nearest_object_distance = distance_to_nearest_point(
+ origin=ee_position, points=list(object_positions.values())
+ )
+
+ self.__task.get_logger().debug(
+ f"[Curriculum] Distance to nearest object: {nearest_object_distance}"
+ )
+ if nearest_object_distance < self.reach_required_distance:
+ self.__task.get_logger().info(
+ f"[Curriculum] An object is now closer than the required distance of {self.reach_required_distance}"
+ )
+ self.stages_completed_this_episode[GraspStage.REACH] = True
+ return self.__stages_base_reward
+ else:
+ return 0.0
+
+ def get_reward_TOUCH(self, touched_objects: List[str], **kwargs) -> float:
+
+ if touched_objects:
+ self.__task.get_logger().info(
+ f"[Curriculum] Touched objects: {touched_objects}"
+ )
+ self.stages_completed_this_episode[GraspStage.TOUCH] = True
+ return self.__stages_base_reward
+ else:
+ return 0.0
+
+ def get_reward_GRASP(self, grasped_objects: List[str], **kwargs) -> float:
+
+ if grasped_objects:
+ self.__task.get_logger().info(
+ f"[Curriculum] Grasped objects: {grasped_objects}"
+ )
+ self.stages_completed_this_episode[GraspStage.GRASP] = True
+ return self.__stages_base_reward
+ else:
+ return 0.0
+
+ def get_reward_LIFT(
+ self,
+ object_positions: Dict[str, Tuple[float, float, float]],
+ grasped_objects: List[str],
+ **kwargs,
+ ) -> float:
+
+ if not (grasped_objects or object_positions):
+ return 0.0
+
+ for grasped_object in grasped_objects:
+ grasped_object_height = object_positions[grasped_object][2]
+
+ self.__task.get_logger().debug(
+ f"[Curriculum] Height of grasped object '{grasped_objects}': {grasped_object_height}"
+ )
+ if grasped_object_height > self.lift_required_height:
+ self.__task.get_logger().info(
+ f"[Curriculum] Lifted object: {grasped_object}"
+ )
+ self.stages_completed_this_episode[GraspStage.LIFT] = True
+ return self.__stages_base_reward
+
+ return 0.0
+
+ def get_persistent_reward(
+ self, object_positions: Dict[str, Tuple[float, float, float]], **kwargs
+ ) -> float:
+
+ # Subtract a small reward each step to provide incentive to act quickly
+ reward = self.__persistent_reward_each_step
+
+ # Negative reward for colliding with terrain
+ if self.__persistent_reward_terrain_collision:
+ if self.__task.check_terrain_collision():
+ self.__task.get_logger().info(
+ "[Curriculum] Robot collided with the terrain"
+ )
+ reward += self.__persistent_reward_terrain_collision
+
+ # Negative reward for having all objects outside of the workspace
+ if self.__persistent_reward_all_objects_outside_workspace:
+ if self.__task.check_all_objects_outside_workspace(
+ object_positions=object_positions
+ ):
+ self.__task.get_logger().warn(
+ "[Curriculum] All objects are outside of the workspace"
+ )
+ reward += self.__persistent_reward_all_objects_outside_workspace
+ self.episode_failed = True
+
+ # Negative reward for arm getting stuck
+ if self.__persistent_reward_arm_stuck:
+ if ArmStuckChecker.is_robot_stuck(self):
+ self.__task.get_logger().error(
+ f"[Curriculum] Robot appears to be stuck, resetting..."
+ )
+ reward += self.__persistent_reward_arm_stuck
+ self.episode_failed = True
+
+ return reward
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/manipulation.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/manipulation.py
new file mode 100644
index 0000000..b1bb7a0
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/manipulation.py
@@ -0,0 +1,707 @@
+import abc
+import multiprocessing
+import sys
+from itertools import count
+from threading import Thread
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import rclpy
+from env_manager.models.configs import RobotData
+from env_manager.models.robots import RobotWrapper, get_robot_model_class
+from env_manager.scene import Scene
+from env_manager.utils import Tf2Broadcaster, Tf2Listener
+from env_manager.utils.conversions import orientation_6d_to_quat
+from env_manager.utils.gazebo import (
+ Point,
+ Pose,
+ Quat,
+ get_model_orientation,
+ get_model_pose,
+ get_model_position,
+ transform_change_reference_frame_orientation,
+ transform_change_reference_frame_pose,
+ transform_change_reference_frame_position,
+)
+from env_manager.utils.math import quat_mul
+from gym_gz.base.task import Task
+from gym_gz.scenario.model_wrapper import ModelWrapper
+from gym_gz.utils.typing import (
+ Action,
+ ActionSpace,
+ Observation,
+ ObservationSpace,
+ Reward,
+)
+from rclpy.callback_groups import ReentrantCallbackGroup
+from rclpy.executors import MultiThreadedExecutor, SingleThreadedExecutor
+from rclpy.node import Node
+from robot_builder.elements.robot import Robot
+from scipy.spatial.transform import Rotation
+
+from rbs_gym.envs.control import (
+ CartesianForceController,
+ GripperController,
+ JointEffortController,
+)
+
+
+class Manipulation(Task, Node, abc.ABC):
+ _ids = count(0)
+
+ def __init__(
+ self,
+ agent_rate: float,
+ robot_model: str,
+ workspace_frame_id: str,
+ workspace_centre: Tuple[float, float, float],
+ workspace_volume: Tuple[float, float, float],
+ ignore_new_actions_while_executing: bool,
+ use_servo: bool,
+ scaling_factor_translation: float,
+ scaling_factor_rotation: float,
+ restrict_position_goal_to_workspace: bool,
+ enable_gripper: bool,
+ num_threads: int,
+ **kwargs,
+ ):
+ # Get next ID for this task instance
+ self.id = next(self._ids)
+
+ # Initialize the Task base class
+ Task.__init__(self, agent_rate=agent_rate)
+
+ # Initialize ROS 2 context (if not done before)
+ try:
+ rclpy.init()
+ except Exception as e:
+ if not rclpy.ok():
+ sys.exit(f"ROS 2 context could not be initialised: {e}")
+
+ # Initialize ROS 2 Node base class
+ Node.__init__(self, f"rbs_gym_{self.id}")
+
+ self._allow_undeclared_parameters = True
+ self.declare_parameter("robot_description", "")
+ # Create callback group that allows execution of callbacks in parallel without restrictions
+ self._callback_group = ReentrantCallbackGroup()
+
+ # Create executor
+ if num_threads == 1:
+ executor = SingleThreadedExecutor()
+ elif num_threads > 1:
+ executor = MultiThreadedExecutor(
+ num_threads=num_threads,
+ )
+ else:
+ executor = MultiThreadedExecutor(num_threads=multiprocessing.cpu_count())
+
+ # Add this node to the executor
+ executor.add_node(self)
+
+ # Spin this node in background thread(s)
+ self._executor_thread = Thread(target=executor.spin, daemon=True, args=())
+ self._executor_thread.start()
+
+ # Get class of the robot model based on passed argument
+ # self.robot_model_class = get_robot_model_class(robot_model)
+
+ # Store passed arguments for later use
+ # self.workspace_centre = (
+ # workspace_centre[0],
+ # workspace_centre[1],
+ # workspace_centre[2] + self.robot_model_class.BASE_LINK_Z_OFFSET,
+ # )
+ # self.workspace_volume = workspace_volume
+ # self._restrict_position_goal_to_workspace = restrict_position_goal_to_workspace
+ # self._use_servo = use_servo
+ # self.__scaling_factor_translation = scaling_factor_translation
+ # self.__scaling_factor_rotation = scaling_factor_rotation
+ self._enable_gripper = enable_gripper
+ self._scene: Scene | None = None
+ #
+ # # Get workspace bounds, useful is many computations
+ # workspace_volume_half = (
+ # workspace_volume[0] / 2,
+ # workspace_volume[1] / 2,
+ # workspace_volume[2] / 2,
+ # )
+ # self.workspace_min_bound = (
+ # self.workspace_centre[0] - workspace_volume_half[0],
+ # self.workspace_centre[1] - workspace_volume_half[1],
+ # self.workspace_centre[2] - workspace_volume_half[2],
+ # )
+ # self.workspace_max_bound = (
+ # self.workspace_centre[0] + workspace_volume_half[0],
+ # self.workspace_centre[1] + workspace_volume_half[1],
+ # self.workspace_centre[2] + workspace_volume_half[2],
+ # )
+
+ # Determine robot name and prefix based on current ID of the task
+ # self.robot_prefix: str = self.robot_model_class.DEFAULT_PREFIX
+ # if 0 == self.id:
+ # self.robot_name = self.robot_model_class.ROBOT_MODEL_NAME
+ # else:
+ # self.robot_name = f"{self.robot_model_class.ROBOT_MODEL_NAME}{self.id}"
+ # if self.robot_prefix.endswith("_"):
+ # self.robot_prefix = f"{self.robot_prefix[:-1]}{self.id}_"
+ # elif self.robot_prefix == "":
+ # self.robot_prefix = f"robot{self.id}_"
+
+ # Names of specific robot links, useful all around the code
+ # self.robot_base_link_name = self.robot_model_class.get_robot_base_link_name(
+ # self.robot_prefix
+ # )
+ # self.robot_arm_base_link_name = self.robot_model_class.get_arm_base_link_name(
+ # self.robot_prefix
+ # )
+ # self.robot_ee_link_name = self.robot_model_class.get_ee_link_name(
+ # self.robot_prefix
+ # )
+ # self.robot_arm_link_names = self.robot_model_class.get_arm_link_names(
+ # self.robot_prefix
+ # )
+ # self.robot_gripper_link_names = self.robot_model_class.get_gripper_link_names(
+ # self.robot_prefix
+ # )
+ # self.robot_arm_joint_names = self.robot_model_class.get_arm_joint_names(
+ # self.robot_prefix
+ # )
+ # self.robot_gripper_joint_names = self.robot_model_class.get_gripper_joint_names(
+ # self.robot_prefix
+ # )
+
+ # Get exact name substitution of the frame for workspace
+ self.workspace_frame_id = self.substitute_special_frame(workspace_frame_id)
+
+ # Specify initial positions (default configuration is used here)
+ # self.initial_arm_joint_positions = (
+ # self.robot_model_class.DEFAULT_ARM_JOINT_POSITIONS
+ # )
+ # self.initial_gripper_joint_positions = (
+ # self.robot_model_class.DEFAULT_GRIPPER_JOINT_POSITIONS
+ # )
+
+ # Names of important models (in addition to robot model)
+
+ # Setup listener and broadcaster of transforms via tf2
+ self.tf2_listener = Tf2Listener(node=self)
+ self.tf2_broadcaster = Tf2Broadcaster(node=self)
+
+ self.cartesian_control = True # TODO: make it as an external parameter
+
+ # Setup controllers
+ self.controller = CartesianForceController(self)
+ if not self.cartesian_control:
+ self.joint_controller = JointEffortController(self)
+ if self._enable_gripper:
+ self.gripper = GripperController(self, 0.064)
+
+ # Initialize task and randomizer overrides (e.g. from curriculum)
+ # Both of these are consumed at the beginning of reset
+ self.__task_parameter_overrides: dict[str, Any] = {}
+ self._randomizer_parameter_overrides: dict[str, Any] = {}
+
+ def create_spaces(self) -> Tuple[ActionSpace, ObservationSpace]:
+ action_space = self.create_action_space()
+ observation_space = self.create_observation_space()
+
+ return action_space, observation_space
+
+ def create_action_space(self) -> ActionSpace:
+ raise NotImplementedError()
+
+ def create_observation_space(self) -> ObservationSpace:
+ raise NotImplementedError()
+
+ def set_action(self, action: Action):
+ raise NotImplementedError()
+
+ def get_observation(self) -> Observation:
+ raise NotImplementedError()
+
+ def get_reward(self) -> Reward:
+ raise NotImplementedError()
+
+ def is_done(self) -> bool:
+ raise NotImplementedError()
+
+ def reset_task(self):
+ # self.__consume_parameter_overrides()
+ pass
+
+ # # Helper functions #
+ # def get_relative_ee_position(
+ # self, translation: Tuple[float, float, float]
+ # ) -> Tuple[float, float, float]:
+ # # Scale relative action to metric units
+ # translation = self.scale_relative_translation(translation)
+ # # Get current position
+ # current_position = self.get_ee_position()
+ # # Compute target position
+ # target_position = (
+ # current_position[0] + translation[0],
+ # current_position[1] + translation[1],
+ # current_position[2] + translation[2],
+ # )
+ #
+ # # Restrict target position to a limited workspace, if desired
+ # if self._restrict_position_goal_to_workspace:
+ # target_position = self.restrict_position_goal_to_workspace(target_position)
+ #
+ # return target_position
+ #
+ # def get_relative_ee_orientation(
+ # self,
+ # rotation: Union[
+ # float,
+ # Tuple[float, float, float, float],
+ # Tuple[float, float, float, float, float, float],
+ # ],
+ # representation: str = "quat",
+ # ) -> Tuple[float, float, float, float]:
+ # # Get current orientation
+ # current_quat_xyzw = self.get_ee_orientation()
+ #
+ # # For 'z' representation, result should always point down
+ # # Therefore, create a new quatertnion that contains only yaw component
+ # if "z" == representation:
+ # current_yaw = Rotation.from_quat(current_quat_xyzw).as_euler("xyz")[2]
+ # current_quat_xyzw = Rotation.from_euler(
+ # "xyz", [np.pi, 0, current_yaw]
+ # ).as_quat()
+ #
+ # # Convert relative orientation representation to quaternion
+ # relative_quat_xyzw = None
+ # if "quat" == representation:
+ # relative_quat_xyzw = rotation
+ # elif "6d" == representation:
+ # vectors = tuple(
+ # rotation[x : x + 3] for x, _ in enumerate(rotation) if x % 3 == 0
+ # )
+ # relative_quat_xyzw = orientation_6d_to_quat(vectors[0], vectors[1])
+ # elif "z" == representation:
+ # rotation = self.scale_relative_rotation(rotation)
+ # relative_quat_xyzw = Rotation.from_euler("xyz", [0, 0, rotation]).as_quat()
+ #
+ # # Compute target position (combine quaternions)
+ # target_quat_xyzw = quat_mul(current_quat_xyzw, relative_quat_xyzw)
+ #
+ # # Normalise quaternion (should not be needed, but just to be safe)
+ # target_quat_xyzw /= np.linalg.norm(target_quat_xyzw)
+ #
+ # return target_quat_xyzw
+
+ # def scale_relative_translation(
+ # self, translation: Tuple[float, float, float]
+ # ) -> Tuple[float, float, float]:
+ # return (
+ # self.__scaling_factor_translation * translation[0],
+ # self.__scaling_factor_translation * translation[1],
+ # self.__scaling_factor_translation * translation[2],
+ # )
+
+ # def scale_relative_rotation(
+ # self,
+ # rotation: Union[float, Tuple[float, float, float], np.floating, np.ndarray],
+ # ) -> float:
+ # if not hasattr(rotation, "__len__"):
+ # return self.__scaling_factor_rotation * rotation
+ # else:
+ # return (
+ # self.__scaling_factor_rotation * rotation[0],
+ # self.__scaling_factor_rotation * rotation[1],
+ # self.__scaling_factor_rotation * rotation[2],
+ # )
+ #
+ # def restrict_position_goal_to_workspace(
+ # self, position: Tuple[float, float, float]
+ # ) -> Tuple[float, float, float]:
+ # return (
+ # min(
+ # self.workspace_max_bound[0],
+ # max(
+ # self.workspace_min_bound[0],
+ # position[0],
+ # ),
+ # ),
+ # min(
+ # self.workspace_max_bound[1],
+ # max(
+ # self.workspace_min_bound[1],
+ # position[1],
+ # ),
+ # ),
+ # min(
+ # self.workspace_max_bound[2],
+ # max(
+ # self.workspace_min_bound[2],
+ # position[2],
+ # ),
+ # ),
+ # )
+
+ # def restrict_servo_translation_to_workspace(
+ # self, translation: tuple[float, float, float]
+ # ) -> tuple[float, float, float]:
+ # current_ee_position = self.get_ee_position()
+ #
+ # translation = tuple(
+ # 0.0
+ # if (
+ # current_ee_position[i] > self.workspace_max_bound[i]
+ # and translation[i] > 0.0
+ # )
+ # or (
+ # current_ee_position[i] < self.workspace_min_bound[i]
+ # and translation[i] < 0.0
+ # )
+ # else translation[i]
+ # for i in range(3)
+ # )
+ #
+ # return translation
+
+ def get_ee_pose(
+ self,
+ ) -> Pose:
+ """
+ Return the current pose of the end effector with respect to arm base link.
+ """
+
+ try:
+ robot_model = self.world.to_gazebo().get_model(self.robot.name).to_gazebo()
+ ee_position, ee_quat_xyzw = get_model_pose(
+ world=self.world,
+ model=robot_model,
+ link=self.robot.ee_link,
+ xyzw=True,
+ )
+ return transform_change_reference_frame_pose(
+ world=self.world,
+ position=ee_position,
+ quat=ee_quat_xyzw,
+ target_model=robot_model,
+ target_link=self.robot.base_link,
+ xyzw=True,
+ )
+ except Exception as e:
+ self.get_logger().warn(
+ f"Cannot get end effector pose from Gazebo ({e}), using tf2..."
+ )
+ transform = self.tf2_listener.lookup_transform_sync(
+ source_frame=self.robot.ee_link,
+ target_frame=self.robot.base_link,
+ retry=False,
+ )
+ if transform is not None:
+ return (
+ (
+ transform.translation.x,
+ transform.translation.y,
+ transform.translation.z,
+ ),
+ (
+ transform.rotation.x,
+ transform.rotation.y,
+ transform.rotation.z,
+ transform.rotation.w,
+ ),
+ )
+ else:
+ self.get_logger().error(
+ "Cannot get pose of the end effector (default values are returned)"
+ )
+ return (
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 0.0, 1.0),
+ )
+
+ def get_ee_position(self) -> Point:
+ """
+ Return the current position of the end effector with respect to arm base link.
+ """
+
+ try:
+ robot_model = self.robot_wrapper
+ ee_position = get_model_position(
+ world=self.world,
+ model=robot_model,
+ link=self.robot.ee_link,
+ )
+ return transform_change_reference_frame_position(
+ world=self.world,
+ position=ee_position,
+ target_model=robot_model,
+ target_link=self.robot.base_link,
+ )
+ except Exception as e:
+ self.get_logger().debug(
+ f"Cannot get end effector position from Gazebo ({e}), using tf2..."
+ )
+ transform = self.tf2_listener.lookup_transform_sync(
+ source_frame=self.robot.ee_link,
+ target_frame=self.robot.base_link,
+ retry=False,
+ )
+ if transform is not None:
+ return (
+ transform.translation.x,
+ transform.translation.y,
+ transform.translation.z,
+ )
+ else:
+ self.get_logger().error(
+ "Cannot get position of the end effector (default values are returned)"
+ )
+ return (0.0, 0.0, 0.0)
+
+ def get_ee_orientation(self) -> Quat:
+ """
+ Return the current xyzw quaternion of the end effector with respect to arm base link.
+ """
+
+ try:
+ robot_model = self.robot_wrapper
+ ee_quat_xyzw = get_model_orientation(
+ world=self.world,
+ model=robot_model,
+ link=self.robot.ee_link,
+ xyzw=True,
+ )
+ return transform_change_reference_frame_orientation(
+ world=self.world,
+ quat=ee_quat_xyzw,
+ target_model=robot_model,
+ target_link=self.robot.base_link,
+ xyzw=True,
+ )
+ except Exception as e:
+ self.get_logger().warn(
+ f"Cannot get end effector orientation from Gazebo ({e}), using tf2..."
+ )
+ transform = self.tf2_listener.lookup_transform_sync(
+ source_frame=self.robot.ee_link,
+ target_frame=self.robot.base_link,
+ retry=False,
+ )
+ if transform is not None:
+ return (
+ transform.rotation.x,
+ transform.rotation.y,
+ transform.rotation.z,
+ transform.rotation.w,
+ )
+ else:
+ self.get_logger().error(
+ "Cannot get orientation of the end effector (default values are returned)"
+ )
+ return (0.0, 0.0, 0.0, 1.0)
+
+ def get_object_position(
+ self, object_model: ModelWrapper | str
+ ) -> Point:
+ """
+ Return the current position of an object with respect to arm base link.
+ Note: Only simulated objects are currently supported.
+ """
+
+ try:
+ object_position = get_model_position(
+ world=self.world,
+ model=object_model,
+ )
+ return transform_change_reference_frame_position(
+ world=self.world,
+ position=object_position,
+ target_model=self.robot_wrapper.name(),
+ target_link=self.robot.base_link,
+ )
+ except Exception as e:
+ self.get_logger().error(
+ f"Cannot get position of {object_model} object (default values are returned): {e}"
+ )
+ return (0.0, 0.0, 0.0)
+
+ def get_object_positions(self) -> dict[str, Point]:
+ """
+ Return the current position of all objects with respect to arm base link.
+ Note: Only simulated objects are currently supported.
+ """
+
+ object_positions = {}
+
+ try:
+ robot_model = self.robot_wrapper
+ robot_arm_base_link = robot_model.get_link(link_name=self.robot.base_link)
+ for object_name in self.scene.__objects_unique_names:
+ object_position = get_model_position(
+ world=self.world,
+ model=object_name,
+ )
+ object_positions[object_name] = (
+ transform_change_reference_frame_position(
+ world=self.world,
+ position=object_position,
+ target_model=robot_model,
+ target_link=robot_arm_base_link,
+ )
+ )
+ except Exception as e:
+ self.get_logger().error(
+ f"Cannot get positions of all objects (empty Dict is returned): {e}"
+ )
+
+ return object_positions
+
+ def substitute_special_frame(self, frame_id: str) -> str:
+ if "arm_base_link" == frame_id:
+ return self.robot.base_link
+ elif "base_link" == frame_id:
+ return self.robot.base_link
+ elif "end_effector" == frame_id:
+ return self.robot.ee_link
+ elif "world" == frame_id:
+ try:
+ # In Gazebo, where multiple worlds are allowed
+ return self.world.to_gazebo().name()
+ except Exception as e:
+ self.get_logger().warn(f"{e}")
+ # Otherwise (e.g. real world)
+ return "rbs_gym_world"
+ else:
+ return frame_id
+
+ def wait_until_action_executed(self):
+ if self._enable_gripper:
+ #FIXME: code is not tested
+ self.gripper.wait_until_executed()
+
+ def move_to_initial_joint_configuration(self):
+ #TODO: Write this code for cartesian_control
+ pass
+
+ # self.moveit2.move_to_configuration(self.initial_arm_joint_positions)
+
+ # if (
+ # self.robot_model_class.CLOSED_GRIPPER_JOINT_POSITIONS
+ # == self.initial_gripper_joint_positions
+ # ):
+ # self.gripper.reset_close()
+ # else:
+ # self.gripper.reset_open()
+
+ def check_terrain_collision(self) -> bool:
+ """
+ Returns true if robot links are in collision with the ground.
+ """
+ robot_name_len = len(self.robot.name)
+ model_names = self.world.model_names()
+ terrain_model = self.world.get_model(self.scene.terrain.name)
+
+ for contact in terrain_model.contacts():
+ body_b = contact.body_b
+
+ if body_b.startswith(self.robot.name) and len(body_b) > robot_name_len:
+ link = body_b[robot_name_len + 2 :]
+
+ if link != self.robot.base_link and (
+ link in self.robot.actuated_joint_names
+ or link in self.robot.gripper_actuated_joint_names
+ ):
+ return True
+
+ return False
+
+ # def check_all_objects_outside_workspace(
+ # self,
+ # object_positions: Dict[str, Tuple[float, float, float]],
+ # ) -> bool:
+ # """
+ # Returns true if all objects are outside the workspace
+ # """
+ #
+ # return all(
+ # [
+ # self.check_object_outside_workspace(object_position)
+ # for object_position in object_positions.values()
+ # ]
+ # )
+ #
+ # def check_object_outside_workspace(
+ # self,
+ # object_position: Tuple[float, float, float],
+ # ) -> bool:
+ # """
+ # Returns true if the object is outside the workspace
+ # """
+ #
+ # return (
+ # object_position[0] < self.workspace_min_bound[0]
+ # or object_position[1] < self.workspace_min_bound[1]
+ # or object_position[2] < self.workspace_min_bound[2]
+ # or object_position[0] > self.workspace_max_bound[0]
+ # or object_position[1] > self.workspace_max_bound[1]
+ # or object_position[2] > self.workspace_max_bound[2]
+ # )
+
+ # def add_parameter_overrides(self, parameter_overrides: Dict[str, Any]):
+ # self.add_task_parameter_overrides(parameter_overrides)
+ # self.add_randomizer_parameter_overrides(parameter_overrides)
+ # #
+ # def add_task_parameter_overrides(self, parameter_overrides: Dict[str, Any]):
+ # self.__task_parameter_overrides.update(parameter_overrides)
+ # #
+ # def add_randomizer_parameter_overrides(self, parameter_overrides: Dict[str, Any]):
+ # self._randomizer_parameter_overrides.update(parameter_overrides)
+ #
+ # def __consume_parameter_overrides(self):
+ # for key, value in self.__task_parameter_overrides.items():
+ # if hasattr(self, key):
+ # setattr(self, key, value)
+ # elif hasattr(self, f"_{key}"):
+ # setattr(self, f"_{key}", value)
+ # elif hasattr(self, f"__{key}"):
+ # setattr(self, f"__{key}", value)
+ # else:
+ # self.get_logger().error(
+ # f"Override '{key}' is not supperted by the task."
+ # )
+ #
+ # self.__task_parameter_overrides.clear()
+
+ @property
+ def robot(self) -> Robot:
+ """The robot property."""
+ if self._scene:
+ return self._scene.robot_wrapper.robot
+ else:
+ raise ValueError("R")
+
+ @property
+ def robot_data(self) -> RobotData:
+ """The robot_data property."""
+ if self._scene:
+ return self._scene.robot
+ else:
+ raise ValueError("RD")
+
+ @property
+ def robot_wrapper(self) -> RobotWrapper:
+ """The robot_wrapper property."""
+ if self._scene:
+ return self._scene.robot_wrapper
+ else:
+ raise ValueError("RW")
+
+ @property
+ def scene(self) -> Scene:
+ """The scene property."""
+ if self._scene:
+ return self._scene
+ else:
+ raise ValueError("S")
+
+ @scene.setter
+ def scene(self, value: Scene):
+ self._scene = value
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/__init__.py
new file mode 100644
index 0000000..1fc68c3
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/__init__.py
@@ -0,0 +1,4 @@
+from .reach import Reach
+from .reach_color_image import ReachColorImage
+from .reach_depth_image import ReachDepthImage
+# from .reach_octree import ReachOctree
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/pick_and_place_imitate.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/pick_and_place_imitate.py
new file mode 100644
index 0000000..dd2d408
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/pick_and_place_imitate.py
@@ -0,0 +1,120 @@
+import abc
+
+import gymnasium as gym
+import numpy as np
+from gym_gz.utils.typing import Observation, ObservationSpace
+
+from rbs_gym.envs.models.sensors import Camera
+from rbs_gym.envs.observation import CameraSubscriber
+from rbs_gym.envs.tasks.reach import Reach
+
+
+class ReachColorImage(Reach, abc.ABC):
+ def __init__(
+ self,
+ camera_info: list[dict] = [],
+ monochromatic: bool = False,
+ **kwargs,
+ ):
+ # Initialize the Task base class
+ Reach.__init__(
+ self,
+ **kwargs,
+ )
+
+ # Store parameters for later use
+ self._camera_width = sum(w.get("width", 0) for w in camera_info)
+ self._camera_height = camera_info[0]["height"]
+ self._monochromatic = monochromatic
+ self._saved = False
+
+ self.camera_subs: list[CameraSubscriber] = []
+ for camera in camera_info:
+ # Perception (RGB camera)
+ self.camera_subs.append(
+ CameraSubscriber(
+ node=self,
+ topic=Camera.get_color_topic(camera["name"], camera["type"]),
+ is_point_cloud=False,
+ callback_group=self._callback_group,
+ )
+ )
+
+ def create_observation_space(self) -> ObservationSpace:
+ """
+ Creates the observation space for the Imitation Learning algorithm.
+
+ This method returns a dictionary-based observation space that includes:
+
+ - **image**: A 2D image captured from the robot's camera. The image's shape is determined by the camera's
+ width and height, and the number of channels (either monochromatic or RGB).
+ - Dimensions: `(camera_height, camera_width, channels)` where `channels` is 1 for monochromatic images
+ and 3 for RGB images.
+ - Pixel values are in the range `[0, 255]` with data type `np.uint8`.
+
+ - **joint_states**: The joint positions of the robot arm, represented as continuous values within the
+ range `[-1.0, 1.0]`, where `nDoF` refers to the number of degrees of freedom of the robot arm.
+
+ Returns:
+ gym.spaces.Dict: A dictionary defining the observation space for the learning algorithm,
+ containing both image and joint_states information.
+ """
+ return gym.spaces.Dict(
+ {
+ "image": gym.spaces.Box(
+ low=0,
+ high=255,
+ shape=(
+ self._camera_height,
+ self._camera_width,
+ 1 if self._monochromatic else 3,
+ ),
+ dtype=np.uint8,
+ ),
+ "joint_states": gym.spaces.Box(low=-1.0, high=1.0),
+ }
+ )
+
+ def get_observation(self) -> Observation:
+ # Get the latest images
+ image = []
+
+ for sub in self.camera_subs:
+ image.append(sub.get_observation())
+
+ image_width = sum(i.width for i in image)
+ image_height = image[0].height
+ # if image_width == self._camera_width and image_height == self._camera_height:
+
+ # image_data = np.concatenate([i.data for i in image], axis=1)
+
+ assert (
+ image_width == self._camera_width and image_height == self._camera_height
+ ), f"Error: Resolution of the input image does not match the configured observation space. ({image_width}x{image_height} instead of {self._camera_width}x{self._camera_height})"
+
+ # Reshape and create the observation
+ # color_image = np.array([i.data for i in image], dtype=np.uint8).reshape(
+ # self._camera_height, self._camera_width, 3
+ # )
+ color_image = np.concatenate(
+ [
+ np.array(i.data, dtype=np.uint8).reshape(i.height, i.width, 3)
+ for i in image
+ ],
+ axis=1,
+ )
+
+ # # Debug save images
+ # from PIL import Image
+ # img_color = Image.fromarray(color_image)
+ # img_color.save("img_color.png")
+
+ if self._monochromatic:
+ observation = Observation(color_image[:, :, 0])
+ else:
+ observation = Observation(color_image)
+
+ self.get_logger().debug(f"\nobservation: {observation}")
+
+ # Return the observation
+ return observation
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach.py
new file mode 100644
index 0000000..4d26ad5
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach.py
@@ -0,0 +1,199 @@
+import abc
+from typing import Tuple
+
+from geometry_msgs.msg import WrenchStamped
+import gymnasium as gym
+import numpy as np
+from gym_gz.utils.typing import (
+ Action,
+ ActionSpace,
+ Observation,
+ ObservationSpace,
+ Reward,
+)
+
+from rbs_gym.envs.tasks.manipulation import Manipulation
+from env_manager.utils.math import distance_to_nearest_point
+from std_msgs.msg import Float64MultiArray
+from rbs_gym.envs.observation import TwistSubscriber, JointStates
+from env_manager.utils.helper import get_object_position
+
+
+class Reach(Manipulation, abc.ABC):
+ def __init__(
+ self,
+ sparse_reward: bool,
+ collision_reward: float,
+ act_quick_reward: float,
+ required_accuracy: float,
+ **kwargs,
+ ):
+ # Initialize the Task base class
+ Manipulation.__init__(
+ self,
+ **kwargs,
+ )
+
+ # Additional parameters
+ self._sparse_reward: bool = sparse_reward
+ self._act_quick_reward = (
+ act_quick_reward if act_quick_reward >= 0.0 else -act_quick_reward
+ )
+ self._collision_reward = (
+ collision_reward if collision_reward >= 0.0 else -collision_reward
+ )
+ self._required_accuracy: float = required_accuracy
+
+ # Flag indicating if the task is done (performance - get_reward + is_done)
+ self._is_done: bool = False
+ self._is_truncated: bool = False
+ self._is_terminated: bool = False
+
+ # Distance to target in the previous step (or after reset)
+ self._previous_distance: float = 0.0
+ # self._collision_reward: float = collision_reward
+
+ # self.initial_gripper_joint_positions = self.robot_data.gripper_joint_positions
+ self.twist = TwistSubscriber(
+ self,
+ topic="/cartesian_force_controller/current_twist",
+ callback_group=self._callback_group,
+ )
+
+ self.joint_states = JointStates(
+ self, topic="/joint_states", callback_group=self._callback_group
+ )
+
+ self._action = WrenchStamped()
+ self._action_array: Action = np.array([])
+ self.followed_object_name = "sphere"
+
+ def create_action_space(self) -> ActionSpace:
+ # 0:3 - (x, y, z) force
+ # 3:6 - (x, y, z) torque
+ return gym.spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
+
+ def create_observation_space(self) -> ObservationSpace:
+ # 0:3 - (x, y, z) end effector position
+ # 3:6 - (x, y, z) target position
+ # 6:9 - (x, y, z) current twist
+ # Note: These could theoretically be restricted to the workspace and object spawn area instead of inf
+ return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32)
+
+ def set_action(self, action: Action):
+ # self.get_logger().debug(f"action: {action}")
+ # act = Float64MultiArray()
+ # act.data = action
+ # self.joint_controller.publisher.publish(act)
+ if isinstance(action, np.number):
+ raise RuntimeError("For the task Reach the action space should be array")
+
+ # Store action for reward function
+ self._action_array = action
+
+ # self._action.header.frame_id = self.robot_ee_link_name
+ self._action.header.stamp = self.get_clock().now().to_msg()
+ self._action.header.frame_id = self.robot.ee_link
+ self._action.wrench.force.x = float(action[0]) * 30.0
+ self._action.wrench.force.y = float(action[1]) * 30.0
+ self._action.wrench.force.z = float(action[2]) * 30.0
+ # self._action.wrench.torque.x = float(action[3]) * 10.0
+ # self._action.wrench.torque.y = float(action[4]) * 10.0
+ # self._action.wrench.torque.z = float(action[5]) * 10.0
+ self.controller.publisher.publish(self._action)
+
+ def get_observation(self) -> Observation:
+ # Get current end-effector and target positions
+ self.object_names = []
+ ee_position = self.get_ee_position()
+ target_position = self.get_object_position(self.followed_object_name)
+ # target_position = self.get_object_position(object_model=self.object_names[0])
+ # joint_states = tuple(self.joint_states.get_positions())
+
+ # self.get_logger().warn(f"joint_states: {joint_states[:7]}")
+
+ twist = self.twist.get_observation()
+ twt: tuple[float, float, float, float, float, float] = (
+ twist.twist.linear.x,
+ twist.twist.linear.y,
+ twist.twist.linear.z,
+ twist.twist.angular.x,
+ twist.twist.angular.y,
+ twist.twist.angular.z,
+ )
+
+ observation: Observation = np.concatenate([ee_position, target_position, twt], dtype=np.float32)
+ # Create the observation
+ # observation = Observation(
+ # np.concatenate([ee_position, target_position, twt], dtype=np.float32)
+ # )
+
+ self.get_logger().debug(f"\nobservation: {observation}")
+
+ # Return the observation
+ return observation
+
+ def get_reward(self) -> Reward:
+ reward = 0.0
+
+ # Compute the current distance to the target
+ current_distance = self.get_distance_to_target()
+
+ # Mark the episode done if target is reached
+ if current_distance < self._required_accuracy:
+ self._is_terminated = True
+ reward += 1.0 if self._sparse_reward else 0.0 # 100.0
+ self.get_logger().debug(f"reward_target: {reward}")
+
+ # Give reward based on how much closer robot got relative to the target for dense reward
+ if not self._sparse_reward:
+ distance_delta = self._previous_distance - current_distance
+ reward += distance_delta * 10.0
+ self._previous_distance = current_distance
+ self.get_logger().debug(f"reward_distance: {reward}")
+
+ if self.check_terrain_collision():
+ reward -= self._collision_reward
+ self._is_truncated = True
+ self.get_logger().debug(f"reward_collision: {reward}")
+
+ # Reward control
+ # reward -= np.abs(self._action_array).sum() * 0.01
+ # self.get_logger().debug(f"reward_c: {reward}")
+
+ # Subtract a small reward each step to provide incentive to act quickly (if enabled)
+ reward += self._act_quick_reward
+
+ self.get_logger().debug(f"reward: {reward}")
+
+ return Reward(reward)
+
+ def is_terminated(self) -> bool:
+ self.get_logger().debug(f"terminated: {self._is_terminated}")
+ return self._is_terminated
+
+ def is_truncated(self) -> bool:
+ self.get_logger().debug(f"truncated: {self._is_truncated}")
+ return self._is_truncated
+
+ def reset_task(self):
+ Manipulation.reset_task(self)
+
+ self._is_done = False
+ self._is_truncated = False
+ self._is_terminated = False
+
+ # Compute and store the distance after reset if using dense reward
+ if not self._sparse_reward:
+ self._previous_distance = self.get_distance_to_target()
+
+ self.get_logger().debug("\ntask reset")
+
+ def get_distance_to_target(self) -> float:
+ ee_position = self.get_ee_position()
+ object_position = self.get_object_position(object_model=self.followed_object_name)
+ self.tf2_broadcaster.broadcast_tf(
+ "world", "object", object_position, (0.0, 0.0, 0.0, 1.0), xyzw=True
+ )
+
+ return distance_to_nearest_point(origin=ee_position, points=[object_position])
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_color_image.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_color_image.py
new file mode 100644
index 0000000..32ea90a
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_color_image.py
@@ -0,0 +1,89 @@
+import abc
+
+import gymnasium as gym
+import numpy as np
+from gym_gz.utils.typing import Observation, ObservationSpace
+
+from env_manager.models.sensors import Camera
+from rbs_gym.envs.observation import CameraSubscriber
+from rbs_gym.envs.tasks.reach import Reach
+
+
+class ReachColorImage(Reach, abc.ABC):
+ def __init__(
+ self,
+ camera_info: list[dict] = [],
+ monochromatic: bool = False,
+ **kwargs,
+ ):
+
+ # Initialize the Task base class
+ Reach.__init__(
+ self,
+ **kwargs,
+ )
+
+ # Store parameters for later use
+ self._camera_width = sum(w.get('width', 0) for w in camera_info)
+ self._camera_height = camera_info[0]["height"]
+ self._monochromatic = monochromatic
+ self._saved = False
+
+ self.camera_subs: list[CameraSubscriber] = []
+ for camera in camera_info:
+ # Perception (RGB camera)
+ self.camera_subs.append(CameraSubscriber(
+ node=self,
+ topic=Camera.get_color_topic(camera["name"], camera["type"]),
+ is_point_cloud=False,
+ callback_group=self._callback_group,
+ ))
+
+ def create_observation_space(self) -> ObservationSpace:
+
+ # 0:3*height*width - rgb image
+ # 0:1*height*width - monochromatic (intensity) image
+ return gym.spaces.Box(
+ low=0,
+ high=255,
+ shape=(
+ self._camera_height,
+ self._camera_width,
+ 1 if self._monochromatic else 3,
+ ),
+ dtype=np.uint8,
+ )
+
+ def get_observation(self) -> Observation:
+ # Get the latest images
+ image = []
+
+ for sub in self.camera_subs:
+ image.append(sub.get_observation())
+
+ image_width = sum(i.width for i in image)
+ image_height = image[0].height
+
+ assert (
+ image_width == self._camera_width and image_height == self._camera_height
+ ), f"Error: Resolution of the input image does not match the configured observation space. ({image_width}x{image_height} instead of {self._camera_width}x{self._camera_height})"
+
+ color_image = np.concatenate(
+ [np.array(i.data, dtype=np.uint8).reshape(i.height, i.width, 3) for i in image],
+ axis=1
+ )
+
+ # # Debug save images
+ # from PIL import Image
+ # img_color = Image.fromarray(color_image)
+ # img_color.save("img_color.png")
+
+ if self._monochromatic:
+ observation = Observation(color_image[:, :, 0])
+ else:
+ observation = Observation(color_image)
+
+ self.get_logger().debug(f"\nobservation: {observation}")
+
+ # Return the observation
+ return observation
diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_depth_image.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_depth_image.py
new file mode 100644
index 0000000..00d64b9
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_depth_image.py
@@ -0,0 +1,67 @@
+import abc
+
+import gymnasium as gym
+import numpy as np
+from gym_gz.utils.typing import Observation, ObservationSpace
+
+from env_manager.models.sensors import Camera
+from rbs_gym.envs.observation import CameraSubscriber
+from rbs_gym.envs.tasks.reach import Reach
+
+
+class ReachDepthImage(Reach, abc.ABC):
+ def __init__(
+ self,
+ camera_width: int,
+ camera_height: int,
+ camera_type: str = "depth_camera",
+ **kwargs,
+ ):
+
+ # Initialize the Task base class
+ Reach.__init__(
+ self,
+ **kwargs,
+ )
+
+ # Store parameters for later use
+ self._camera_width = camera_width
+ self._camera_height = camera_height
+
+ # Perception (depth camera)
+ self.camera_sub = CameraSubscriber(
+ node=self,
+ topic=Camera.get_depth_topic(camera_type),
+ is_point_cloud=False,
+ callback_group=self._callback_group,
+ )
+
+ def create_observation_space(self) -> ObservationSpace:
+
+ # 0:height*width - depth image
+ return gym.spaces.Box(
+ low=0,
+ high=np.inf,
+ shape=(self._camera_height, self._camera_width, 1),
+ dtype=np.float32,
+ )
+
+ def get_observation(self) -> Observation:
+
+ # Get the latest image
+ image = self.camera_sub.get_observation()
+
+ # Construct from buffer and reshape
+ depth_image = np.frombuffer(image.data, dtype=np.float32).reshape(
+ self._camera_height, self._camera_width, 1
+ )
+ # Replace all instances of infinity with 0
+ depth_image[depth_image == np.inf] = 0.0
+
+ # Create the observation
+ observation = Observation(depth_image)
+
+ self.get_logger().debug(f"\nobservation: {observation}")
+
+ # Return the observation
+ return observation
diff --git a/env_manager/rbs_gym/rbs_gym/scripts/evaluate.py b/env_manager/rbs_gym/rbs_gym/scripts/evaluate.py
new file mode 100755
index 0000000..813ab4d
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/scripts/evaluate.py
@@ -0,0 +1,295 @@
+#!/usr/bin/env -S python3 -O
+
+import argparse
+import os
+from typing import Dict
+
+import numpy as np
+import torch as th
+import yaml
+from rbs_gym import envs as gz_envs
+from rbs_gym.utils import create_test_env, get_latest_run_id, get_saved_hyperparams
+from rbs_gym.utils.utils import ALGOS, StoreDict, str2bool
+from stable_baselines3.common.utils import set_random_seed
+from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecEnvWrapper
+
+
+def find_model_path(log_path: str, args) -> str:
+ model_extensions = ["zip"]
+ for ext in model_extensions:
+ model_path = os.path.join(log_path, f"{args.env}.{ext}")
+ if os.path.isfile(model_path):
+ return model_path
+
+ if args.load_best:
+ best_model_path = os.path.join(log_path, "best_model.zip")
+ if os.path.isfile(best_model_path):
+ return best_model_path
+
+ if args.load_checkpoint is not None:
+ checkpoint_model_path = os.path.join(
+ log_path, f"rl_model_{args.load_checkpoint}_steps.zip"
+ )
+ if os.path.isfile(checkpoint_model_path):
+ return checkpoint_model_path
+
+ raise ValueError(f"No model found for {args.algo} on {args.env}, path: {log_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ # Environment and its parameters
+ parser.add_argument(
+ "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
+ )
+ parser.add_argument(
+ "--env-kwargs",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional keyword argument to pass to the env constructor",
+ )
+ parser.add_argument("--n-envs", type=int, default=1, help="Number of environments")
+
+ # Algorithm
+ parser.add_argument(
+ "--algo",
+ type=str,
+ choices=list(ALGOS.keys()),
+ required=False,
+ default="sac",
+ help="RL algorithm to use during the training",
+ )
+ parser.add_argument(
+ "--num-threads",
+ type=int,
+ default=-1,
+ help="Number of threads for PyTorch (-1 to use default)",
+ )
+
+ # Test duration
+ parser.add_argument(
+ "-n",
+ "--n-episodes",
+ type=int,
+ default=200,
+ help="Number of evaluation episodes",
+ )
+
+ # Random seed
+ parser.add_argument("--seed", type=int, default=0, help="Random generator seed")
+
+ # Model to test
+ parser.add_argument(
+ "-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
+ )
+ parser.add_argument(
+ "--exp-id",
+ type=int,
+ default=0,
+ help="Experiment ID (default: 0: latest, -1: no exp folder)",
+ )
+ parser.add_argument(
+ "--load-best",
+ type=str2bool,
+ default=False,
+ help="Load best model instead of last model if available",
+ )
+ parser.add_argument(
+ "--load-checkpoint",
+ type=int,
+ help="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it",
+ )
+
+ # Deterministic/stochastic actions
+ parser.add_argument(
+ "--stochastic",
+ type=str2bool,
+ default=False,
+ help="Use stochastic actions instead of deterministic",
+ )
+
+ # Logging
+ parser.add_argument(
+ "--reward-log", type=str, default="reward_logs", help="Where to log reward"
+ )
+ parser.add_argument(
+ "--norm-reward",
+ type=str2bool,
+ default=False,
+ help="Normalize reward if applicable (trained with VecNormalize)",
+ )
+
+ # Disable render
+ parser.add_argument(
+ "--no-render",
+ type=str2bool,
+ default=False,
+ help="Do not render the environment (useful for tests)",
+ )
+
+ # Verbosity
+ parser.add_argument(
+ "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
+ )
+
+ args, unknown = parser.parse_known_args()
+
+ if args.exp_id == 0:
+ args.exp_id = get_latest_run_id(
+ os.path.join(args.log_folder, args.algo), args.env
+ )
+ print(f"Loading latest experiment, id={args.exp_id}")
+
+ # Sanity checks
+ if args.exp_id > 0:
+ log_path = os.path.join(args.log_folder, args.algo, f"{args.env}_{args.exp_id}")
+ else:
+ log_path = os.path.join(args.log_folder, args.algo)
+
+ assert os.path.isdir(log_path), f"The {log_path} folder was not found"
+
+ model_path = find_model_path(log_path, args)
+ off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
+
+ if args.algo in off_policy_algos:
+ args.n_envs = 1
+
+ set_random_seed(args.seed)
+
+ if args.num_threads > 0:
+ if args.verbose > 1:
+ print(f"Setting torch.num_threads to {args.num_threads}")
+ th.set_num_threads(args.num_threads)
+
+ stats_path = os.path.join(log_path, args.env)
+ hyperparams, stats_path = get_saved_hyperparams(
+ stats_path, norm_reward=args.norm_reward, test_mode=True
+ )
+
+ # load env_kwargs if existing
+ env_kwargs = {}
+ args_path = os.path.join(log_path, args.env, "args.yml")
+ if os.path.isfile(args_path):
+ with open(args_path, "r") as f:
+ loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader)
+ if loaded_args["env_kwargs"] is not None:
+ env_kwargs = loaded_args["env_kwargs"]
+ # overwrite with command line arguments
+ if args.env_kwargs is not None:
+ env_kwargs.update(args.env_kwargs)
+
+ log_dir = args.reward_log if args.reward_log != "" else None
+
+ env = create_test_env(
+ args.env,
+ n_envs=args.n_envs,
+ stats_path=stats_path,
+ seed=args.seed,
+ log_dir=log_dir,
+ should_render=not args.no_render,
+ hyperparams=hyperparams,
+ env_kwargs=env_kwargs,
+ )
+
+ kwargs = dict(seed=args.seed)
+ if args.algo in off_policy_algos:
+ # Dummy buffer size as we don't need memory to evaluate the trained agent
+ kwargs.update(dict(buffer_size=1))
+
+ custom_objects = {
+ "observation_space": env.observation_space,
+ "action_space": env.action_space,
+ }
+
+ model = ALGOS[args.algo].load(
+ model_path, env=env, custom_objects=custom_objects, **kwargs
+ )
+
+ obs = env.reset()
+
+ # Deterministic by default
+ stochastic = args.stochastic
+ deterministic = not stochastic
+
+ print(
+ f"Evaluating for {args.n_episodes} episodes with a",
+ "deterministic" if deterministic else "stochastic",
+ "policy.",
+ )
+
+ state = None
+ episode_reward = 0.0
+ episode_rewards, episode_lengths, success_episode_lengths = [], [], []
+ ep_len = 0
+ episode = 0
+ # For HER, monitor success rate
+ successes = []
+ while episode < args.n_episodes:
+ action, state = model.predict(obs, state=state, deterministic=deterministic)
+ obs, reward, done, infos = env.step(action)
+ if not args.no_render:
+ env.render("human")
+
+ episode_reward += reward[0]
+ ep_len += 1
+
+ if done and args.verbose > 0:
+ episode += 1
+ print(f"--- Episode {episode}/{args.n_episodes}")
+ # NOTE: for env using VecNormalize, the mean reward
+ # is a normalized reward when `--norm_reward` flag is passed
+ print(f"Episode Reward: {episode_reward:.2f}")
+ episode_rewards.append(episode_reward)
+ print("Episode Length", ep_len)
+ episode_lengths.append(ep_len)
+ if infos[0].get("is_success") is not None:
+ print("Success?:", infos[0].get("is_success", False))
+ successes.append(infos[0].get("is_success", False))
+ if infos[0].get("is_success"):
+ success_episode_lengths.append(ep_len)
+ print(f"Current success rate: {100 * np.mean(successes):.2f}%")
+ episode_reward = 0.0
+ ep_len = 0
+ state = None
+
+ if args.verbose > 0 and len(successes) > 0:
+ print(f"Success rate: {100 * np.mean(successes):.2f}%")
+
+ if args.verbose > 0 and len(episode_rewards) > 0:
+ print(
+ f"Mean reward: {np.mean(episode_rewards):.2f} "
+ f"+/- {np.std(episode_rewards):.2f}"
+ )
+
+ if args.verbose > 0 and len(episode_lengths) > 0:
+ print(
+ f"Mean episode length: {np.mean(episode_lengths):.2f} "
+ f"+/- {np.std(episode_lengths):.2f}"
+ )
+
+ if args.verbose > 0 and len(success_episode_lengths) > 0:
+ print(
+ f"Mean episode length of successful episodes: {np.mean(success_episode_lengths):.2f} "
+ f"+/- {np.std(success_episode_lengths):.2f}"
+ )
+
+ # Workaround for https://github.com/openai/gym/issues/893
+ if not args.no_render:
+ if args.n_envs == 1 and "Bullet" not in args.env and isinstance(env, VecEnv):
+ # DummyVecEnv
+ # Unwrap env
+ while isinstance(env, VecEnvWrapper):
+ env = env.venv
+ if isinstance(env, DummyVecEnv):
+ env.envs[0].env.close()
+ else:
+ env.close()
+ else:
+ # SubprocVecEnv
+ env.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/env_manager/rbs_gym/rbs_gym/scripts/optimize.py b/env_manager/rbs_gym/rbs_gym/scripts/optimize.py
new file mode 100644
index 0000000..d840a67
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/scripts/optimize.py
@@ -0,0 +1,234 @@
+#!/usr/bin/env -S python3 -O
+
+import argparse
+import difflib
+import os
+import uuid
+from typing import Dict
+
+import gymnasium as gym
+import numpy as np
+import torch as th
+from rbs_gym import envs as gz_envs
+from rbs_gym.utils.exp_manager import ExperimentManager
+from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
+from stable_baselines3.common.utils import set_random_seed
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ # Environment and its parameters
+ parser.add_argument(
+ "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
+ )
+ parser.add_argument(
+ "--env-kwargs",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional keyword argument to pass to the env constructor",
+ )
+ parser.add_argument(
+ "--vec-env",
+ type=str,
+ choices=["dummy", "subproc"],
+ default="dummy",
+ help="Type of VecEnv to use",
+ )
+
+ # Algorithm and training
+ parser.add_argument(
+ "--algo",
+ type=str,
+ choices=list(ALGOS.keys()),
+ required=False,
+ default="sac",
+ help="RL algorithm to use during the training",
+ )
+ parser.add_argument(
+ "-params",
+ "--hyperparams",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
+ )
+ parser.add_argument(
+ "-n",
+ "--n-timesteps",
+ type=int,
+ default=-1,
+ help="Overwrite the number of timesteps",
+ )
+ parser.add_argument(
+ "--num-threads",
+ type=int,
+ default=-1,
+ help="Number of threads for PyTorch (-1 to use default)",
+ )
+
+ # Continue training an already trained agent
+ parser.add_argument(
+ "-i",
+ "--trained-agent",
+ type=str,
+ default="",
+ help="Path to a pretrained agent to continue training",
+ )
+
+ # Random seed
+ parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
+
+ # Saving of model
+ parser.add_argument(
+ "--save-freq",
+ type=int,
+ default=10000,
+ help="Save the model every n steps (if negative, no checkpoint)",
+ )
+ parser.add_argument(
+ "--save-replay-buffer",
+ type=str2bool,
+ default=False,
+ help="Save the replay buffer too (when applicable)",
+ )
+
+ # Pre-load a replay buffer and start training on it
+ parser.add_argument(
+ "--preload-replay-buffer",
+ type=empty_str2none,
+ default="",
+ help="Path to a replay buffer that should be preloaded before starting the training process",
+ )
+
+ # Logging
+ parser.add_argument(
+ "-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
+ )
+ parser.add_argument(
+ "-tb",
+ "--tensorboard-log",
+ type=empty_str2none,
+ default="tensorboard_logs",
+ help="Tensorboard log dir",
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=-1,
+ help="Override log interval (default: -1, no change)",
+ )
+ parser.add_argument(
+ "-uuid",
+ "--uuid",
+ type=str2bool,
+ default=False,
+ help="Ensure that the run has a unique ID",
+ )
+
+ # Evaluation
+ parser.add_argument(
+ "--eval-freq",
+ type=int,
+ default=-1,
+ help="Evaluate the agent every n steps (if negative, no evaluation)",
+ )
+ parser.add_argument(
+ "--eval-episodes",
+ type=int,
+ default=5,
+ help="Number of episodes to use for evaluation",
+ )
+
+ # Verbosity
+ parser.add_argument(
+ "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
+ )
+
+ # HER specifics
+ parser.add_argument(
+ "--truncate-last-trajectory",
+ type=str2bool,
+ default=True,
+ help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
+ )
+
+ args, unknown = parser.parse_known_args()
+
+ # Check if the selected environment is valid
+ # If it could not be found, suggest the closest match
+ registered_envs = set(gym.envs.registry.keys())
+ if args.env not in registered_envs:
+ try:
+ closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
+ except IndexError:
+ closest_match = "'no close match found...'"
+ raise ValueError(
+ f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
+ )
+
+ # If no specific seed is selected, choose a random one
+ if args.seed < 0:
+ args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
+
+ # Set the random seed across platforms
+ set_random_seed(args.seed)
+
+ # Setting num threads to 1 makes things run faster on cpu
+ if args.num_threads > 0:
+ if args.verbose > 1:
+ print(f"Setting torch.num_threads to {args.num_threads}")
+ th.set_num_threads(args.num_threads)
+
+ # Verify that pre-trained agent exists before continuing to train it
+ if args.trained_agent != "":
+ assert args.trained_agent.endswith(".zip") and os.path.isfile(
+ args.trained_agent
+ ), "The trained_agent must be a valid path to a .zip file"
+
+ # If enabled, ensure that the run has a unique ID
+ uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
+
+ print("=" * 10, args.env, "=" * 10)
+ print(f"Seed: {args.seed}")
+ env_kwargs = {"render_mode": "human"}
+
+ exp_manager = ExperimentManager(
+ args,
+ args.algo,
+ args.env,
+ args.log_folder,
+ args.tensorboard_log,
+ args.n_timesteps,
+ args.eval_freq,
+ args.eval_episodes,
+ args.save_freq,
+ args.hyperparams,
+ args.env_kwargs,
+ args.trained_agent,
+ truncate_last_trajectory=args.truncate_last_trajectory,
+ uuid_str=uuid_str,
+ seed=args.seed,
+ log_interval=args.log_interval,
+ save_replay_buffer=args.save_replay_buffer,
+ preload_replay_buffer=args.preload_replay_buffer,
+ verbose=args.verbose,
+ vec_env_type=args.vec_env,
+ )
+
+ # Prepare experiment
+ results = exp_manager.setup_experiment()
+ if results is not None:
+ model, saved_hyperparams = results
+
+ # Normal training
+ if model is not None:
+ exp_manager.learn(model)
+ exp_manager.save_trained_model(model)
+ else:
+ exp_manager.hyperparameters_optimization()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/env_manager/rbs_gym/rbs_gym/scripts/test_agent.py b/env_manager/rbs_gym/rbs_gym/scripts/test_agent.py
new file mode 100755
index 0000000..e201b4b
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/scripts/test_agent.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+
+import argparse
+from typing import Dict
+
+import gymnasium as gym
+from stable_baselines3.common.env_checker import check_env
+
+from rbs_gym import envs as gz_envs
+from rbs_gym.utils.utils import StoreDict, str2bool
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ # Environment and its parameters
+ parser.add_argument(
+ "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
+ )
+ parser.add_argument(
+ "--env-kwargs",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional keyword argument to pass to the env constructor",
+ )
+
+ # Number of episodes to run
+ parser.add_argument(
+ "-n",
+ "--n-episodes",
+ type=int,
+ default=10000,
+ help="Overwrite the number of episodes",
+ )
+
+ # Random seed
+ parser.add_argument("--seed", type=int, default=69, help="Random generator seed")
+
+ # Flag to check environment
+ parser.add_argument(
+ "--check-env",
+ type=str2bool,
+ default=True,
+ help="Flag to check the environment before running the random agent",
+ )
+
+ # Flag to enable rendering
+ parser.add_argument(
+ "--render",
+ type=str2bool,
+ default=False,
+ help="Flag to enable rendering",
+ )
+
+ args, unknown = parser.parse_known_args()
+
+ # Create the environment
+ if args.env_kwargs is not None:
+ env = gym.make(args.env, **args.env_kwargs)
+ else:
+ env = gym.make(args.env)
+
+ # Initialize random seed
+ env.seed(args.seed)
+
+ # Check the environment
+ if args.check_env:
+ check_env(env, warn=True, skip_render_check=True)
+
+ # Step environment for bunch of episodes
+ for episode in range(args.n_episodes):
+ # Initialize returned values
+ done = False
+ total_reward = 0
+
+ # Reset the environment
+ observation = env.reset()
+
+ # Step through the current episode until it is done
+ while not done:
+ # Sample random action
+ action = env.action_space.sample()
+
+ # Step the environment with the random action
+ observation, reward, truncated, terminated, info = env.step(action)
+
+ done = truncated or terminated
+
+ # Accumulate the reward
+ total_reward += reward
+
+ print(f"Episode #{episode}\n\treward: {total_reward}")
+
+ # Cleanup once done
+ env.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/env_manager/rbs_gym/rbs_gym/scripts/train.py b/env_manager/rbs_gym/rbs_gym/scripts/train.py
new file mode 100755
index 0000000..87ea498
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/scripts/train.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env -S python3 -O
+
+import argparse
+import difflib
+import os
+import uuid
+
+import gymnasium as gym
+import numpy as np
+import torch as th
+from stable_baselines3.common.utils import set_random_seed
+
+from rbs_gym import envs as gz_envs
+from rbs_gym.utils.exp_manager import ExperimentManager
+from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool
+
+
+def main():
+
+ parser = argparse.ArgumentParser()
+
+ # Environment and its parameters
+ parser.add_argument(
+ "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID"
+ )
+ parser.add_argument(
+ "--env-kwargs",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional keyword argument to pass to the env constructor",
+ )
+ parser.add_argument(
+ "--vec-env",
+ type=str,
+ choices=["dummy", "subproc"],
+ default="dummy",
+ help="Type of VecEnv to use",
+ )
+
+ # Algorithm and training
+ parser.add_argument(
+ "--algo",
+ type=str,
+ choices=list(ALGOS.keys()),
+ required=False,
+ default="sac",
+ help="RL algorithm to use during the training",
+ )
+ parser.add_argument(
+ "-params",
+ "--hyperparams",
+ type=str,
+ nargs="+",
+ action=StoreDict,
+ help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)",
+ )
+ parser.add_argument(
+ "-n",
+ "--n-timesteps",
+ type=int,
+ default=-1,
+ help="Overwrite the number of timesteps",
+ )
+ parser.add_argument(
+ "--num-threads",
+ type=int,
+ default=-1,
+ help="Number of threads for PyTorch (-1 to use default)",
+ )
+
+ # Continue training an already trained agent
+ parser.add_argument(
+ "-i",
+ "--trained-agent",
+ type=str,
+ default="",
+ help="Path to a pretrained agent to continue training",
+ )
+
+ # Random seed
+ parser.add_argument("--seed", type=int, default=-1, help="Random generator seed")
+
+ # Saving of model
+ parser.add_argument(
+ "--save-freq",
+ type=int,
+ default=10000,
+ help="Save the model every n steps (if negative, no checkpoint)",
+ )
+ parser.add_argument(
+ "--save-replay-buffer",
+ type=str2bool,
+ default=False,
+ help="Save the replay buffer too (when applicable)",
+ )
+
+ # Pre-load a replay buffer and start training on it
+ parser.add_argument(
+ "--preload-replay-buffer",
+ type=empty_str2none,
+ default="",
+ help="Path to a replay buffer that should be preloaded before starting the training process",
+ )
+ parser.add_argument(
+ "--track",
+ type=str2bool,
+ default=False,
+ help="Track experiment using wandb"
+ )
+ # optimization parameters
+ parser.add_argument(
+ "--optimize-hyperparameters",
+ type=str2bool,
+ default=False,
+ help="Run optimization or not?"
+ )
+ parser.add_argument(
+ "--no-optim-plots", action="store_true", default=False, help="Disable hyperparameter optimization plots"
+ )
+
+ # Logging
+ parser.add_argument(
+ "-f", "--log-folder", type=str, default="logs", help="Path to the log directory"
+ )
+ parser.add_argument(
+ "-tb",
+ "--tensorboard-log",
+ type=empty_str2none,
+ default="tensorboard_logs",
+ help="Tensorboard log dir",
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=-1,
+ help="Override log interval (default: -1, no change)",
+ )
+ parser.add_argument(
+ "-uuid",
+ "--uuid",
+ type=str2bool,
+ default=False,
+ help="Ensure that the run has a unique ID",
+ )
+
+ # Evaluation
+ parser.add_argument(
+ "--eval-freq",
+ type=int,
+ default=-1,
+ help="Evaluate the agent every n steps (if negative, no evaluation)",
+ )
+ parser.add_argument(
+ "--eval-episodes",
+ type=int,
+ default=5,
+ help="Number of episodes to use for evaluation",
+ )
+
+ # Verbosity
+ parser.add_argument(
+ "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)"
+ )
+
+ # HER specifics
+ parser.add_argument(
+ "--truncate-last-trajectory",
+ type=str2bool,
+ default=True,
+ help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.",
+ )
+
+ args, unknown = parser.parse_known_args()
+
+ # Check if the selected environment is valid
+ # If it could not be found, suggest the closest match
+ registered_envs = set(gym.envs.registry.keys())
+ if args.env not in registered_envs:
+ try:
+ closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0]
+ except IndexError:
+ closest_match = "'no close match found...'"
+ raise ValueError(
+ f"{args.env} not found in gym registry, you maybe meant {closest_match}?"
+ )
+
+ # If no specific seed is selected, choose a random one
+ if args.seed < 0:
+ args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item()
+
+ # Set the random seed across platforms
+ set_random_seed(args.seed)
+
+ # Setting num threads to 1 makes things run faster on cpu
+ if args.num_threads > 0:
+ if args.verbose > 1:
+ print(f"Setting torch.num_threads to {args.num_threads}")
+ th.set_num_threads(args.num_threads)
+
+ # Verify that pre-trained agent exists before continuing to train it
+ if args.trained_agent != "":
+ assert args.trained_agent.endswith(".zip") and os.path.isfile(
+ args.trained_agent
+ ), "The trained_agent must be a valid path to a .zip file"
+
+ # If enabled, ensure that the run has a unique ID
+ uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
+
+ print("=" * 10, args.env, "=" * 10)
+ print(f"Seed: {args.seed}")
+
+ if args.track:
+ try:
+ import datetime
+ except ImportError as e:
+ raise ImportError(
+ "if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
+ ) from e
+
+ run_name_str = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
+ args.tensorboard_log = "runs"
+
+
+ exp_manager = ExperimentManager(
+ args,
+ args.algo,
+ args.env,
+ args.log_folder,
+ args.tensorboard_log,
+ args.n_timesteps,
+ args.eval_freq,
+ args.eval_episodes,
+ args.save_freq,
+ args.hyperparams,
+ args.env_kwargs,
+ args.trained_agent,
+ args.optimize_hyperparameters,
+ run_name=run_name_str,
+ truncate_last_trajectory=args.truncate_last_trajectory,
+ uuid_str=uuid_str,
+ seed=args.seed,
+ log_interval=args.log_interval,
+ save_replay_buffer=args.save_replay_buffer,
+ preload_replay_buffer=args.preload_replay_buffer,
+ verbose=args.verbose,
+ vec_env_type=args.vec_env,
+ no_optim_plots=args.no_optim_plots,
+ )
+
+ # Prepare experiment
+ results = exp_manager.setup_experiment()
+ if results is not None:
+ model, saved_hyperparams = results
+
+ # Normal training
+ if model is not None:
+ exp_manager.learn(model, saved_hyperparams)
+ exp_manager.save_trained_model(model)
+ else:
+ exp_manager.hyperparameters_optimization()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/env_manager/rbs_gym/rbs_gym/utils/__init__.py b/env_manager/rbs_gym/rbs_gym/utils/__init__.py
new file mode 100644
index 0000000..b55f8bc
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/utils/__init__.py
@@ -0,0 +1,9 @@
+from .utils import (
+ ALGOS,
+ create_test_env,
+ get_latest_run_id,
+ get_saved_hyperparams,
+ get_trained_models,
+ get_wrapper_class,
+ linear_schedule,
+)
diff --git a/env_manager/rbs_gym/rbs_gym/utils/callbacks.py b/env_manager/rbs_gym/rbs_gym/utils/callbacks.py
new file mode 100644
index 0000000..66c6bdd
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/utils/callbacks.py
@@ -0,0 +1,306 @@
+import os
+import tempfile
+import time
+from copy import deepcopy
+from functools import wraps
+from threading import Thread
+from typing import Optional
+
+import optuna
+from sb3_contrib import TQC
+from stable_baselines3 import SAC
+from stable_baselines3.common.callbacks import (
+ BaseCallback,
+ CheckpointCallback,
+ EvalCallback,
+)
+from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv
+
+
+class TrialEvalCallback(EvalCallback):
+ """
+ Callback used for evaluating and reporting a trial.
+ """
+
+ def __init__(
+ self,
+ eval_env: VecEnv,
+ trial: optuna.Trial,
+ n_eval_episodes: int = 5,
+ eval_freq: int = 10000,
+ deterministic: bool = True,
+ verbose: int = 0,
+ best_model_save_path: Optional[str] = None,
+ log_path: Optional[str] = None,
+ ):
+
+ super(TrialEvalCallback, self).__init__(
+ eval_env=eval_env,
+ n_eval_episodes=n_eval_episodes,
+ eval_freq=eval_freq,
+ deterministic=deterministic,
+ verbose=verbose,
+ best_model_save_path=best_model_save_path,
+ log_path=log_path,
+ )
+ self.trial = trial
+ self.eval_idx = 0
+ self.is_pruned = False
+
+ def _on_step(self) -> bool:
+ if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
+ print("Evaluating trial")
+ super(TrialEvalCallback, self)._on_step()
+ self.eval_idx += 1
+ # report best or report current ?
+ # report num_timesteps or elasped time ?
+ self.trial.report(self.last_mean_reward, self.eval_idx)
+ # Prune trial if need
+ if self.trial.should_prune():
+ self.is_pruned = True
+ return False
+ return True
+
+
+class ParallelTrainCallback(BaseCallback):
+ """
+ Callback to explore (collect experience) and train (do gradient steps)
+ at the same time using two separate threads.
+ Normally used with off-policy algorithms and `train_freq=(1, "episode")`.
+
+ - blocking mode: wait for the model to finish updating the policy before collecting new experience
+ at the end of a rollout
+ - force sync mode: stop training to update to the latest policy for collecting
+ new experience
+
+ :param gradient_steps: Number of gradient steps to do before
+ sending the new policy
+ :param verbose: Verbosity level
+ :param sleep_time: Limit the fps in the thread collecting experience.
+ """
+
+ def __init__(
+ self, gradient_steps: int = 100, verbose: int = 0, sleep_time: float = 0.0
+ ):
+ super(ParallelTrainCallback, self).__init__(verbose)
+ self.batch_size = 0
+ self._model_ready = True
+ self._model = None
+ self.gradient_steps = gradient_steps
+ self.process = None
+ self.model_class = None
+ self.sleep_time = sleep_time
+
+ def _init_callback(self) -> None:
+ temp_file = tempfile.TemporaryFile()
+
+ # Windows TemporaryFile is not a io Buffer
+ # we save the model in the logs/ folder
+ if os.name == "nt":
+ temp_file = os.path.join("logs", "model_tmp.zip")
+
+ self.model.save(temp_file)
+
+ # TODO (external): add support for other algorithms
+ for model_class in [SAC, TQC]:
+ if isinstance(self.model, model_class):
+ self.model_class = model_class
+ break
+
+ assert (
+ self.model_class is not None
+ ), f"{self.model} is not supported for parallel training"
+ self._model = self.model_class.load(temp_file)
+
+ self.batch_size = self._model.batch_size
+
+ # Disable train method
+ def patch_train(function):
+ @wraps(function)
+ def wrapper(*args, **kwargs):
+ return
+
+ return wrapper
+
+ # Add logger for parallel training
+ self._model.set_logger(self.model.logger)
+ self.model.train = patch_train(self.model.train)
+
+ # Hack: Re-add correct values at save time
+ def patch_save(function):
+ @wraps(function)
+ def wrapper(*args, **kwargs):
+ return self._model.save(*args, **kwargs)
+
+ return wrapper
+
+ self.model.save = patch_save(self.model.save)
+
+ def train(self) -> None:
+ self._model_ready = False
+
+ self.process = Thread(target=self._train_thread, daemon=True)
+ self.process.start()
+
+ def _train_thread(self) -> None:
+ self._model.train(
+ gradient_steps=self.gradient_steps, batch_size=self.batch_size
+ )
+ self._model_ready = True
+
+ def _on_step(self) -> bool:
+ if self.sleep_time > 0:
+ time.sleep(self.sleep_time)
+ return True
+
+ def _on_rollout_end(self) -> None:
+ if self._model_ready:
+ self._model.replay_buffer = deepcopy(self.model.replay_buffer)
+ self.model.set_parameters(deepcopy(self._model.get_parameters()))
+ self.model.actor = self.model.policy.actor
+ if self.num_timesteps >= self._model.learning_starts:
+ self.train()
+ # Do not wait for the training loop to finish
+ # self.process.join()
+
+ def _on_training_end(self) -> None:
+ # Wait for the thread to terminate
+ if self.process is not None:
+ if self.verbose > 0:
+ print("Waiting for training thread to terminate")
+ self.process.join()
+
+
+class SaveVecNormalizeCallback(BaseCallback):
+ """
+ Callback for saving a VecNormalize wrapper every ``save_freq`` steps
+
+ :param save_freq: (int)
+ :param save_path: (str) Path to the folder where ``VecNormalize`` will be saved, as ``vecnormalize.pkl``
+ :param name_prefix: (str) Common prefix to the saved ``VecNormalize``, if None (default)
+ only one file will be kept.
+ """
+
+ def __init__(
+ self,
+ save_freq: int,
+ save_path: str,
+ name_prefix: Optional[str] = None,
+ verbose: int = 0,
+ ):
+ super(SaveVecNormalizeCallback, self).__init__(verbose)
+ self.save_freq = save_freq
+ self.save_path = save_path
+ self.name_prefix = name_prefix
+
+ def _init_callback(self) -> None:
+ # Create folder if needed
+ if self.save_path is not None:
+ os.makedirs(self.save_path, exist_ok=True)
+
+ def _on_step(self) -> bool:
+ if self.n_calls % self.save_freq == 0:
+ if self.name_prefix is not None:
+ path = os.path.join(
+ self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps.pkl"
+ )
+ else:
+ path = os.path.join(self.save_path, "vecnormalize.pkl")
+ if self.model.get_vec_normalize_env() is not None:
+ self.model.get_vec_normalize_env().save(path)
+ if self.verbose > 1:
+ print(f"Saving VecNormalize to {path}")
+ return True
+
+
+class CheckpointCallbackWithReplayBuffer(CheckpointCallback):
+ """
+ Callback for saving a model every ``save_freq`` steps
+ :param save_freq:
+ :param save_path: Path to the folder where the model will be saved.
+ :param name_prefix: Common prefix to the saved models
+ :param save_replay_buffer: If enabled, save replay buffer together with model (if supported by algorithm).
+ :param verbose:
+ """
+
+ def __init__(
+ self,
+ save_freq: int,
+ save_path: str,
+ name_prefix: str = "rl_model",
+ save_replay_buffer: bool = False,
+ verbose: int = 0,
+ ):
+ super(CheckpointCallbackWithReplayBuffer, self).__init__(
+ save_freq, save_path, name_prefix, verbose
+ )
+ self.save_replay_buffer = save_replay_buffer
+ # self.save_replay_buffer = hasattr(self.model, "save_replay_buffer") and save_replay_buffer
+
+ def _on_step(self) -> bool:
+ if self.n_calls % self.save_freq == 0:
+ path = os.path.join(
+ self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps"
+ )
+ self.model.save(path)
+ if self.verbose > 0:
+ print(f"Saving model checkpoint to {path}")
+ if self.save_replay_buffer:
+ path_replay_buffer = os.path.join(self.save_path, "replay_buffer.pkl")
+ self.model.save_replay_buffer(path_replay_buffer)
+ if self.verbose > 0:
+ print(f"Saving model checkpoint to {path_replay_buffer}")
+ return True
+
+
+class CurriculumLoggerCallback(BaseCallback):
+ """
+ Custom callback for logging curriculum values.
+ """
+
+ def __init__(self, verbose=0):
+ super(CurriculumLoggerCallback, self).__init__(verbose)
+
+ def _on_step(self) -> bool:
+
+ for infos in self.locals["infos"]:
+ for info_key, info_value in infos.items():
+ if not (
+ info_key.startswith("curriculum")
+ and info_key.count("__mean_step__")
+ ):
+ continue
+
+ self.logger.record_mean(
+ key=info_key.replace("__mean_step__", ""), value=info_value
+ )
+
+ return True
+
+ def _on_rollout_end(self) -> None:
+
+ for infos in self.locals["infos"]:
+ for info_key, info_value in infos.items():
+ if not info_key.startswith("curriculum"):
+ continue
+ if info_key.count("__mean_step__"):
+ continue
+
+ if info_key.count("__mean_episode__"):
+ self.logger.record_mean(
+ key=info_key.replace("__mean_episode__", ""), value=info_value
+ )
+ else:
+ if isinstance(info_value, str):
+ exclude = "tensorboard"
+ else:
+ exclude = None
+ self.logger.record(key=info_key, value=info_value, exclude=exclude)
+
+
+class MetricsCallback(BaseCallback):
+ def __init__(self, verbose: int = 0):
+ super(MetricsCallback, self).__init__(verbose)
+
+ def _on_step(self) -> bool:
+ pass
diff --git a/env_manager/rbs_gym/rbs_gym/utils/exp_manager.py b/env_manager/rbs_gym/rbs_gym/utils/exp_manager.py
new file mode 100644
index 0000000..eaa0b9b
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/utils/exp_manager.py
@@ -0,0 +1,936 @@
+import argparse
+import os
+import pickle as pkl
+import time
+import warnings
+from collections import OrderedDict
+from pprint import pprint
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import gymnasium as gym
+import numpy as np
+import optuna
+import yaml
+from optuna.integration.skopt import SkoptSampler
+from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner
+from optuna.samplers import BaseSampler, RandomSampler, TPESampler
+from optuna.visualization import plot_optimization_history, plot_param_importances
+from stable_baselines3 import HerReplayBuffer
+from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
+from stable_baselines3.common.env_util import make_vec_env
+from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.noise import (
+ NormalActionNoise,
+ OrnsteinUhlenbeckActionNoise,
+)
+from stable_baselines3.common.preprocessing import (
+ is_image_space,
+ is_image_space_channels_first,
+)
+from stable_baselines3.common.utils import constant_fn
+from stable_baselines3.common.vec_env import (
+ DummyVecEnv,
+ SubprocVecEnv,
+ VecEnv,
+ VecFrameStack,
+ VecNormalize,
+ VecTransposeImage,
+ is_vecenv_wrapped,
+)
+from torch import nn as nn
+
+from rbs_gym.utils.callbacks import (
+ CheckpointCallbackWithReplayBuffer,
+ SaveVecNormalizeCallback,
+ TrialEvalCallback,
+)
+from rbs_gym.utils.hyperparams_opt import HYPERPARAMS_SAMPLER
+from rbs_gym.utils.utils import (
+ ALGOS,
+ get_callback_list,
+ get_latest_run_id,
+ get_wrapper_class,
+ linear_schedule,
+)
+from aim.sb3 import AimCallback
+
+
+class ExperimentManager(object):
+ """
+ Experiment manager: read the hyperparameters,
+ preprocess them, create the environment and the RL model.
+
+ Please take a look at `train.py` to have the details for each argument.
+ """
+
+ def __init__(
+ self,
+ args: argparse.Namespace,
+ algo: str,
+ env_id: str,
+ log_folder: str,
+ tensorboard_log: str = "",
+ n_timesteps: int = 0,
+ eval_freq: int = 10000,
+ n_eval_episodes: int = 5,
+ save_freq: int = -1,
+ hyperparams: Optional[Dict[str, Any]] = None,
+ env_kwargs: Optional[Dict[str, Any]] = None,
+ trained_agent: str = "",
+ optimize_hyperparameters: bool = False,
+ storage: Optional[str] = None,
+ study_name: Optional[str] = None,
+ n_trials: int = 1,
+ n_jobs: int = 1,
+ sampler: str = "tpe",
+ pruner: str = "median",
+ optimization_log_path: Optional[str] = None,
+ n_startup_trials: int = 0,
+ n_evaluations: int = 1,
+ truncate_last_trajectory: bool = False,
+ uuid_str: str = "",
+ seed: int = 0,
+ log_interval: int = 0,
+ save_replay_buffer: bool = False,
+ preload_replay_buffer: str = "",
+ verbose: int = 1,
+ vec_env_type: str = "dummy",
+ n_eval_envs: int = 1,
+ no_optim_plots: bool = False,
+ run_name: str = "rbs_gym_run",
+ ):
+ super(ExperimentManager, self).__init__()
+ self.algo = algo
+ self.env_id = env_id
+ # Custom params
+ self.custom_hyperparams = hyperparams
+ self.env_kwargs = {} if env_kwargs is None else env_kwargs
+ self.n_timesteps = n_timesteps
+ self.normalize = False
+ self.normalize_kwargs = {}
+ self.env_wrapper = None
+ self.frame_stack = None
+ self.seed = seed
+ self.optimization_log_path = optimization_log_path
+ self.run_name = run_name
+
+ self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[
+ vec_env_type
+ ]
+
+ self.vec_env_kwargs = {}
+ # self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}
+
+ # Callbacks
+ self.specified_callbacks = []
+ self.callbacks = []
+ self.save_freq = save_freq
+ self.eval_freq = eval_freq
+ self.n_eval_episodes = n_eval_episodes
+ self.n_eval_envs = n_eval_envs
+
+ self.n_envs = 1 # it will be updated when reading hyperparams
+ self.n_actions = None # For DDPG/TD3 action noise objects
+ self._hyperparams = {}
+
+ self.trained_agent = trained_agent
+ self.continue_training = trained_agent.endswith(".zip") and os.path.isfile(
+ trained_agent
+ )
+ self.truncate_last_trajectory = truncate_last_trajectory
+
+ self.preload_replay_buffer = preload_replay_buffer
+
+ self._is_atari = self.is_atari(env_id)
+ self._is_gazebo_env = self.is_gazebo_env(env_id)
+
+ # Hyperparameter optimization config
+ self.optimize_hyperparameters = optimize_hyperparameters
+ self.storage = storage
+ self.study_name = study_name
+ self.no_optim_plots = no_optim_plots
+ # maximum number of trials for finding the best hyperparams
+ self.n_trials = n_trials
+ # number of parallel jobs when doing hyperparameter search
+ self.n_jobs = n_jobs
+ self.sampler = sampler
+ self.pruner = pruner
+ self.n_startup_trials = n_startup_trials
+ self.n_evaluations = n_evaluations
+ self.deterministic_eval = not self.is_atari(self.env_id)
+
+ # Logging
+ self.log_folder = log_folder
+ self.tensorboard_log = (
+ None if tensorboard_log == "" else os.path.join(tensorboard_log, run_name, env_id)
+ )
+ self.verbose = verbose
+ self.args = args
+ self.log_interval = log_interval
+ self.save_replay_buffer = save_replay_buffer
+
+ self.log_path = f"{log_folder}/{self.algo}/"
+ self.save_path = os.path.join(
+ self.log_path,
+ f"{self.env_id}_{get_latest_run_id(self.log_path, self.env_id) + 1}{uuid_str}",
+ )
+ self.params_path = f"{self.save_path}/{self.env_id}"
+
+ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
+ """
+ Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects)
+ create the environment and possibly the model.
+
+ :return: the initialized RL model
+ """
+ hyperparams, saved_hyperparams = self.read_hyperparameters()
+ hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(
+ hyperparams
+ )
+
+ # Create env to have access to action space for action noise
+ self._env = self.create_envs(self.n_envs, no_log=False)
+
+ self.create_log_folder()
+ self.create_callbacks()
+
+ self._hyperparams = self._preprocess_action_noise(hyperparams, self._env)
+
+ if self.continue_training:
+ model = self._load_pretrained_agent(self._hyperparams, self._env)
+ elif self.optimize_hyperparameters:
+ return None
+ else:
+ # Train an agent from scratch
+ model = ALGOS[self.algo](
+ env=self._env,
+ tensorboard_log=self.tensorboard_log,
+ seed=self.seed,
+ verbose=self.verbose,
+ **self._hyperparams,
+ )
+
+ # Pre-load replay buffer if enabled
+ if self.preload_replay_buffer:
+ if self.preload_replay_buffer.endswith(".pkl"):
+ replay_buffer_path = self.preload_replay_buffer
+ else:
+ replay_buffer_path = os.path.join(
+ self.preload_replay_buffer, "replay_buffer.pkl"
+ )
+ if os.path.exists(replay_buffer_path):
+ print("Pre-loading replay buffer")
+ if self.algo == "her":
+ model.load_replay_buffer(
+ replay_buffer_path, self.truncate_last_trajectory
+ )
+ else:
+ model.load_replay_buffer(replay_buffer_path)
+ else:
+ raise Exception(f"Replay buffer {replay_buffer_path} " "does not exist")
+
+ self._save_config(saved_hyperparams)
+ return model, saved_hyperparams
+
+ def learn(self, model: BaseAlgorithm, hyperparams: dict = {}) -> None:
+ """
+ :param model: an initialized RL model
+ """
+ kwargs = {}
+ if self.log_interval > -1:
+ kwargs = {"log_interval": self.log_interval}
+
+ if len(self.callbacks) > 0:
+ aim_callback = AimCallback(repo=self.log_folder, experiment_name=self.run_name)
+ self.callbacks.append(aim_callback)
+ kwargs["callback"] = self.callbacks
+
+ if self.continue_training:
+ kwargs["reset_num_timesteps"] = False
+ model.env.reset()
+
+ try:
+ model.learn(self.n_timesteps, **kwargs)
+ except Exception as e:
+ print(f"Caught an exception during training of the model: {e}")
+ self.save_trained_model(model)
+ finally:
+ # Release resources
+ try:
+ model.env.close()
+ except EOFError:
+ pass
+
+ def save_trained_model(self, model: BaseAlgorithm) -> None:
+ """
+ Save trained model optionally with its replay buffer
+ and ``VecNormalize`` statistics
+
+ :param model:
+ """
+ print(f"Saving to {self.save_path}")
+ model.save(f"{self.save_path}/{self.env_id}")
+
+ if hasattr(model, "save_replay_buffer") and self.save_replay_buffer:
+ print("Saving replay buffer")
+ model.save_replay_buffer(os.path.join(self.save_path, "replay_buffer.pkl"))
+
+ if self.normalize:
+ # Important: save the running average, for testing the agent we need that normalization
+ model.get_vec_normalize_env().save(
+ os.path.join(self.params_path, "vecnormalize.pkl")
+ )
+
+ def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None:
+ """
+ Save unprocessed hyperparameters, this can be use later
+ to reproduce an experiment.
+
+ :param saved_hyperparams:
+ """
+ # Save hyperparams
+ with open(os.path.join(self.params_path, "config.yml"), "w") as f:
+ yaml.dump(saved_hyperparams, f)
+
+ # save command line arguments
+ with open(os.path.join(self.params_path, "args.yml"), "w") as f:
+ ordered_args = OrderedDict(
+ [(key, vars(self.args)[key]) for key in sorted(vars(self.args).keys())]
+ )
+ yaml.dump(ordered_args, f)
+
+ print(f"Log path: {self.save_path}")
+
+ def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ # Load hyperparameters from yaml file
+ hyperparams_dir = os.path.abspath(
+ os.path.join(
+ os.path.realpath(__file__), *3 * [os.path.pardir], "hyperparams"
+ )
+ )
+ with open(f"{hyperparams_dir}/{self.algo}.yml", "r") as f:
+ hyperparams_dict = yaml.safe_load(f)
+ if self.env_id in list(hyperparams_dict.keys()):
+ hyperparams = hyperparams_dict[self.env_id]
+ elif self._is_atari:
+ hyperparams = hyperparams_dict["atari"]
+ else:
+ raise ValueError(
+ f"Hyperparameters not found for {self.algo}-{self.env_id}"
+ )
+
+ if self.custom_hyperparams is not None:
+ # Overwrite hyperparams if needed
+ hyperparams.update(self.custom_hyperparams)
+ # Sort hyperparams that will be saved
+ saved_hyperparams = OrderedDict(
+ [(key, hyperparams[key]) for key in sorted(hyperparams.keys())]
+ )
+
+ if self.verbose > 0:
+ print(
+ "Default hyperparameters for environment (ones being tuned will be overridden):"
+ )
+ pprint(saved_hyperparams)
+
+ return hyperparams, saved_hyperparams
+
+ @staticmethod
+ def _preprocess_schedules(hyperparams: Dict[str, Any]) -> Dict[str, Any]:
+ # Create schedules
+ for key in ["learning_rate", "clip_range", "clip_range_vf"]:
+ if key not in hyperparams:
+ continue
+ if isinstance(hyperparams[key], str):
+ schedule, initial_value = hyperparams[key].split("_")
+ initial_value = float(initial_value)
+ hyperparams[key] = linear_schedule(initial_value)
+ elif isinstance(hyperparams[key], (float, int)):
+ # Negative value: ignore (ex: for clipping)
+ if hyperparams[key] < 0:
+ continue
+ hyperparams[key] = constant_fn(float(hyperparams[key]))
+ else:
+ raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")
+ return hyperparams
+
+ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, Any]:
+ if "normalize" in hyperparams.keys():
+ self.normalize = hyperparams["normalize"]
+
+ # Special case, instead of both normalizing
+ # both observation and reward, we can normalize one of the two.
+ # in that case `hyperparams["normalize"]` is a string
+ # that can be evaluated as python,
+ # ex: "dict(norm_obs=False, norm_reward=True)"
+ if isinstance(self.normalize, str):
+ self.normalize_kwargs = eval(self.normalize)
+ self.normalize = True
+
+ # Use the same discount factor as for the algorithm
+ if "gamma" in hyperparams:
+ self.normalize_kwargs["gamma"] = hyperparams["gamma"]
+
+ del hyperparams["normalize"]
+ return hyperparams
+
+ def _preprocess_hyperparams(
+ self, hyperparams: Dict[str, Any]
+ ) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback]]:
+ self.n_envs = hyperparams.get("n_envs", 1)
+
+ if self.verbose > 0:
+ print(f"Using {self.n_envs} environments")
+
+ # Convert schedule strings to objects
+ hyperparams = self._preprocess_schedules(hyperparams)
+
+ # Pre-process train_freq
+ if "train_freq" in hyperparams and isinstance(hyperparams["train_freq"], list):
+ hyperparams["train_freq"] = tuple(hyperparams["train_freq"])
+
+ # Should we overwrite the number of timesteps?
+ if self.n_timesteps > 0:
+ if self.verbose:
+ print(f"Overwriting n_timesteps with n={self.n_timesteps}")
+ else:
+ self.n_timesteps = int(hyperparams["n_timesteps"])
+
+ # Pre-process normalize config
+ hyperparams = self._preprocess_normalization(hyperparams)
+
+ # Pre-process policy/buffer keyword arguments
+ # Convert to python object if needed
+ # TODO: Use the new replay_buffer_class argument of offpolicy algorithms instead of monkey patch
+ for kwargs_key in {
+ "policy_kwargs",
+ "replay_buffer_class",
+ "replay_buffer_kwargs",
+ }:
+ if kwargs_key in hyperparams.keys() and isinstance(
+ hyperparams[kwargs_key], str
+ ):
+ hyperparams[kwargs_key] = eval(hyperparams[kwargs_key])
+
+ # Delete keys so the dict can be pass to the model constructor
+ if "n_envs" in hyperparams.keys():
+ del hyperparams["n_envs"]
+ del hyperparams["n_timesteps"]
+
+ if "frame_stack" in hyperparams.keys():
+ self.frame_stack = hyperparams["frame_stack"]
+ del hyperparams["frame_stack"]
+
+ # obtain a class object from a wrapper name string in hyperparams
+ # and delete the entry
+ env_wrapper = get_wrapper_class(hyperparams)
+ if "env_wrapper" in hyperparams.keys():
+ del hyperparams["env_wrapper"]
+
+ callbacks = get_callback_list(hyperparams)
+ if "callback" in hyperparams.keys():
+ self.specified_callbacks = hyperparams["callback"]
+ del hyperparams["callback"]
+
+ return hyperparams, env_wrapper, callbacks
+
+ def _preprocess_action_noise(
+ self, hyperparams: Dict[str, Any], env: VecEnv
+ ) -> Dict[str, Any]:
+ # Parse noise string
+ # Note: only off-policy algorithms are supported
+ if hyperparams.get("noise_type") is not None:
+ noise_type = hyperparams["noise_type"].strip()
+ noise_std = hyperparams["noise_std"]
+
+ # Save for later (hyperparameter optimization)
+ self.n_actions = env.action_space.shape[0]
+
+ if "normal" in noise_type:
+ hyperparams["action_noise"] = NormalActionNoise(
+ mean=np.zeros(self.n_actions),
+ sigma=noise_std * np.ones(self.n_actions),
+ )
+ elif "ornstein-uhlenbeck" in noise_type:
+ hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise(
+ mean=np.zeros(self.n_actions),
+ sigma=noise_std * np.ones(self.n_actions),
+ )
+ else:
+ raise RuntimeError(f'Unknown noise type "{noise_type}"')
+
+ print(f"Applying {noise_type} noise with std {noise_std}")
+
+ del hyperparams["noise_type"]
+ del hyperparams["noise_std"]
+
+ return hyperparams
+
+ def create_log_folder(self):
+ os.makedirs(self.params_path, exist_ok=True)
+
+ def create_callbacks(self):
+
+ if self.save_freq > 0:
+ # Account for the number of parallel environments
+ self.save_freq = max(self.save_freq // self.n_envs, 1)
+ self.callbacks.append(
+ CheckpointCallbackWithReplayBuffer(
+ save_freq=self.save_freq,
+ save_path=self.save_path,
+ name_prefix="rl_model",
+ save_replay_buffer=self.save_replay_buffer,
+ verbose=self.verbose,
+ )
+ )
+
+ # Create test env if needed, do not normalize reward
+ if self.eval_freq > 0 and not self.optimize_hyperparameters:
+ # Account for the number of parallel environments
+ self.eval_freq = max(self.eval_freq // self.n_envs, 1)
+
+ if self.verbose > 0:
+ print("Creating test environment")
+
+ save_vec_normalize = SaveVecNormalizeCallback(
+ save_freq=1, save_path=self.params_path
+ )
+ eval_callback = EvalCallback(
+ eval_env=self._env,
+ # TODO: Use separate environment(s) for evaluation
+ # self.create_envs(self.n_eval_envs, eval_env=True),
+ callback_on_new_best=save_vec_normalize,
+ best_model_save_path=self.save_path,
+ n_eval_episodes=self.n_eval_episodes,
+ log_path=self.save_path,
+ eval_freq=self.eval_freq,
+ deterministic=self.deterministic_eval,
+ )
+
+ self.callbacks.append(eval_callback)
+
+ @staticmethod
+ def is_atari(env_id: str) -> bool:
+ entry_point = gym.envs.registry[env_id].entry_point
+ return "AtariEnv" in str(entry_point)
+
+ @staticmethod
+ def is_bullet(env_id: str) -> bool:
+ entry_point = gym.envs.registry[env_id].entry_point
+ return "pybullet_envs" in str(entry_point)
+
+ @staticmethod
+ def is_robotics_env(env_id: str) -> bool:
+ entry_point = gym.envs.registry[env_id].entry_point
+ return "gym.envs.robotics" in str(entry_point) or "panda_gym.envs" in str(
+ entry_point
+ )
+
+ @staticmethod
+ def is_gazebo_env(env_id: str) -> bool:
+ return "Gazebo" in gym.envs.registry[env_id].entry_point
+
+ def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv:
+ """
+ Wrap the env into a VecNormalize wrapper if needed
+ and load saved statistics when present.
+
+ :param env:
+ :param eval_env:
+ :return:
+ """
+ # Pretrained model, load normalization
+ path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_id)
+ path_ = os.path.join(path_, "vecnormalize.pkl")
+
+ if os.path.exists(path_):
+ print("Loading saved VecNormalize stats")
+ env = VecNormalize.load(path_, env)
+ # Deactivate training and reward normalization
+ if eval_env:
+ env.training = False
+ env.norm_reward = False
+
+ elif self.normalize:
+ # Copy to avoid changing default values by reference
+ local_normalize_kwargs = self.normalize_kwargs.copy()
+ # Do not normalize reward for env used for evaluation
+ if eval_env:
+ if len(local_normalize_kwargs) > 0:
+ local_normalize_kwargs["norm_reward"] = False
+ else:
+ local_normalize_kwargs = {"norm_reward": False}
+
+ if self.verbose > 0:
+ if len(local_normalize_kwargs) > 0:
+ print(f"Normalization activated: {local_normalize_kwargs}")
+ else:
+ print("Normalizing input and reward")
+ # Note: The following line was added but not sure if it is still required
+ env.num_envs = self.n_envs
+ env = VecNormalize(env, **local_normalize_kwargs)
+ return env
+
+ def create_envs(
+ self, n_envs: int, eval_env: bool = False, no_log: bool = False
+ ) -> VecEnv:
+ """
+ Create the environment and wrap it if necessary.
+
+ :param n_envs:
+ :param eval_env: Whether is it an environment used for evaluation or not
+ :param no_log: Do not log training when doing hyperparameter optim
+ (issue with writing the same file)
+ :return: the vectorized environment, with appropriate wrappers
+ """
+ # Do not log eval env (issue with writing the same file)
+ log_dir = None if eval_env or no_log else self.save_path
+
+ monitor_kwargs = {}
+ # Special case for GoalEnvs: log success rate too
+ if (
+ "Neck" in self.env_id
+ or self.is_robotics_env(self.env_id)
+ or "parking-v0" in self.env_id
+ ):
+ monitor_kwargs = dict(info_keywords=("is_success",))
+
+ # On most env, SubprocVecEnv does not help and is quite memory hungry
+ # therefore we use DummyVecEnv by default
+ env = make_vec_env(
+ env_id=self.env_id,
+ n_envs=n_envs,
+ seed=self.seed,
+ env_kwargs=self.env_kwargs,
+ monitor_dir=log_dir,
+ wrapper_class=self.env_wrapper,
+ vec_env_cls=self.vec_env_class,
+ vec_env_kwargs=self.vec_env_kwargs,
+ monitor_kwargs=monitor_kwargs,
+ )
+
+ # Wrap the env into a VecNormalize wrapper if needed
+ # and load saved statistics when present
+ env = self._maybe_normalize(env, eval_env)
+
+ # Optional Frame-stacking
+ if self.frame_stack is not None:
+ n_stack = self.frame_stack
+ env = VecFrameStack(env, n_stack)
+ if self.verbose > 0:
+ print(f"Stacking {n_stack} frames")
+
+ if not is_vecenv_wrapped(env, VecTransposeImage):
+ wrap_with_vectranspose = False
+ if isinstance(env.observation_space, gym.spaces.Dict):
+ # If even one of the keys is a image-space in need of transpose, apply transpose
+ # If the image spaces are not consistent (for instance one is channel first,
+ # the other channel last), VecTransposeImage will throw an error
+ for space in env.observation_space.spaces.values():
+ wrap_with_vectranspose = wrap_with_vectranspose or (
+ is_image_space(space)
+ and not is_image_space_channels_first(space)
+ )
+ else:
+ wrap_with_vectranspose = is_image_space(
+ env.observation_space
+ ) and not is_image_space_channels_first(env.observation_space)
+
+ if wrap_with_vectranspose:
+ if self.verbose >= 1:
+ print("Wrapping the env in a VecTransposeImage.")
+ env = VecTransposeImage(env)
+
+ return env
+
+ def _load_pretrained_agent(
+ self, hyperparams: Dict[str, Any], env: VecEnv
+ ) -> BaseAlgorithm:
+ # Continue training
+ print(
+ f"Loading pretrained agent '{self.trained_agent}' to continue its training"
+ )
+ # Policy should not be changed
+ del hyperparams["policy"]
+
+ if "policy_kwargs" in hyperparams.keys():
+ del hyperparams["policy_kwargs"]
+
+ model = ALGOS[self.algo].load(
+ self.trained_agent,
+ env=env,
+ seed=self.seed,
+ tensorboard_log=self.tensorboard_log,
+ verbose=self.verbose,
+ **hyperparams,
+ )
+
+ replay_buffer_path = os.path.join(
+ os.path.dirname(self.trained_agent), "replay_buffer.pkl"
+ )
+
+ if not self.preload_replay_buffer and os.path.exists(replay_buffer_path):
+ print("Loading replay buffer")
+ # `truncate_last_traj` will be taken into account only if we use HER replay buffer
+ model.load_replay_buffer(
+ replay_buffer_path, truncate_last_traj=self.truncate_last_trajectory
+ )
+ return model
+
+ def _create_sampler(self, sampler_method: str) -> BaseSampler:
+ # n_warmup_steps: Disable pruner until the trial reaches the given number of step.
+ if sampler_method == "random":
+ sampler = RandomSampler(seed=self.seed)
+ elif sampler_method == "tpe":
+ # TODO (external): try with multivariate=True
+ sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed)
+ elif sampler_method == "skopt":
+ # cf https://scikit-optimize.github.io/#skopt.Optimizer
+ # GP: gaussian process
+ # Gradient boosted regression: GBRT
+ sampler = SkoptSampler(
+ skopt_kwargs={"base_estimator": "GP", "acq_func": "gp_hedge"}
+ )
+ else:
+ raise ValueError(f"Unknown sampler: {sampler_method}")
+ return sampler
+
+ def _create_pruner(self, pruner_method: str) -> BasePruner:
+ if pruner_method == "halving":
+ pruner = SuccessiveHalvingPruner(
+ min_resource=1, reduction_factor=4, min_early_stopping_rate=0
+ )
+ elif pruner_method == "median":
+ pruner = MedianPruner(
+ n_startup_trials=self.n_startup_trials,
+ n_warmup_steps=self.n_evaluations // 3,
+ )
+ elif pruner_method == "none":
+ # Do not prune
+ pruner = NopPruner()
+ else:
+ raise ValueError(f"Unknown pruner: {pruner_method}")
+ return pruner
+
+ # Important: Objective changed in this project (rbs_gym) to evaluate on the same environment that is used for training (cannot have two existing simulatenously with the current setup)
+ def objective(self, trial: optuna.Trial) -> float:
+
+ kwargs = self._hyperparams.copy()
+
+ trial.model_class = None
+
+ # Hack to use DDPG/TD3 noise sampler
+ trial.n_actions = self._env.action_space.shape[0]
+
+ # Hack when using HerReplayBuffer
+ if kwargs.get("replay_buffer_class") == HerReplayBuffer:
+ trial.her_kwargs = kwargs.get("replay_buffer_kwargs", {})
+
+ # Sample candidate hyperparameters
+ kwargs.update(HYPERPARAMS_SAMPLER[self.algo](trial))
+ print(f"\nRunning a new trial with hyperparameters: {kwargs}")
+
+ # Write hyperparameters into a file
+ trial_params_path = os.path.join(self.params_path, "optimization")
+ os.makedirs(trial_params_path, exist_ok=True)
+ with open(
+ os.path.join(
+ trial_params_path, f"hyperparameters_trial_{trial.number}.yml"
+ ),
+ "w",
+ ) as f:
+ yaml.dump(kwargs, f)
+
+ model = ALGOS[self.algo](
+ env=self._env,
+ # Note: Here I enabled tensorboard logs
+ tensorboard_log=self.tensorboard_log,
+ # Note: Here I differ and I seed the trial. I want all trials to have the same starting conditions
+ seed=self.seed,
+ verbose=self.verbose,
+ **kwargs,
+ )
+
+ # Pre-load replay buffer if enabled
+ if self.preload_replay_buffer:
+ if self.preload_replay_buffer.endswith(".pkl"):
+ replay_buffer_path = self.preload_replay_buffer
+ else:
+ replay_buffer_path = os.path.join(
+ self.preload_replay_buffer, "replay_buffer.pkl"
+ )
+ if os.path.exists(replay_buffer_path):
+ print("Pre-loading replay buffer")
+ if self.algo == "her":
+ model.load_replay_buffer(
+ replay_buffer_path, self.truncate_last_trajectory
+ )
+ else:
+ model.load_replay_buffer(replay_buffer_path)
+ else:
+ raise Exception(f"Replay buffer {replay_buffer_path} " "does not exist")
+
+ model.trial = trial
+
+ eval_freq = int(self.n_timesteps / self.n_evaluations)
+ # Account for parallel envs
+ eval_freq_ = max(eval_freq // model.get_env().num_envs, 1)
+ # Use non-deterministic eval for Atari
+ callbacks = get_callback_list({"callback": self.specified_callbacks})
+ path = None
+ if self.optimization_log_path is not None:
+ path = os.path.join(
+ self.optimization_log_path, f"trial_{str(trial.number)}"
+ )
+ eval_callback = TrialEvalCallback(
+ # TODO: Use a separate environment for evaluation during trial
+ model.env,
+ model.trial,
+ best_model_save_path=path,
+ log_path=path,
+ n_eval_episodes=self.n_eval_episodes,
+ eval_freq=eval_freq_,
+ deterministic=self.deterministic_eval,
+ verbose=self.verbose,
+ )
+ callbacks.append(eval_callback)
+
+ try:
+ model.learn(self.n_timesteps, callback=callbacks)
+ # Reset env
+ self._env.reset()
+ except AssertionError as e:
+ # Reset env
+ self._env.reset()
+ print("Trial stopped:", e)
+ # Prune hyperparams that generate NaNs
+ raise optuna.exceptions.TrialPruned()
+ except Exception as err:
+ exception_type = type(err).__name__
+ print("Trial stopped due to raised exception:", exception_type, err)
+ # Prune also all other exceptions
+ raise optuna.exceptions.TrialPruned()
+ is_pruned = eval_callback.is_pruned
+ reward = eval_callback.last_mean_reward
+
+ print(
+ f"\nFinished a trial with reward={reward}, is_pruned={is_pruned} "
+ f"for hyperparameters: {kwargs}"
+ )
+
+ del model
+
+ if is_pruned:
+ raise optuna.exceptions.TrialPruned()
+
+ return reward
+
+ def hyperparameters_optimization(self) -> None:
+
+ if self.verbose > 0:
+ print("Optimizing hyperparameters")
+
+ if self.storage is not None and self.study_name is None:
+ warnings.warn(
+ f"You passed a remote storage: {self.storage} but no `--study-name`."
+ "The study name will be generated by Optuna, make sure to re-use the same study name "
+ "when you want to do distributed hyperparameter optimization."
+ )
+
+ if self.tensorboard_log is not None:
+ warnings.warn(
+ "Tensorboard log is deactivated when running hyperparameter optimization"
+ )
+ self.tensorboard_log = None
+
+ # TODO (external): eval each hyperparams several times to account for noisy evaluation
+ sampler = self._create_sampler(self.sampler)
+ pruner = self._create_pruner(self.pruner)
+
+ if self.verbose > 0:
+ print(f"Sampler: {self.sampler} - Pruner: {self.pruner}")
+
+ study = optuna.create_study(
+ sampler=sampler,
+ pruner=pruner,
+ storage=self.storage,
+ study_name=self.study_name,
+ load_if_exists=True,
+ direction="maximize",
+ )
+
+ try:
+ study.optimize(
+ self.objective,
+ n_trials=self.n_trials,
+ n_jobs=self.n_jobs,
+ gc_after_trial=True,
+ show_progress_bar=True,
+ )
+ except KeyboardInterrupt:
+ pass
+
+ print("Number of finished trials: ", len(study.trials))
+
+ print("Best trial:")
+ trial = study.best_trial
+
+ print("Value: ", trial.value)
+
+ print("Params: ")
+ for key, value in trial.params.items():
+ print(f" {key}: {value}")
+
+ report_name = (
+ f"report_{self.env_id}_{self.n_trials}-trials-{self.n_timesteps}"
+ f"-{self.sampler}-{self.pruner}_{int(time.time())}"
+ )
+
+ log_path = os.path.join(self.log_folder, self.algo, report_name)
+
+ if self.verbose:
+ print(f"Writing report to {log_path}")
+
+ # Write report
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
+ study.trials_dataframe().to_csv(f"{log_path}.csv")
+
+ # Save python object to inspect/re-use it later
+ with open(f"{log_path}.pkl", "wb+") as f:
+ pkl.dump(study, f)
+
+ # Skip plots
+ if self.no_optim_plots:
+ return
+
+ # Plot optimization result
+ try:
+ fig1 = plot_optimization_history(study)
+ fig2 = plot_param_importances(study)
+
+ fig1.show()
+ fig2.show()
+ except (ValueError, ImportError, RuntimeError):
+ pass
+
+ def collect_demonstration(self, model):
+ # Any random action will do (this won't actually be used since `preload_replay_buffer` env kwarg is enabled)
+ action = np.array([model.env.action_space.sample()])
+ # Reset env at the beginning
+ obs = model.env.reset()
+ # Collect transitions
+ for i in range(model.replay_buffer.buffer_size):
+ # Note: If `None` is passed to Grasp env, it uses custom action heuristic to reach the target
+ next_obs, rewards, dones, infos = model.env.unwrapped.step(action)
+ # Extract the actual actions from info
+ actual_actions = [info["actual_actions"] for info in infos]
+ # Add to replay buffer
+ model.replay_buffer.add(obs, next_obs, actual_actions, rewards, dones)
+ # Update current observation
+ obs = next_obs
+
+ print("Saving replay buffer")
+ model.save_replay_buffer(os.path.join(self.save_path, "replay_buffer.pkl"))
+ model.env.close()
+ exit
diff --git a/env_manager/rbs_gym/rbs_gym/utils/hyperparams_opt.py b/env_manager/rbs_gym/rbs_gym/utils/hyperparams_opt.py
new file mode 100644
index 0000000..9d36cf5
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/utils/hyperparams_opt.py
@@ -0,0 +1,170 @@
+from typing import Any, Dict
+
+import numpy as np
+import optuna
+from stable_baselines3.common.noise import NormalActionNoise
+from torch import nn as nn
+
+
+def sample_sac_params(
+ trial: optuna.Trial,
+) -> Dict[str, Any]:
+ """
+ Sampler for SAC hyperparameters
+ """
+
+ buffer_size = 150000
+ # learning_starts = trial.suggest_categorical(
+ # "learning_starts", [5000, 10000, 20000])
+ learning_starts = 5000
+
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256, 512, 1024, 2048])
+ learning_rate = trial.suggest_float(
+ "learning_rate", low=0.000001, high=0.001, log=True
+ )
+
+ gamma = trial.suggest_float("gamma", low=0.98, high=1.0, log=True)
+ tau = trial.suggest_float("tau", low=0.001, high=0.025, log=True)
+
+ ent_coef = "auto_0.5_0.1"
+ target_entropy = "auto"
+
+ noise_std = trial.suggest_float("noise_std", low=0.01, high=0.2, log=True)
+ action_noise = NormalActionNoise(
+ mean=np.zeros(trial.n_actions), sigma=np.ones(trial.n_actions) * noise_std
+ )
+
+ train_freq = 1
+ gradient_steps = trial.suggest_categorical("gradient_steps", [1, 2])
+
+ policy_kwargs = dict()
+ net_arch = trial.suggest_categorical("net_arch", [256, 384, 512])
+ policy_kwargs["net_arch"] = [net_arch] * 3
+
+ return {
+ "buffer_size": buffer_size,
+ "learning_starts": learning_starts,
+ "batch_size": batch_size,
+ "learning_rate": learning_rate,
+ "gamma": gamma,
+ "tau": tau,
+ "ent_coef": ent_coef,
+ "target_entropy": target_entropy,
+ "action_noise": action_noise,
+ "train_freq": train_freq,
+ "gradient_steps": gradient_steps,
+ "policy_kwargs": policy_kwargs,
+ }
+
+
+def sample_td3_params(
+ trial: optuna.Trial,
+) -> Dict[str, Any]:
+ """
+ Sampler for TD3 hyperparameters
+ """
+
+ buffer_size = 150000
+ # learning_starts = trial.suggest_categorical(
+ # "learning_starts", [5000, 10000, 20000])
+ learning_starts = 5000
+
+ batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
+ learning_rate = trial.suggest_float(
+ "learning_rate", low=0.000001, high=0.001, log=True
+ )
+
+ gamma = trial.suggest_float("gamma", low=0.98, high=1.0, log=True)
+ tau = trial.suggest_float("tau", low=0.001, high=0.025, log=True)
+
+ target_policy_noise = trial.suggest_float(
+ "target_policy_noise", low=0.00000001, high=0.5, log=True
+ )
+ target_noise_clip = 0.5
+
+ noise_std = trial.suggest_float("noise_std", low=0.025, high=0.5, log=True)
+ action_noise = NormalActionNoise(
+ mean=np.zeros(trial.n_actions), sigma=np.ones(trial.n_actions) * noise_std
+ )
+
+ train_freq = 1
+ gradient_steps = trial.suggest_categorical("gradient_steps", [1, 2])
+
+ policy_kwargs = dict()
+ net_arch = trial.suggest_categorical("net_arch", [256, 384, 512])
+ policy_kwargs["net_arch"] = [net_arch] * 3
+
+ return {
+ "buffer_size": buffer_size,
+ "learning_starts": learning_starts,
+ "batch_size": batch_size,
+ "learning_rate": learning_rate,
+ "gamma": gamma,
+ "tau": tau,
+ "target_policy_noise": target_policy_noise,
+ "target_noise_clip": target_noise_clip,
+ "action_noise": action_noise,
+ "train_freq": train_freq,
+ "gradient_steps": gradient_steps,
+ "policy_kwargs": policy_kwargs,
+ }
+
+
+def sample_tqc_params(
+ trial: optuna.Trial,
+) -> Dict[str, Any]:
+ """
+ Sampler for TQC hyperparameters
+ """
+
+ buffer_size = 25000
+ learning_starts = 0
+
+ batch_size = 32
+ learning_rate = trial.suggest_float(
+ "learning_rate", low=0.000025, high=0.00075, log=True
+ )
+
+ gamma = 1.0 - trial.suggest_float("gamma", low=0.0001, high=0.025, log=True)
+ tau = trial.suggest_float("tau", low=0.0005, high=0.025, log=True)
+
+ ent_coef = "auto_0.1_0.05"
+ target_entropy = "auto"
+
+ noise_std = trial.suggest_float("noise_std", low=0.01, high=0.1, log=True)
+ action_noise = NormalActionNoise(
+ mean=np.zeros(trial.n_actions), sigma=np.ones(trial.n_actions) * noise_std
+ )
+
+ train_freq = 1
+ gradient_steps = trial.suggest_categorical("gradient_steps", [1, 2])
+
+ policy_kwargs = dict()
+ net_arch = trial.suggest_categorical("net_arch", [128, 256, 384, 512])
+ policy_kwargs["net_arch"] = [net_arch] * 2
+ policy_kwargs["n_quantiles"] = trial.suggest_int("n_quantiles", low=20, high=40)
+ top_quantiles_to_drop_per_net = round(0.08 * policy_kwargs["n_quantiles"])
+ policy_kwargs["n_critics"] = trial.suggest_categorical("n_critics", [2, 3])
+
+ return {
+ "buffer_size": buffer_size,
+ "learning_starts": learning_starts,
+ "batch_size": batch_size,
+ "learning_rate": learning_rate,
+ "gamma": gamma,
+ "tau": tau,
+ "ent_coef": ent_coef,
+ "target_entropy": target_entropy,
+ "top_quantiles_to_drop_per_net": top_quantiles_to_drop_per_net,
+ "action_noise": action_noise,
+ "train_freq": train_freq,
+ "gradient_steps": gradient_steps,
+ "policy_kwargs": policy_kwargs,
+ }
+
+
+HYPERPARAMS_SAMPLER = {
+ "sac": sample_sac_params,
+ "td3": sample_td3_params,
+ "tqc": sample_tqc_params,
+}
diff --git a/env_manager/rbs_gym/rbs_gym/utils/utils.py b/env_manager/rbs_gym/rbs_gym/utils/utils.py
new file mode 100644
index 0000000..c073dd6
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/utils/utils.py
@@ -0,0 +1,411 @@
+import argparse
+import glob
+import importlib
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import gymnasium as gym
+
+# For custom activation fn
+import stable_baselines3 as sb3 # noqa: F401
+import torch as th # noqa: F401
+import yaml
+from sb3_contrib import QRDQN, TQC
+from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
+from stable_baselines3.common.callbacks import BaseCallback
+from stable_baselines3.common.env_util import make_vec_env
+from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.sb2_compat.rmsprop_tf_like import ( # noqa: F401
+ RMSpropTFLike,
+)
+from stable_baselines3.common.vec_env import (
+ DummyVecEnv,
+ SubprocVecEnv,
+ VecEnv,
+ VecFrameStack,
+ VecNormalize,
+)
+from torch import nn as nn # noqa: F401 pylint: disable=unused-import
+
+ALGOS = {
+ "a2c": A2C,
+ "ddpg": DDPG,
+ "dqn": DQN,
+ "ppo": PPO,
+ "sac": SAC,
+ "td3": TD3,
+ # SB3 Contrib,
+ "qrdqn": QRDQN,
+ "tqc": TQC,
+}
+
+
+def flatten_dict_observations(env: gym.Env) -> gym.Env:
+ assert isinstance(env.observation_space, gym.spaces.Dict)
+ try:
+ return gym.wrappers.FlattenObservation(env)
+ except AttributeError:
+ keys = env.observation_space.spaces.keys()
+ return gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
+
+
+def get_wrapper_class(
+ hyperparams: Dict[str, Any]
+) -> Optional[Callable[[gym.Env], gym.Env]]:
+ """
+ Get one or more Gym environment wrapper class specified as a hyper parameter
+ "env_wrapper".
+ e.g.
+ env_wrapper: gym_minigrid.wrappers.FlatObsWrapper
+
+ for multiple, specify a list:
+
+ env_wrapper:
+ - utils.wrappers.PlotActionWrapper
+ - utils.wrappers.TimeFeatureWrapper
+
+
+ :param hyperparams:
+ :return: maybe a callable to wrap the environment
+ with one or multiple gym.Wrapper
+ """
+
+ def get_module_name(wrapper_name):
+ return ".".join(wrapper_name.split(".")[:-1])
+
+ def get_class_name(wrapper_name):
+ return wrapper_name.split(".")[-1]
+
+ if "env_wrapper" in hyperparams.keys():
+ wrapper_name = hyperparams.get("env_wrapper")
+
+ if wrapper_name is None:
+ return None
+
+ if not isinstance(wrapper_name, list):
+ wrapper_names = [wrapper_name]
+ else:
+ wrapper_names = wrapper_name
+
+ wrapper_classes = []
+ wrapper_kwargs = []
+ # Handle multiple wrappers
+ for wrapper_name in wrapper_names:
+ # Handle keyword arguments
+ if isinstance(wrapper_name, dict):
+ assert len(wrapper_name) == 1, (
+ "You have an error in the formatting "
+ f"of your YAML file near {wrapper_name}. "
+ "You should check the indentation."
+ )
+ wrapper_dict = wrapper_name
+ wrapper_name = list(wrapper_dict.keys())[0]
+ kwargs = wrapper_dict[wrapper_name]
+ else:
+ kwargs = {}
+ wrapper_module = importlib.import_module(get_module_name(wrapper_name))
+ wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name))
+ wrapper_classes.append(wrapper_class)
+ wrapper_kwargs.append(kwargs)
+
+ def wrap_env(env: gym.Env) -> gym.Env:
+ """
+ :param env:
+ :return:
+ """
+ for wrapper_class, kwargs in zip(wrapper_classes, wrapper_kwargs):
+ env = wrapper_class(env, **kwargs)
+ return env
+
+ return wrap_env
+ else:
+ return None
+
+
+def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]:
+ """
+ Get one or more Callback class specified as a hyper-parameter
+ "callback".
+ e.g.
+ callback: stable_baselines3.common.callbacks.CheckpointCallback
+
+ for multiple, specify a list:
+
+ callback:
+ - utils.callbacks.PlotActionWrapper
+ - stable_baselines3.common.callbacks.CheckpointCallback
+
+ :param hyperparams:
+ :return:
+ """
+
+ def get_module_name(callback_name):
+ return ".".join(callback_name.split(".")[:-1])
+
+ def get_class_name(callback_name):
+ return callback_name.split(".")[-1]
+
+ callbacks = []
+
+ if "callback" in hyperparams.keys():
+ callback_name = hyperparams.get("callback")
+
+ if callback_name is None:
+ return callbacks
+
+ if not isinstance(callback_name, list):
+ callback_names = [callback_name]
+ else:
+ callback_names = callback_name
+
+ # Handle multiple wrappers
+ for callback_name in callback_names:
+ # Handle keyword arguments
+ if isinstance(callback_name, dict):
+ assert len(callback_name) == 1, (
+ "You have an error in the formatting "
+ f"of your YAML file near {callback_name}. "
+ "You should check the indentation."
+ )
+ callback_dict = callback_name
+ callback_name = list(callback_dict.keys())[0]
+ kwargs = callback_dict[callback_name]
+ else:
+ kwargs = {}
+ callback_module = importlib.import_module(get_module_name(callback_name))
+ callback_class = getattr(callback_module, get_class_name(callback_name))
+ callbacks.append(callback_class(**kwargs))
+
+ return callbacks
+
+
+def create_test_env(
+ env_id: str,
+ n_envs: int = 1,
+ stats_path: Optional[str] = None,
+ seed: int = 0,
+ log_dir: Optional[str] = None,
+ should_render: bool = True,
+ hyperparams: Optional[Dict[str, Any]] = None,
+ env_kwargs: Optional[Dict[str, Any]] = None,
+) -> VecEnv:
+ """
+ Create environment for testing a trained agent
+
+ :param env_id:
+ :param n_envs: number of processes
+ :param stats_path: path to folder containing saved running averaged
+ :param seed: Seed for random number generator
+ :param log_dir: Where to log rewards
+ :param should_render: For Pybullet env, display the GUI
+ :param hyperparams: Additional hyperparams (ex: n_stack)
+ :param env_kwargs: Optional keyword argument to pass to the env constructor
+ :return:
+ """
+ # Avoid circular import
+ from rbs_gym.utils.exp_manager import ExperimentManager
+
+ # Create the environment and wrap it if necessary
+ env_wrapper = get_wrapper_class(hyperparams)
+
+ hyperparams = {} if hyperparams is None else hyperparams
+
+ if "env_wrapper" in hyperparams.keys():
+ del hyperparams["env_wrapper"]
+
+ vec_env_kwargs = {}
+ vec_env_cls = DummyVecEnv
+ if n_envs > 1 or (ExperimentManager.is_bullet(env_id) and should_render):
+ # HACK: force SubprocVecEnv for Bullet env
+ # as Pybullet envs does not follow gym.render() interface
+ vec_env_cls = SubprocVecEnv
+ # start_method = 'spawn' for thread safe
+
+ env = make_vec_env(
+ env_id,
+ n_envs=n_envs,
+ monitor_dir=log_dir,
+ seed=seed,
+ wrapper_class=env_wrapper,
+ env_kwargs=env_kwargs,
+ vec_env_cls=vec_env_cls,
+ vec_env_kwargs=vec_env_kwargs,
+ )
+
+ # Load saved stats for normalizing input and rewards
+ # And optionally stack frames
+ if stats_path is not None:
+ if hyperparams["normalize"]:
+ print("Loading running average")
+ print(f"with params: {hyperparams['normalize_kwargs']}")
+ path_ = os.path.join(stats_path, "vecnormalize.pkl")
+ if os.path.exists(path_):
+ env = VecNormalize.load(path_, env)
+ # Deactivate training and reward normalization
+ env.training = False
+ env.norm_reward = False
+ else:
+ raise ValueError(f"VecNormalize stats {path_} not found")
+
+ n_stack = hyperparams.get("frame_stack", 0)
+ if n_stack > 0:
+ print(f"Stacking {n_stack} frames")
+ env = VecFrameStack(env, n_stack)
+ return env
+
+
+def linear_schedule(initial_value: Union[float, str]) -> Callable[[float], float]:
+ """
+ Linear learning rate schedule.
+
+ :param initial_value: (float or str)
+ :return: (function)
+ """
+ if isinstance(initial_value, str):
+ initial_value = float(initial_value)
+
+ def func(progress_remaining: float) -> float:
+ """
+ Progress will decrease from 1 (beginning) to 0
+ :param progress_remaining: (float)
+ :return: (float)
+ """
+ return progress_remaining * initial_value
+
+ return func
+
+
+def get_trained_models(log_folder: str) -> Dict[str, Tuple[str, str]]:
+ """
+ :param log_folder: Root log folder
+ :return: Dict[str, Tuple[str, str]] representing the trained agents
+ """
+ trained_models = {}
+ for algo in os.listdir(log_folder):
+ if not os.path.isdir(os.path.join(log_folder, algo)):
+ continue
+ for env_id in os.listdir(os.path.join(log_folder, algo)):
+ # Retrieve env name
+ env_id = env_id.split("_")[0]
+ trained_models[f"{algo}-{env_id}"] = (algo, env_id)
+ return trained_models
+
+
+def get_latest_run_id(log_path: str, env_id: str) -> int:
+ """
+ Returns the latest run number for the given log name and log path,
+ by finding the greatest number in the directories.
+
+ :param log_path: path to log folder
+ :param env_id:
+ :return: latest run number
+ """
+ max_run_id = 0
+ for path in glob.glob(os.path.join(log_path, env_id + "_[0-9]*")):
+ file_name = os.path.basename(path)
+ ext = file_name.split("_")[-1]
+ if (
+ env_id == "_".join(file_name.split("_")[:-1])
+ and ext.isdigit()
+ and int(ext) > max_run_id
+ ):
+ max_run_id = int(ext)
+ return max_run_id
+
+
+def get_saved_hyperparams(
+ stats_path: str,
+ norm_reward: bool = False,
+ test_mode: bool = False,
+) -> Tuple[Dict[str, Any], str]:
+ """
+ :param stats_path:
+ :param norm_reward:
+ :param test_mode:
+ :return:
+ """
+ hyperparams = {}
+ if not os.path.isdir(stats_path):
+ stats_path = None
+ else:
+ config_file = os.path.join(stats_path, "config.yml")
+ if os.path.isfile(config_file):
+ # Load saved hyperparameters
+ with open(os.path.join(stats_path, "config.yml"), "r") as f:
+ hyperparams = yaml.load(
+ f, Loader=yaml.UnsafeLoader
+ ) # pytype: disable=module-attr
+ hyperparams["normalize"] = hyperparams.get("normalize", False)
+ else:
+ obs_rms_path = os.path.join(stats_path, "obs_rms.pkl")
+ hyperparams["normalize"] = os.path.isfile(obs_rms_path)
+
+ # Load normalization params
+ if hyperparams["normalize"]:
+ if isinstance(hyperparams["normalize"], str):
+ normalize_kwargs = eval(hyperparams["normalize"])
+ if test_mode:
+ normalize_kwargs["norm_reward"] = norm_reward
+ else:
+ normalize_kwargs = {
+ "norm_obs": hyperparams["normalize"],
+ "norm_reward": norm_reward,
+ }
+ hyperparams["normalize_kwargs"] = normalize_kwargs
+ return hyperparams, stats_path
+
+
+class StoreDict(argparse.Action):
+ """
+ Custom argparse action for storing dict.
+
+ In: args1:0.0 args2:"dict(a=1)"
+ Out: {'args1': 0.0, arg2: dict(a=1)}
+ """
+
+ def __init__(self, option_strings, dest, nargs=None, **kwargs):
+ self._nargs = nargs
+ super(StoreDict, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
+
+ def __call__(self, parser, namespace, values, option_string=None):
+
+ arg_dict = {}
+
+ if hasattr(namespace, self.dest):
+ current_arg = getattr(namespace, self.dest)
+ if isinstance(current_arg, Dict):
+ arg_dict = getattr(namespace, self.dest)
+
+ for arguments in values:
+ if not arguments:
+ continue
+ key = arguments.split(":")[0]
+ value = ":".join(arguments.split(":")[1:])
+ # Evaluate the string as python code
+ arg_dict[key] = eval(value)
+ setattr(namespace, self.dest, arg_dict)
+
+
+def str2bool(value: Union[str, bool]) -> bool:
+ """
+ Convert logical string to boolean. Can be used as argparse type.
+ """
+
+ if isinstance(value, bool):
+ return value
+ if value.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif value.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def empty_str2none(value: Optional[str]) -> Optional[str]:
+ """
+ If string is empty, convert to None. Can be used as argparse type.
+ """
+
+ if not value:
+ return None
+ return value
diff --git a/env_manager/rbs_gym/rbs_gym/utils/wrappers.py b/env_manager/rbs_gym/rbs_gym/utils/wrappers.py
new file mode 100644
index 0000000..9107117
--- /dev/null
+++ b/env_manager/rbs_gym/rbs_gym/utils/wrappers.py
@@ -0,0 +1,395 @@
+import gymnasium as gym
+import numpy as np
+from matplotlib import pyplot as plt
+from scipy.signal import iirfilter, sosfilt, zpk2sos
+
+
+class DoneOnSuccessWrapper(gym.Wrapper):
+ """
+ Reset on success and offsets the reward.
+ Useful for GoalEnv.
+ """
+
+ def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1):
+ super(DoneOnSuccessWrapper, self).__init__(env)
+ self.reward_offset = reward_offset
+ self.n_successes = n_successes
+ self.current_successes = 0
+
+ def reset(self):
+ self.current_successes = 0
+ return self.env.reset()
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ if info.get("is_success", False):
+ self.current_successes += 1
+ else:
+ self.current_successes = 0
+ # number of successes in a row
+ done = done or self.current_successes >= self.n_successes
+ reward += self.reward_offset
+ return obs, reward, done, info
+
+ def compute_reward(self, achieved_goal, desired_goal, info):
+ reward = self.env.compute_reward(achieved_goal, desired_goal, info)
+ return reward + self.reward_offset
+
+
+class ActionNoiseWrapper(gym.Wrapper):
+ """
+ Add gaussian noise to the action (without telling the agent),
+ to test the robustness of the control.
+
+ :param env: (gym.Env)
+ :param noise_std: (float) Standard deviation of the noise
+ """
+
+ def __init__(self, env, noise_std=0.1):
+ super(ActionNoiseWrapper, self).__init__(env)
+ self.noise_std = noise_std
+
+ def step(self, action):
+ noise = np.random.normal(
+ np.zeros_like(action), np.ones_like(action) * self.noise_std
+ )
+ noisy_action = action + noise
+ return self.env.step(noisy_action)
+
+
+# from https://docs.obspy.org
+def lowpass(data, freq, df, corners=4, zerophase=False):
+ """
+ Butterworth-Lowpass Filter.
+
+ Filter data removing data over certain frequency ``freq`` using ``corners``
+ corners.
+ The filter uses :func:`scipy.signal.iirfilter` (for design)
+ and :func:`scipy.signal.sosfilt` (for applying the filter).
+
+ :type data: numpy.ndarray
+ :param data: Data to filter.
+ :param freq: Filter corner frequency.
+ :param df: Sampling rate in Hz.
+ :param corners: Filter corners / order.
+ :param zerophase: If True, apply filter once forwards and once backwards.
+ This results in twice the number of corners but zero phase shift in
+ the resulting filtered trace.
+ :return: Filtered data.
+ """
+ fe = 0.5 * df
+ f = freq / fe
+ # raise for some bad scenarios
+ if f > 1:
+ f = 1.0
+ msg = (
+ "Selected corner frequency is above Nyquist. "
+ + "Setting Nyquist as high corner."
+ )
+ print(msg)
+ z, p, k = iirfilter(corners, f, btype="lowpass", ftype="butter", output="zpk")
+ sos = zpk2sos(z, p, k)
+ if zerophase:
+ firstpass = sosfilt(sos, data)
+ return sosfilt(sos, firstpass[::-1])[::-1]
+ else:
+ return sosfilt(sos, data)
+
+
+class LowPassFilterWrapper(gym.Wrapper):
+ """
+ Butterworth-Lowpass
+
+ :param env: (gym.Env)
+ :param freq: Filter corner frequency.
+ :param df: Sampling rate in Hz.
+ """
+
+ def __init__(self, env, freq=5.0, df=25.0):
+ super(LowPassFilterWrapper, self).__init__(env)
+ self.freq = freq
+ self.df = df
+ self.signal = []
+
+ def reset(self):
+ self.signal = []
+ return self.env.reset()
+
+ def step(self, action):
+ self.signal.append(action)
+ filtered = np.zeros_like(action)
+ for i in range(self.action_space.shape[0]):
+ smoothed_action = lowpass(
+ np.array(self.signal)[:, i], freq=self.freq, df=self.df
+ )
+ filtered[i] = smoothed_action[-1]
+ return self.env.step(filtered)
+
+
+class ActionSmoothingWrapper(gym.Wrapper):
+ """
+ Smooth the action using exponential moving average.
+
+ :param env: (gym.Env)
+ :param smoothing_coef: (float) Smoothing coefficient (0 no smoothing, 1 very smooth)
+ """
+
+ def __init__(self, env, smoothing_coef: float = 0.0):
+ super(ActionSmoothingWrapper, self).__init__(env)
+ self.smoothing_coef = smoothing_coef
+ self.smoothed_action = None
+ # from https://github.com/rail-berkeley/softlearning/issues/3
+ # for smoothing latent space
+ # self.alpha = self.smoothing_coef
+ # self.beta = np.sqrt(1 - self.alpha ** 2) / (1 - self.alpha)
+
+ def reset(self):
+ self.smoothed_action = None
+ return self.env.reset()
+
+ def step(self, action):
+ if self.smoothed_action is None:
+ self.smoothed_action = np.zeros_like(action)
+ self.smoothed_action = (
+ self.smoothing_coef * self.smoothed_action
+ + (1 - self.smoothing_coef) * action
+ )
+ return self.env.step(self.smoothed_action)
+
+
+class DelayedRewardWrapper(gym.Wrapper):
+ """
+ Delay the reward by `delay` steps, it makes the task harder but more realistic.
+ The reward is accumulated during those steps.
+
+ :param env: (gym.Env)
+ :param delay: (int) Number of steps the reward should be delayed.
+ """
+
+ def __init__(self, env, delay=10):
+ super(DelayedRewardWrapper, self).__init__(env)
+ self.delay = delay
+ self.current_step = 0
+ self.accumulated_reward = 0.0
+
+ def reset(self):
+ self.current_step = 0
+ self.accumulated_reward = 0.0
+ return self.env.reset()
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+
+ self.accumulated_reward += reward
+ self.current_step += 1
+
+ if self.current_step % self.delay == 0 or done:
+ reward = self.accumulated_reward
+ self.accumulated_reward = 0.0
+ else:
+ reward = 0.0
+ return obs, reward, done, info
+
+
+class HistoryWrapper(gym.Wrapper):
+ """
+ Stack past observations and actions to give an history to the agent.
+
+ :param env: (gym.Env)
+ :param horizon: (int) Number of steps to keep in the history.
+ """
+
+ def __init__(self, env: gym.Env, horizon: int = 5):
+ assert isinstance(env.observation_space, gym.spaces.Box)
+
+ wrapped_obs_space = env.observation_space
+ wrapped_action_space = env.action_space
+
+ # TODO (external): double check, it seems wrong when we have different low and highs
+ low_obs = np.repeat(wrapped_obs_space.low, horizon, axis=-1)
+ high_obs = np.repeat(wrapped_obs_space.high, horizon, axis=-1)
+
+ low_action = np.repeat(wrapped_action_space.low, horizon, axis=-1)
+ high_action = np.repeat(wrapped_action_space.high, horizon, axis=-1)
+
+ low = np.concatenate((low_obs, low_action))
+ high = np.concatenate((high_obs, high_action))
+
+ # Overwrite the observation space
+ env.observation_space = gym.spaces.Box(
+ low=low, high=high, dtype=wrapped_obs_space.dtype
+ )
+
+ super(HistoryWrapper, self).__init__(env)
+
+ self.horizon = horizon
+ self.low_action, self.high_action = low_action, high_action
+ self.low_obs, self.high_obs = low_obs, high_obs
+ self.low, self.high = low, high
+ self.obs_history = np.zeros(low_obs.shape, low_obs.dtype)
+ self.action_history = np.zeros(low_action.shape, low_action.dtype)
+
+ def _create_obs_from_history(self):
+ return np.concatenate((self.obs_history, self.action_history))
+
+ def reset(self):
+ # Flush the history
+ self.obs_history[...] = 0
+ self.action_history[...] = 0
+ obs = self.env.reset()
+ self.obs_history[..., -obs.shape[-1] :] = obs
+ return self._create_obs_from_history()
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ last_ax_size = obs.shape[-1]
+
+ self.obs_history = np.roll(self.obs_history, shift=-last_ax_size, axis=-1)
+ self.obs_history[..., -obs.shape[-1] :] = obs
+
+ self.action_history = np.roll(
+ self.action_history, shift=-action.shape[-1], axis=-1
+ )
+ self.action_history[..., -action.shape[-1] :] = action
+ return self._create_obs_from_history(), reward, done, info
+
+
+class HistoryWrapperObsDict(gym.Wrapper):
+ """
+ History Wrapper for dict observation.
+
+ :param env: (gym.Env)
+ :param horizon: (int) Number of steps to keep in the history.
+ """
+
+ def __init__(self, env, horizon=5):
+ assert isinstance(env.observation_space.spaces["observation"], gym.spaces.Box)
+
+ wrapped_obs_space = env.observation_space.spaces["observation"]
+ wrapped_action_space = env.action_space
+
+ # TODO (external): double check, it seems wrong when we have different low and highs
+ low_obs = np.repeat(wrapped_obs_space.low, horizon, axis=-1)
+ high_obs = np.repeat(wrapped_obs_space.high, horizon, axis=-1)
+
+ low_action = np.repeat(wrapped_action_space.low, horizon, axis=-1)
+ high_action = np.repeat(wrapped_action_space.high, horizon, axis=-1)
+
+ low = np.concatenate((low_obs, low_action))
+ high = np.concatenate((high_obs, high_action))
+
+ # Overwrite the observation space
+ env.observation_space.spaces["observation"] = gym.spaces.Box(
+ low=low, high=high, dtype=wrapped_obs_space.dtype
+ )
+
+ super(HistoryWrapperObsDict, self).__init__(env)
+
+ self.horizon = horizon
+ self.low_action, self.high_action = low_action, high_action
+ self.low_obs, self.high_obs = low_obs, high_obs
+ self.low, self.high = low, high
+ self.obs_history = np.zeros(low_obs.shape, low_obs.dtype)
+ self.action_history = np.zeros(low_action.shape, low_action.dtype)
+
+ def _create_obs_from_history(self):
+ return np.concatenate((self.obs_history, self.action_history))
+
+ def reset(self):
+ # Flush the history
+ self.obs_history[...] = 0
+ self.action_history[...] = 0
+ obs_dict = self.env.reset()
+ obs = obs_dict["observation"]
+ self.obs_history[..., -obs.shape[-1] :] = obs
+
+ obs_dict["observation"] = self._create_obs_from_history()
+
+ return obs_dict
+
+ def step(self, action):
+ obs_dict, reward, done, info = self.env.step(action)
+ obs = obs_dict["observation"]
+ last_ax_size = obs.shape[-1]
+
+ self.obs_history = np.roll(self.obs_history, shift=-last_ax_size, axis=-1)
+ self.obs_history[..., -obs.shape[-1] :] = obs
+
+ self.action_history = np.roll(
+ self.action_history, shift=-action.shape[-1], axis=-1
+ )
+ self.action_history[..., -action.shape[-1] :] = action
+
+ obs_dict["observation"] = self._create_obs_from_history()
+
+ return obs_dict, reward, done, info
+
+
+class PlotActionWrapper(gym.Wrapper):
+ """
+ Wrapper for plotting the taken actions.
+ Only works with 1D actions for now.
+ Optionally, it can be used to plot the observations too.
+
+ :param env: (gym.Env)
+ :param plot_freq: (int) Plot every `plot_freq` episodes
+ """
+
+ def __init__(self, env, plot_freq=5):
+ super(PlotActionWrapper, self).__init__(env)
+ self.plot_freq = plot_freq
+ self.current_episode = 0
+ # Observation buffer (Optional)
+ # self.observations = []
+ # Action buffer
+ self.actions = []
+
+ def reset(self):
+ self.current_episode += 1
+ if self.current_episode % self.plot_freq == 0:
+ self.plot()
+ # Reset
+ self.actions = []
+ obs = self.env.reset()
+ self.actions.append([])
+ # self.observations.append(obs)
+ return obs
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+
+ self.actions[-1].append(action)
+ # self.observations.append(obs)
+
+ return obs, reward, done, info
+
+ def plot(self):
+ actions = self.actions
+ x = np.arange(sum([len(episode) for episode in actions]))
+ plt.figure("Actions")
+ plt.title("Actions during exploration", fontsize=14)
+ plt.xlabel("Timesteps", fontsize=14)
+ plt.ylabel("Action", fontsize=14)
+
+ start = 0
+ for i in range(len(self.actions)):
+ end = start + len(self.actions[i])
+ plt.plot(x[start:end], self.actions[i])
+ # Clipped actions: real behavior, note that it is between [-2, 2] for the Pendulum
+ # plt.scatter(x[start:end], np.clip(self.actions[i], -1, 1), s=1)
+ # plt.scatter(x[start:end], self.actions[i], s=1)
+ start = end
+
+ plt.show()
+
+
+class FeatureExtractorFreezeParammetersWrapper(gym.Wrapper):
+ """
+ Freezes parameters of the feature extractor.
+ """
+
+ def __init__(self, env: gym.Env):
+
+ super(FeatureExtractorFreezeParammetersWrapper, self).__init__(env)
+ for param in self.feature_extractor.parameters():
+ param.requires_grad = False
diff --git a/env_manager/rbs_gym/resource/rbs_gym b/env_manager/rbs_gym/resource/rbs_gym
new file mode 100644
index 0000000..e69de29
diff --git a/env_manager/rbs_gym/setup.cfg b/env_manager/rbs_gym/setup.cfg
new file mode 100644
index 0000000..759913d
--- /dev/null
+++ b/env_manager/rbs_gym/setup.cfg
@@ -0,0 +1,4 @@
+[develop]
+script_dir=$base/lib/rbs_gym
+[install]
+install_scripts=$base/lib/rbs_gym
diff --git a/env_manager/rbs_gym/setup.py b/env_manager/rbs_gym/setup.py
new file mode 100644
index 0000000..e280fa4
--- /dev/null
+++ b/env_manager/rbs_gym/setup.py
@@ -0,0 +1,36 @@
+import os
+from glob import glob
+
+from setuptools import find_packages, setup
+
+package_name = "rbs_gym"
+
+setup(
+ name=package_name,
+ version="0.0.0",
+ packages=find_packages(exclude=["test"]),
+ data_files=[
+ ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
+ ("share/" + package_name, ["package.xml"]),
+ (
+ os.path.join("share", package_name, "launch"),
+ glob(os.path.join("launch", "*launch.[pxy][yma]*")),
+ ),
+ ],
+ install_requires=["setuptools"],
+ zip_safe=True,
+ maintainer="narmak",
+ maintainer_email="ur.narmak@gmail.com",
+ description="TODO: Package description",
+ license="Apache-2.0",
+ tests_require=["pytest"],
+ entry_points={
+ "console_scripts": [
+ "train = rbs_gym.scripts.train:main",
+ "spawner = rbs_gym.scripts.spawner:main",
+ "velocity = rbs_gym.scripts.velocity:main",
+ "test_agent = rbs_gym.scripts.test_agent:main",
+ "evaluate = rbs_gym.scripts.evaluate:main",
+ ],
+ },
+)
diff --git a/env_manager/rbs_gym/test/test_copyright.py b/env_manager/rbs_gym/test/test_copyright.py
new file mode 100644
index 0000000..97a3919
--- /dev/null
+++ b/env_manager/rbs_gym/test/test_copyright.py
@@ -0,0 +1,25 @@
+# Copyright 2015 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_copyright.main import main
+import pytest
+
+
+# Remove the `skip` decorator once the source file(s) have a copyright header
+@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
+@pytest.mark.copyright
+@pytest.mark.linter
+def test_copyright():
+ rc = main(argv=['.', 'test'])
+ assert rc == 0, 'Found errors'
diff --git a/env_manager/rbs_gym/test/test_flake8.py b/env_manager/rbs_gym/test/test_flake8.py
new file mode 100644
index 0000000..27ee107
--- /dev/null
+++ b/env_manager/rbs_gym/test/test_flake8.py
@@ -0,0 +1,25 @@
+# Copyright 2017 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_flake8.main import main_with_errors
+import pytest
+
+
+@pytest.mark.flake8
+@pytest.mark.linter
+def test_flake8():
+ rc, errors = main_with_errors(argv=[])
+ assert rc == 0, \
+ 'Found %d code style errors / warnings:\n' % len(errors) + \
+ '\n'.join(errors)
diff --git a/env_manager/rbs_gym/test/test_pep257.py b/env_manager/rbs_gym/test/test_pep257.py
new file mode 100644
index 0000000..b234a38
--- /dev/null
+++ b/env_manager/rbs_gym/test/test_pep257.py
@@ -0,0 +1,23 @@
+# Copyright 2015 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_pep257.main import main
+import pytest
+
+
+@pytest.mark.linter
+@pytest.mark.pep257
+def test_pep257():
+ rc = main(argv=['.', 'test'])
+ assert rc == 0, 'Found code style errors / warnings'
diff --git a/env_manager/rbs_runtime/LICENSE b/env_manager/rbs_runtime/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/env_manager/rbs_runtime/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/env_manager/rbs_runtime/config/default-scene-config.yaml b/env_manager/rbs_runtime/config/default-scene-config.yaml
new file mode 100644
index 0000000..7016a9a
--- /dev/null
+++ b/env_manager/rbs_runtime/config/default-scene-config.yaml
@@ -0,0 +1,249 @@
+camera:
+ - clip_color: !tuple
+ - 0.01
+ - 1000.0
+ clip_depth: !tuple
+ - 0.05
+ - 10.0
+ enable: true
+ height: 128
+ horizontal_fov: 1.0471975511965976
+ image_format: R8G8B8
+ name: robot_camera
+ noise_mean: null
+ noise_stddev: null
+ publish_color: true
+ publish_depth: true
+ publish_points: true
+ random_pose_focal_point_z_offset: 0.0
+ random_pose_mode: orbit
+ random_pose_orbit_distance: 1.0
+ random_pose_orbit_height_range: !tuple
+ - 0.1
+ - 0.7
+ random_pose_orbit_ignore_arc_behind_robot: 0.39269908169872414
+ random_pose_rollout_counter: 0.0
+ random_pose_rollouts_num: 1
+ random_pose_select_position_options: []
+ relative_to: base_link
+ spawn_position: !tuple
+ - 0.0
+ - 0.0
+ - 1.0
+ spawn_quat_xyzw: !tuple
+ - 0.0
+ - 0.70710678118
+ - 0.0
+ - 0.70710678118
+ type: rgbd_camera
+ update_rate: 10
+ vertical_fov: 1.0471975511965976
+ width: 128
+gravity: !tuple
+ - 0.0
+ - 0.0
+ - -9.80665
+gravity_std: !tuple
+ - 0.0
+ - 0.0
+ - 0.0232
+light:
+ color: !tuple
+ - 1.0
+ - 1.0
+ - 1.0
+ - 1.0
+ direction: !tuple
+ - 0.5
+ - -0.25
+ - -0.75
+ distance: 1000.0
+ model_rollouts_num: 1
+ radius: 25.0
+ random_minmax_elevation: !tuple
+ - -0.15
+ - -0.65
+ type: sun
+ visual: true
+objects:
+ - color: null
+ name: star
+ orientation: !tuple
+ - 1.0
+ - 0.0
+ - 0.0
+ - 0.0
+ position: !tuple
+ - -0.1
+ - -0.40
+ - 0.1
+ randomize:
+ count: 0
+ models_rollouts_num: 0
+ random_color: false
+ random_model: false
+ random_orientation: false
+ random_pose: false
+ random_position: false
+ random_spawn_position_segments: []
+ random_spawn_position_update_workspace_centre: false
+ random_spawn_volume: !tuple
+ - 0.5
+ - 0.5
+ - 0.5
+ relative_to: world
+ static: false
+ texture: []
+ type: "model"
+
+ - color: null
+ name: box
+ mass: 1.0
+ size: !tuple
+ - 0.02
+ - 0.02
+ - 0.05
+ orientation: !tuple
+ - 1.0
+ - 0.0
+ - 0.0
+ - 0.0
+ position: !tuple
+ - -0.1
+ - -0.37
+ - 0.1
+ randomize:
+ count: 0
+ models_rollouts_num: 0
+ random_color: false
+ random_model: false
+ random_orientation: false
+ random_pose: false
+ random_position: false
+ random_spawn_position_segments: []
+ random_spawn_position_update_workspace_centre: false
+ random_spawn_volume: !tuple
+ - 0.5
+ - 0.5
+ - 0.5
+ relative_to: world
+ static: false
+ type: "box"
+
+
+ - color: null
+ name: cylinder
+ mass: 1.0
+ radius: 0.01
+ length: 0.05
+ orientation: !tuple
+ - 1.0
+ - 0.0
+ - 0.0
+ - 0.0
+ position: !tuple
+ - -0.1
+ - -0.34
+ - 0.1
+ randomize:
+ count: 0
+ models_rollouts_num: 0
+ random_color: false
+ random_model: false
+ random_orientation: false
+ random_pose: false
+ random_position: false
+ random_spawn_position_segments: []
+ random_spawn_position_update_workspace_centre: false
+ random_spawn_volume: !tuple
+ - 0.5
+ - 0.5
+ - 0.5
+ relative_to: world
+ static: false
+ type: "cylinder"
+
+ - color: null
+ name: board
+ orientation: !tuple
+ - 1.0
+ - 0.0
+ - 0.0
+ - 0.0
+ position: !tuple
+ - 0.0
+ - -0.37
+ - 0.01
+ randomize:
+ count: 0
+ models_rollouts_num: 0
+ random_color: false
+ random_model: false
+ random_orientation: false
+ random_pose: false
+ random_position: false
+ random_spawn_position_segments: []
+ random_spawn_position_update_workspace_centre: false
+ random_spawn_volume: !tuple
+ - 0.5
+ - 0.5
+ - 0.5
+ relative_to: world
+ static: true
+ texture: []
+ type: "model"
+physics_rollouts_num: 0
+plugins:
+ fts_broadcaster: false
+ scene_broadcaster: false
+ sensors_render_engine: ogre2
+ user_commands: false
+robot:
+ gripper_joint_positions: 0.0
+ joint_positions:
+ - -1.57
+ - 0.5
+ - 3.14159
+ - 1.5
+ - 0.0
+ - 1.4
+ - 0.0
+ name: rbs_arm
+ randomizer:
+ joint_positions: false
+ joint_positions_above_object_spawn: false
+ joint_positions_above_object_spawn_elevation: 0.2
+ joint_positions_above_object_spawn_xy_randomness: 0.2
+ joint_positions_std: 0.1
+ pose: false
+ spawn_volume: !tuple
+ - 1.0
+ - 1.0
+ - 0.0
+ spawn_position: !tuple
+ - 0.0
+ - 0.0
+ - 0.0
+ spawn_quat_xyzw: !tuple
+ - 0.0
+ - 0.0
+ - 0.0
+ - 1.0
+ urdf_string: ""
+ with_gripper: true
+terrain:
+ model_rollouts_num: 1
+ name: ground
+ size: !tuple
+ - 1.5
+ - 1.5
+ spawn_position: !tuple
+ - 0
+ - 0
+ - 0
+ spawn_quat_xyzw: !tuple
+ - 0
+ - 0
+ - 0
+ - 1
+ type: flat
diff --git a/env_manager/rbs_runtime/launch/runtime.launch.py b/env_manager/rbs_runtime/launch/runtime.launch.py
new file mode 100644
index 0000000..6f56440
--- /dev/null
+++ b/env_manager/rbs_runtime/launch/runtime.launch.py
@@ -0,0 +1,356 @@
+import os
+
+import xacro
+import yaml
+from ament_index_python.packages import get_package_share_directory
+from launch import LaunchDescription
+from launch.actions import (
+ DeclareLaunchArgument,
+ IncludeLaunchDescription,
+ OpaqueFunction,
+ TimerAction,
+)
+from launch.launch_description_sources import PythonLaunchDescriptionSource
+from launch.launch_introspector import indent
+from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
+from launch_ros.actions import Node
+from launch_ros.substitutions import FindPackageShare
+from robot_builder.external.ros2_control import ControllerManager
+from robot_builder.parser.urdf import URDF_parser
+
+
+def launch_setup(context, *args, **kwargs):
+ # Initialize Arguments
+ robot_type = LaunchConfiguration("robot_type")
+ # General arguments
+ with_gripper_condition = LaunchConfiguration("with_gripper")
+ description_package = LaunchConfiguration("description_package")
+ description_file = LaunchConfiguration("description_file")
+ use_moveit = LaunchConfiguration("use_moveit")
+ moveit_config_package = LaunchConfiguration("moveit_config_package")
+ moveit_config_file = LaunchConfiguration("moveit_config_file")
+ use_sim_time = LaunchConfiguration("use_sim_time")
+ scene_config_file = LaunchConfiguration("scene_config_file").perform(context)
+ ee_link_name = LaunchConfiguration("ee_link_name").perform(context)
+ base_link_name = LaunchConfiguration("base_link_name").perform(context)
+ control_space = LaunchConfiguration("control_space").perform(context)
+ control_strategy = LaunchConfiguration("control_strategy").perform(context)
+ interactive = LaunchConfiguration("interactive").perform(context)
+ real_robot = LaunchConfiguration("real_robot").perform(context)
+
+ use_rbs_utils = LaunchConfiguration("use_rbs_utils")
+ assembly_config_name = LaunchConfiguration("assembly_config_name")
+
+ if not scene_config_file == "":
+ config_file = {"config_file": scene_config_file}
+ else:
+ config_file = {}
+
+ description_package_abs_path = get_package_share_directory(
+ description_package.perform(context)
+ )
+
+ controllers_file = os.path.join(
+ description_package_abs_path, "config", "controllers.yaml"
+ )
+
+ xacro_file = os.path.join(
+ description_package_abs_path,
+ "urdf",
+ description_file.perform(context),
+ )
+
+ # xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml"
+ xacro_config_file = os.path.join(
+ description_package_abs_path, "urdf", "xacro_args.yaml"
+ )
+
+ # TODO: hide this to another place
+ # Load xacro_args
+ def param_constructor(loader, node, local_vars):
+ value = loader.construct_scalar(node)
+ return LaunchConfiguration(value).perform(
+ local_vars.get("context", "Launch context if not defined")
+ )
+
+ def variable_constructor(loader, node, local_vars):
+ value = loader.construct_scalar(node)
+ return local_vars.get(value, f"Variable '{value}' not found")
+
+ def load_xacro_args(yaml_file, local_vars):
+ # Get valut from ros2 argument
+ yaml.add_constructor(
+ "!param", lambda loader, node: param_constructor(loader, node, local_vars)
+ )
+
+ # Get value from local variable in this code
+ # The local variable should be initialized before the loader was called
+ yaml.add_constructor(
+ "!variable",
+ lambda loader, node: variable_constructor(loader, node, local_vars),
+ )
+
+ with open(yaml_file, "r") as file:
+ return yaml.load(file, Loader=yaml.FullLoader)
+
+ mappings_data = load_xacro_args(xacro_config_file, locals())
+
+ robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data)
+
+ robot_description_semantic_content = ""
+
+ if use_moveit.perform(context) == "true":
+ srdf_config_file = os.path.join(
+ get_package_share_directory(moveit_config_package.perform(context)),
+ "srdf",
+ "xacro_args.yaml",
+ )
+ srdf_file = os.path.join(
+ get_package_share_directory(moveit_config_package.perform(context)),
+ "srdf",
+ moveit_config_file.perform(context),
+ )
+ srdf_mappings = load_xacro_args(srdf_config_file, locals())
+ robot_description_semantic_content = xacro.process_file(
+ srdf_file, mappings=srdf_mappings
+ )
+ robot_description_semantic_content = (
+ robot_description_semantic_content.toprettyxml(indent=" ")
+ )
+ control_space = "joint"
+ control_strategy = "position"
+ interactive = "false"
+
+ robot_description_content = robot_description_doc.toprettyxml(indent=" ")
+ robot_description = {"robot_description": robot_description_content}
+
+ # Parse robot and configure controller's file for ControllerManager
+ robot = URDF_parser.load_string(
+ robot_description_content, ee_link_name=ee_link_name
+ )
+ ControllerManager.save_to_yaml(
+ robot, description_package_abs_path, "controllers.yaml"
+ )
+
+ rbs_robot_setup = IncludeLaunchDescription(
+ PythonLaunchDescriptionSource(
+ [
+ PathJoinSubstitution(
+ [FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"]
+ )
+ ]
+ ),
+ launch_arguments={
+ "with_gripper": with_gripper_condition,
+ "controllers_file": controllers_file,
+ "robot_type": robot_type,
+ "description_package": description_package,
+ "description_file": description_file,
+ "robot_name": robot_type,
+ "use_moveit": use_moveit,
+ "moveit_config_package": moveit_config_package,
+ "moveit_config_file": moveit_config_file,
+ "use_sim_time": use_sim_time,
+ "use_controllers": "true",
+ "robot_description": robot_description_content,
+ "robot_description_semantic": robot_description_semantic_content,
+ "base_link_name": base_link_name,
+ "ee_link_name": ee_link_name,
+ "control_space": control_space,
+ "control_strategy": control_strategy,
+ "interactive_control": interactive,
+ "use_rbs_utils": use_rbs_utils,
+ "assembly_config_name": assembly_config_name,
+ }.items(),
+ )
+
+ if real_robot == "true":
+ controller_manager_node = Node(
+ package="controller_manager",
+ executable="ros2_control_node",
+ parameters=[
+ robot_description,
+ controllers_file,
+ ],
+ output="screen",
+ )
+ return [rbs_robot_setup, controller_manager_node]
+
+ rbs_runtime = Node(
+ package="rbs_runtime",
+ executable="runtime",
+ arguments=[("--ros-args --remap log_level:=debug")],
+ parameters=[robot_description, config_file, {"use_sim_time": True}],
+ )
+
+ clock_bridge = Node(
+ package="ros_gz_bridge",
+ executable="parameter_bridge",
+ arguments=["/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock"],
+ output="screen",
+ )
+
+ delay_robot_control_stack = TimerAction(period=10.0, actions=[rbs_robot_setup])
+
+ nodes_to_start = [rbs_runtime, clock_bridge, delay_robot_control_stack]
+ return nodes_to_start
+
+
+def generate_launch_description():
+ declared_arguments = []
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_type",
+ description="Type of robot by name",
+ choices=[
+ "rbs_arm",
+ "ar4",
+ "ur3",
+ "ur3e",
+ "ur5",
+ "ur5e",
+ "ur10",
+ "ur10e",
+ "ur16e",
+ ],
+ default_value="rbs_arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_package",
+ default_value="rbs_arm",
+ description="Description package with robot URDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom description.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "description_file",
+ default_value="rbs_arm_modular.xacro",
+ description="URDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "robot_name",
+ default_value="arm0",
+ description="Name for robot, used to apply namespace for specific robot in multirobot setup",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_package",
+ default_value="rbs_arm",
+ description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \
+ is not set, it enables use of a custom moveit config.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "moveit_config_file",
+ default_value="rbs_arm.srdf.xacro",
+ description="MoveIt SRDF/XACRO description file with the robot.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_sim_time",
+ default_value="true",
+ description="Make MoveIt to use simulation time.\
+ This is needed for the trajectory planing in simulation.",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "with_gripper", default_value="true", description="With gripper or not?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_moveit", default_value="false", description="Launch moveit?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "launch_perception", default_value="false", description="Launch perception?"
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_controllers",
+ default_value="true",
+ description="Launch controllers?",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "scene_config_file",
+ default_value="",
+ description="Path to a scene configuration file",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "ee_link_name",
+ default_value="",
+ description="End effector name of robot arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "base_link_name",
+ default_value="",
+ description="Base link name if robot arm",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "control_space",
+ default_value="task",
+ choices=["task", "joint"],
+ description="Specify the control space for the robot (e.g., task space).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "control_strategy",
+ default_value="position",
+ choices=["position", "velocity", "effort"],
+ description="Specify the control strategy (e.g., position control).",
+ )
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "interactive",
+ default_value="true",
+ description="Wheter to run the motion_control_handle controller",
+ ),
+ )
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "real_robot",
+ default_value="false",
+ description="Wheter to run on the real robot",
+ ),
+ )
+
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "use_rbs_utils",
+ default_value="true",
+ description="Wheter to use rbs_utils",
+ ),
+ )
+
+ declared_arguments.append(
+ DeclareLaunchArgument(
+ "assembly_config_name",
+ default_value="",
+ description="Assembly config name from rbs_assets_library",
+ ),
+ )
+
+ return LaunchDescription(
+ declared_arguments + [OpaqueFunction(function=launch_setup)]
+ )
diff --git a/env_manager/rbs_runtime/package.nix b/env_manager/rbs_runtime/package.nix
new file mode 100644
index 0000000..706e14d
--- /dev/null
+++ b/env_manager/rbs_runtime/package.nix
@@ -0,0 +1,29 @@
+# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0
+# Copyright 2025 None
+# Distributed under the terms of the Apache-2.0 license
+{
+ lib,
+ buildRosPackage,
+ ament-copyright,
+ ament-flake8,
+ ament-pep257,
+ env-manager,
+ env-manager-interfaces,
+ pythonPackages,
+ scenario,
+}:
+buildRosPackage rec {
+ pname = "ros-jazzy-rbs-runtime";
+ version = "0.0.0";
+
+ src = ./.;
+
+ buildType = "ament_python";
+ checkInputs = [ament-copyright ament-flake8 ament-pep257 pythonPackages.pytest];
+ propagatedBuildInputs = [pythonPackages.dacite scenario env-manager env-manager-interfaces];
+
+ meta = {
+ description = "TODO: Package description";
+ license = with lib.licenses; [asl20];
+ };
+}
diff --git a/env_manager/rbs_runtime/package.xml b/env_manager/rbs_runtime/package.xml
new file mode 100644
index 0000000..65e2d1d
--- /dev/null
+++ b/env_manager/rbs_runtime/package.xml
@@ -0,0 +1,21 @@
+
+
+
+ rbs_runtime
+ 0.0.0
+ TODO: Package description
+ narmak
+ Apache-2.0
+
+ env_manager
+ env_manager_interfaces
+
+ ament_copyright
+ ament_flake8
+ ament_pep257
+ python3-pytest
+
+
+ ament_python
+
+
diff --git a/env_manager/rbs_runtime/rbs_runtime/__init__.py b/env_manager/rbs_runtime/rbs_runtime/__init__.py
new file mode 100644
index 0000000..029c56e
--- /dev/null
+++ b/env_manager/rbs_runtime/rbs_runtime/__init__.py
@@ -0,0 +1,41 @@
+from pathlib import Path
+
+import yaml
+from dacite import from_dict
+from env_manager.models.configs import SceneData
+
+from typing import Dict, Any
+from env_manager.models.configs import (
+ BoxObjectData, CylinderObjectData, MeshData, ModelData, ObjectData, ObjectRandomizerData
+)
+
+def object_factory(obj_data: Dict[str, Any]) -> ObjectData:
+ obj_type = obj_data.get('type', '')
+
+ if "randomize" in obj_data and isinstance(obj_data["randomize"], dict):
+ obj_data["randomize"] = from_dict(data_class=ObjectRandomizerData, data=obj_data["randomize"])
+
+ if obj_type == 'box':
+ return from_dict(data_class=BoxObjectData, data=obj_data)
+ elif obj_type == 'cylinder':
+ return from_dict(data_class=CylinderObjectData, data=obj_data)
+ elif obj_type == 'mesh':
+ return from_dict(data_class=MeshData, data=obj_data)
+ elif obj_type == 'model':
+ return from_dict(data_class=ModelData, data=obj_data)
+ else:
+ return from_dict(data_class=ObjectData, data=obj_data)
+
+def scene_config_loader(file: str | Path) -> SceneData:
+ def tuple_constructor(loader, node):
+ return tuple(loader.construct_sequence(node))
+
+ with open(file, "r") as yaml_file:
+ yaml.SafeLoader.add_constructor("!tuple", tuple_constructor)
+ scene_cfg = yaml.load(yaml_file, Loader=yaml.SafeLoader)
+
+ scene_data = from_dict(data_class=SceneData, data=scene_cfg)
+
+ scene_data.objects = [object_factory(obj) for obj in scene_cfg.get('objects', [])]
+
+ return scene_data
diff --git a/env_manager/rbs_runtime/rbs_runtime/scripts/__init__.py b/env_manager/rbs_runtime/rbs_runtime/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/env_manager/rbs_runtime/rbs_runtime/scripts/runtime.py b/env_manager/rbs_runtime/rbs_runtime/scripts/runtime.py
new file mode 100755
index 0000000..e234263
--- /dev/null
+++ b/env_manager/rbs_runtime/rbs_runtime/scripts/runtime.py
@@ -0,0 +1,73 @@
+#!/usr/bin/python3
+
+import rclpy
+from ament_index_python.packages import get_package_share_directory
+from env_manager.scene import Scene
+from env_manager_interfaces.srv import ResetEnv
+from rbs_assets_library import get_world_file
+from rclpy.node import Node
+from scenario.bindings.gazebo import GazeboSimulator
+from rclpy.executors import MultiThreadedExecutor, ExternalShutdownException
+
+from .. import scene_config_loader
+
+class GazeboRuntime(Node):
+ def __init__(self) -> None:
+ super().__init__(node_name="rbs_gz_runtime")
+ self.declare_parameter("robot_description", "")
+ self.declare_parameter(
+ "config_file",
+ get_package_share_directory("rbs_runtime")
+ + "/config/default-scene-config.yaml",
+ )
+
+ self.gazebo = GazeboSimulator(step_size=0.001, rtf=1.0, steps_per_run=1)
+
+ self.gazebo.insert_world_from_sdf(get_world_file("default"))
+ if not self.gazebo.initialize():
+ raise RuntimeError("Gazebo cannot be initialized")
+
+ config_file = (
+ self.get_parameter("config_file").get_parameter_value().string_value
+ )
+
+ scene_data = scene_config_loader(config_file)
+
+ self.scene = Scene(
+ node=self,
+ gazebo=self.gazebo,
+ scene=scene_data,
+ robot_urdf_string=self.get_parameter("robot_description")
+ .get_parameter_value()
+ .string_value,
+ )
+ self.scene.init_scene()
+
+ self.reset_env_srv = self.create_service(
+ ResetEnv, "/env_manager/reset_env", self.reset_env
+ )
+ self.is_env_reset = False
+ self.timer = self.create_timer(self.gazebo.step_size(), self.gazebo_run)
+
+ def gazebo_run(self):
+ if self.is_env_reset:
+ self.scene.reset_scene()
+ self.is_env_reset = False
+ self.gazebo.run()
+
+ def reset_env(self, req, res):
+ self.is_env_reset = True
+ res.ok = self.is_env_reset
+ return res
+
+
+def main():
+ rclpy.init()
+
+ executor = MultiThreadedExecutor()
+ node = GazeboRuntime()
+ executor.add_node(node)
+ try:
+ executor.spin()
+ except (KeyboardInterrupt, ExternalShutdownException):
+ node.destroy_node()
diff --git a/env_manager/rbs_runtime/resource/rbs_runtime b/env_manager/rbs_runtime/resource/rbs_runtime
new file mode 100644
index 0000000..e69de29
diff --git a/env_manager/rbs_runtime/setup.cfg b/env_manager/rbs_runtime/setup.cfg
new file mode 100644
index 0000000..27e6f37
--- /dev/null
+++ b/env_manager/rbs_runtime/setup.cfg
@@ -0,0 +1,4 @@
+[develop]
+script_dir=$base/lib/rbs_runtime
+[install]
+install_scripts=$base/lib/rbs_runtime
diff --git a/env_manager/rbs_runtime/setup.py b/env_manager/rbs_runtime/setup.py
new file mode 100644
index 0000000..12ce2cd
--- /dev/null
+++ b/env_manager/rbs_runtime/setup.py
@@ -0,0 +1,36 @@
+import os
+from glob import glob
+
+from setuptools import find_packages, setup
+
+package_name = "rbs_runtime"
+
+setup(
+ name=package_name,
+ version="0.0.0",
+ packages=find_packages(exclude=["test"]),
+ data_files=[
+ ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
+ ("share/" + package_name, ["package.xml"]),
+ (
+ os.path.join("share", package_name, "launch"),
+ glob(os.path.join("launch", "*launch.[pxy][yma]*")),
+ ),
+ (
+ os.path.join("share", package_name, "config"),
+ glob(os.path.join("config", "*config.[yma]*")),
+ ),
+ ],
+ install_requires=["setuptools", "dacite"],
+ zip_safe=True,
+ maintainer="narmak",
+ maintainer_email="ur.narmak@gmail.com",
+ description="TODO: Package description",
+ license="Apache-2.0",
+ tests_require=["pytest"],
+ entry_points={
+ "console_scripts": [
+ "runtime = rbs_runtime.scripts.runtime:main",
+ ],
+ },
+)
diff --git a/env_manager/rbs_runtime/test/test_copyright.py b/env_manager/rbs_runtime/test/test_copyright.py
new file mode 100644
index 0000000..97a3919
--- /dev/null
+++ b/env_manager/rbs_runtime/test/test_copyright.py
@@ -0,0 +1,25 @@
+# Copyright 2015 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_copyright.main import main
+import pytest
+
+
+# Remove the `skip` decorator once the source file(s) have a copyright header
+@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
+@pytest.mark.copyright
+@pytest.mark.linter
+def test_copyright():
+ rc = main(argv=['.', 'test'])
+ assert rc == 0, 'Found errors'
diff --git a/env_manager/rbs_runtime/test/test_flake8.py b/env_manager/rbs_runtime/test/test_flake8.py
new file mode 100644
index 0000000..27ee107
--- /dev/null
+++ b/env_manager/rbs_runtime/test/test_flake8.py
@@ -0,0 +1,25 @@
+# Copyright 2017 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_flake8.main import main_with_errors
+import pytest
+
+
+@pytest.mark.flake8
+@pytest.mark.linter
+def test_flake8():
+ rc, errors = main_with_errors(argv=[])
+ assert rc == 0, \
+ 'Found %d code style errors / warnings:\n' % len(errors) + \
+ '\n'.join(errors)
diff --git a/env_manager/rbs_runtime/test/test_pep257.py b/env_manager/rbs_runtime/test/test_pep257.py
new file mode 100644
index 0000000..b234a38
--- /dev/null
+++ b/env_manager/rbs_runtime/test/test_pep257.py
@@ -0,0 +1,23 @@
+# Copyright 2015 Open Source Robotics Foundation, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ament_pep257.main import main
+import pytest
+
+
+@pytest.mark.linter
+@pytest.mark.pep257
+def test_pep257():
+ rc = main(argv=['.', 'test'])
+ assert rc == 0, 'Found code style errors / warnings'