diff --git a/env_manager/.gitignore b/env_manager/.gitignore new file mode 100644 index 0000000..48734be --- /dev/null +++ b/env_manager/.gitignore @@ -0,0 +1,178 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python,ros3 +# Edit at https://www.toptal.com/developers/gitignore?templates=python,ros3 + +### Python ### +# Byte-compiled / optimized / DLL files +**/__pycache__/ +**/*.py[cod] +**/*$py.class +**/*.cpython-*.pyc +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +#!! ERROR: ros3 is undefined. Use list command to see defined gitignore types !!# + +# End of https://www.toptal.com/developers/gitignore/api/python,ros3 diff --git a/env_manager/docs/about/img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png b/env_manager/docs/about/img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png new file mode 100644 index 0000000..df47ede Binary files /dev/null and b/env_manager/docs/about/img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png differ diff --git a/env_manager/docs/about/img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png b/env_manager/docs/about/img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png new file mode 100644 index 0000000..df5d2d8 Binary files /dev/null and b/env_manager/docs/about/img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png differ diff --git a/env_manager/docs/about/img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png b/env_manager/docs/about/img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png new file mode 100644 index 0000000..7c3dcba Binary files /dev/null and b/env_manager/docs/about/img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png differ diff --git a/env_manager/docs/about/img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png b/env_manager/docs/about/img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png new file mode 100644 index 0000000..aa7b626 Binary files /dev/null and b/env_manager/docs/about/img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png differ diff --git a/env_manager/docs/about/index.ru.md b/env_manager/docs/about/index.ru.md new file mode 100644 index 0000000..ee8fa08 --- /dev/null +++ b/env_manager/docs/about/index.ru.md @@ -0,0 +1,55 @@ +# Модуль управления виртуальными средами +--- + +При управлении роботами в симуляторе Gazebo через фреймворк ROS2 возникает необходимость конфигурировать не только робота-манипулятора, но и саму сцену. Однако стандартный подход, основанный на конфигурационных файлах Gazebo, зачастую оказывается избыточным и недостаточно гибким для динамических сценариев, к которым относится обучение с подкреплением. + +**env_manager** — это пакет, предназначенный для конфигурирования сцен в симуляторе Gazebo, предоставляющий более удобный и гибкий подход к созданию и настройке симуляционных сред. + +## Возможности пакета + +С последнего обновления модуль был полностью переработан. Если ранее его функции ограничивались указанием объектов, находящихся в среде, для работы в ROS2, то теперь он предоставляет инструменты для: +- полного конфигурирования сцены, +- настройки объектов наблюдения для ROS2. + +Конфигурация осуществляется с использованием **датаклассов** или **YAML-файлов**, что соответствует декларативному подходу описания сцены. Это делает процесс настройки интуитивно понятным и легко масштабируемым. Пример файла описания сцены, а также файл с конфигурацией по умолчанию доступны [здесь](https://gitlab.com/solid-sinusoid/env_manager/-/blob/b425a1b012bc8320bba7b68e5481da187d64d76e/rbs_runtime/config/default-scene-config.yaml). + +## Возможности конфигурации + +Модуль поддерживает добавление различных типов объектов в сцену, включая: +- **Модель** +- **Меш** +- **Бокс** +- **Цилиндр** +- **Сферу** + +Различие между "моделью" и "мешем" заключается в том, находится ли объект в библиотеке **rbs_assets_library** (подробнее о ней см. [соответствующий раздел](https://robossembler.org/docs/software/ros2#rbs_assets_library)). Дополнительно поддерживается **рандомизация объектов**, позволяющая случайным образом изменять их цвет и положение в сцене. + +Помимо объектов, с помощью пакета можно настраивать: +- **Источники света** +- **Сенсоры** +- **Роботов** +- **Рабочие поверхности** + +Каждый тип объекта обладает как параметрами размещения, так и параметрами рандомизации. Для камер предусмотрены настройки публикации данных: +- изображения глубины +- цветного изображения +- облаков точек. + +Параметры рандомизации могут включать в себя положение, ориентацию в заданных пользователем пределах, а также. Для рабочей поверхности также включается возможность рандомизации текстуры, а для робота имеется возможность рандомизировать его положения, в том числе конфигурацию и расположение базы робота. + +## Архитектура и спецификации + +Основная структура модуля включает обертки для добавления объектов в сцену. Полная спецификация доступных параметров и взаимосвязей между классами представлена в папке конфигураций. Для каждой категории объектов используются отдельные датаклассы, что упрощает организацию и модификацию параметров. + +Диаграмма классов конфигурации сцены представлена ниже: + +![scene_class_diagramm](./img/scene_data_class_diagram-9c873943a7c4492b3254a24973e1fac2.png) +*Диаграмма классов конфигурации сцены* + +## Примеры + +Ниже представлены различные сцены, созданные с использованием возможностей **env_manager**: + +| **Сценарий 1** | **Сценарий 2** | **Сценарий 3** | +|-----------------|-----------------|-----------------| +| ![one](./img/ar_textured_ground-1f72b8d6cb977cdca352bd6e81a3cd7d.png) | ![two](./img/ar_textured_ground2-79b89f8247de2d0e663dad39e151eeac.png) | ![three](./img/rbs_texture_ground_and_spawned_objects-38c103bb6197006d9decdb54fec2404f.png) | diff --git a/env_manager/docs/getting_started/getting_started.ru.md b/env_manager/docs/getting_started/getting_started.ru.md new file mode 100644 index 0000000..35af9b9 --- /dev/null +++ b/env_manager/docs/getting_started/getting_started.ru.md @@ -0,0 +1,48 @@ +# Начало работы с rbs_gym + +Пакет входит в проект Robossembler. Для установки пакета необходимо произвести установку всего проекта по [инструкции](https://gitlab.com/robossembler/robossembler-ros2/-/blob/b109c97b5c093e665135179668cb2091e6708387/docs/ru/installation.md) + +## Запуск тестовой среды + +Для запуска тестовой среды необходимо выполнить команду + +```sh +ros2 launch rbs_gym test_env.launch.py base_link_name:=base_link ee_link_name:=gripper_grasp_point control_space:=task control_strategy:=effort interactive:=false +``` + +Это продемонстрирует что все установилось и работает адекватно. Для визуализации можно воспользоваться графическим клиентом Gazebo в новом терминале: +```sh +ign gazebo -g +``` + +## Запуск обучения тестовой среды + +Запуск обучения производится следующей командой: + +```sh +ros2 launch rbs_gym train.launch.py base_link_name:=base_link ee_link_name:=gripper_grasp_point control_space:=task control_strategy:=effort interactive:=false +``` + +Команда запустит обучения алгоритмом SAC. Полный перечень аргументов можно посмотреть общим флагом `--show-args` применимым ко всем файлам запуска запускаемым посредством команды `ros2 launch`. + +```sh +ros2 launch rbs_gym train.launch.py --show-args +``` + +Метрики качества обучения можно наблюдать в папке `logs`, которая автоматически создастся в том месте откуда Вы запускали обучение агента. Для этого надо перейти в эту директорию и выполнить команду + +```sh +cd logs +aim up +``` + +Это выведет в консоль ссылку на веб интерфейс [Aim](https://aimstack.io/) который необходимо открыть в браузере. + +## Модификация сцены + +Обычно среда всегда прочно связана со сценой. Поэтому для модификации сцены часто также необходимо вносить правки и в среду + +Сцена задается как конфигурация [`env_manager`](../about/index.ru.md) пример конфигурации для сцены можно подглядеть [тут](../../rbs_gym/rbs_gym/envs/__init__.py) для тестовой среды + +В настоящий момент функционал активно разрабатывается в дальнейшем появится более удобный способ модификации сцены и модификации параметров среды с использованием [Hydra](https://hydra.cc/) + diff --git a/env_manager/env_manager/LICENSE b/env_manager/env_manager/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/env_manager/env_manager/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/env_manager/env_manager/env_manager/__init__.py b/env_manager/env_manager/env_manager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/env_manager/env_manager/env_manager/models/__init__.py b/env_manager/env_manager/env_manager/models/__init__.py new file mode 100644 index 0000000..2cef2a1 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/__init__.py @@ -0,0 +1,5 @@ +from .lights import * +from .objects import * +from .robots import * +from .sensors import * +from .terrains import * diff --git a/env_manager/env_manager/env_manager/models/configs/__init__.py b/env_manager/env_manager/env_manager/models/configs/__init__.py new file mode 100644 index 0000000..10201ce --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/__init__.py @@ -0,0 +1,15 @@ +from .camera import CameraData +from .light import LightData +from .objects import ( + BoxObjectData, + CylinderObjectData, + MeshData, + ModelData, + ObjectData, + PlaneObjectData, + SphereObjectData, + ObjectRandomizerData, +) +from .robot import RobotData +from .terrain import TerrainData +from .scene import SceneData diff --git a/env_manager/env_manager/env_manager/models/configs/camera.py b/env_manager/env_manager/env_manager/models/configs/camera.py new file mode 100644 index 0000000..c76e8d9 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/camera.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass, field + +import numpy as np + + +@dataclass +class CameraData: + """ + CameraData stores the parameters and configuration settings for an RGB-D camera. + + Attributes: + ---------- + name : str + The name of the camera instance. Default is an empty string. + enable : bool + Boolean flag to enable or disable the camera. Default is True. + type : str + Type of the camera. Default is "rgbd_camera". + relative_to : str + The reference frame relative to which the camera is placed. Default is "base_link". + + width : int + The image width (in pixels) captured by the camera. Default is 128. + height : int + The image height (in pixels) captured by the camera. Default is 128. + image_format : str + The image format used (e.g., "R8G8B8"). Default is "R8G8B8". + update_rate : int + The rate (in Hz) at which the camera provides updates. Default is 10 Hz. + horizontal_fov : float + The horizontal field of view (in radians). Default is pi / 3.0. + vertical_fov : float + The vertical field of view (in radians). Default is pi / 3.0. + + clip_color : tuple[float, float] + The near and far clipping planes for the color camera. Default is (0.01, 1000.0). + clip_depth : tuple[float, float] + The near and far clipping planes for the depth camera. Default is (0.05, 10.0). + + noise_mean : float | None + The mean value of the Gaussian noise applied to the camera's data. Default is None (no noise). + noise_stddev : float | None + The standard deviation of the Gaussian noise applied to the camera's data. Default is None (no noise). + + publish_color : bool + Whether or not to publish color data from the camera. Default is False. + publish_depth : bool + Whether or not to publish depth data from the camera. Default is False. + publish_points : bool + Whether or not to publish point cloud data from the camera. Default is False. + + spawn_position : tuple[float, float, float] + The initial spawn position (x, y, z) of the camera relative to the reference frame. Default is (0, 0, 1). + spawn_quat_xyzw : tuple[float, float, float, float] + The initial spawn orientation of the camera in quaternion (x, y, z, w). Default is (0, 0.70710678118, 0, 0.70710678118). + + random_pose_rollouts_num : int + The number of random pose rollouts. Default is 1. + random_pose_mode : str + The mode for random pose generation (e.g., "orbit"). Default is "orbit". + random_pose_orbit_distance : float + The fixed distance from the object in "orbit" mode. Default is 1.0. + random_pose_orbit_height_range : tuple[float, float] + The range of vertical positions (z-axis) in "orbit" mode. Default is (0.1, 0.7). + random_pose_orbit_ignore_arc_behind_robot : float + The arc angle (in radians) behind the robot to ignore when generating random poses. Default is pi / 8. + random_pose_select_position_options : list[tuple[float, float, float]] + A list of preset position options for the camera in random pose mode. Default is an empty list. + random_pose_focal_point_z_offset : float + The offset in the z-direction of the focal point when generating random poses. Default is 0.0. + random_pose_rollout_counter : float + A counter tracking the number of rollouts completed for random poses. Default is 0.0. + """ + + name: str = field(default_factory=str) + enable: bool = field(default=True) + type: str = field(default="rgbd_camera") + relative_to: str = field(default="base_link") + + width: int = field(default=128) + height: int = field(default=128) + image_format: str = field(default="R8G8B8") + update_rate: int = field(default=10) + horizontal_fov: float = field(default=np.pi / 3.0) + vertical_fov: float = field(default=np.pi / 3.0) + + clip_color: tuple[float, float] = field(default=(0.01, 1000.0)) + clip_depth: tuple[float, float] = field(default=(0.05, 10.0)) + + noise_mean: float | None = None + noise_stddev: float | None = None + + publish_color: bool = field(default=False) + publish_depth: bool = field(default=False) + publish_points: bool = field(default=False) + + spawn_position: tuple[float, float, float] = field(default=(0.0, 0.0, 1.0)) + spawn_quat_xyzw: tuple[float, float, float, float] = field( + default=(0.0, 0.70710678118, 0.0, 0.70710678118) + ) + + random_pose_rollouts_num: int = field(default=1) + random_pose_mode: str = field(default="orbit") + random_pose_orbit_distance: float = field(default=1.0) + random_pose_orbit_height_range: tuple[float, float] = field(default=(0.1, 0.7)) + random_pose_orbit_ignore_arc_behind_robot: float = field(default=np.pi / 8) + random_pose_select_position_options: list[tuple[float, float, float]] = field( + default_factory=list + ) + random_pose_focal_point_z_offset: float = field(default=0.0) + random_pose_rollout_counter: float = field(default=0.0) diff --git a/env_manager/env_manager/env_manager/models/configs/light.py b/env_manager/env_manager/env_manager/models/configs/light.py new file mode 100644 index 0000000..ea15f52 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/light.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field + + +@dataclass +class LightData: + """ + LightData stores the configuration and settings for a light source in a simulation or rendering environment. + + Attributes: + ---------- + type : str + The type of the light source (e.g., "sun", "point", "spot"). Default is "sun". + direction : tuple[float, float, float] + The direction vector of the light source in (x, y, z). This is typically used for directional lights like the sun. + Default is (0.5, -0.25, -0.75). + random_minmax_elevation : tuple[float, float] + The minimum and maximum elevation angles (in radians) for randomizing the light's direction. + Default is (-0.15, -0.65). + color : tuple[float, float, float, float] + The RGBA color of the light source. Each value ranges from 0 to 1. Default is (1.0, 1.0, 1.0, 1.0), which represents white light. + distance : float + The effective distance of the light source. Default is 1000.0 units. + visual : bool + A flag indicating whether the light source is visually represented in the simulation (e.g., whether it casts visible rays). + Default is True. + radius : float + The radius of the light's influence (for point and spot lights). Default is 25.0 units. + model_rollouts_num : int + The number of rollouts for randomizing light configurations in different simulation runs. Default is 1. + """ + + type: str = field(default="sun") + direction: tuple[float, float, float] = field(default=(0.5, -0.25, -0.75)) + random_minmax_elevation: tuple[float, float] = field(default=(-0.15, -0.65)) + color: tuple[float, float, float, float] = field(default=(1.0, 1.0, 1.0, 1.0)) + distance: float = field(default=1000.0) + visual: bool = field(default=True) + radius: float = field(default=25.0) + model_rollouts_num: int = field(default=1) diff --git a/env_manager/env_manager/env_manager/models/configs/objects.py b/env_manager/env_manager/env_manager/models/configs/objects.py new file mode 100644 index 0000000..167423d --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/objects.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass, field + +from env_manager.utils.types import Pose + + +@dataclass +class ObjectRandomizerData: + """ + ObjectRandomizerData contains parameters for randomizing object properties during simulation. + + Attributes: + ---------- + count : int + The number of objects to randomize. Default is 0. + random_pose : bool + Flag indicating whether to randomize both the position and orientation of objects. Default is False. + random_position : bool + Flag indicating whether to randomize the position of objects. Default is False. + random_orientation : bool + Flag indicating whether to randomize the orientation of objects. Default is False. + random_model : bool + Flag indicating whether to randomize the model of the objects. Default is False. + random_spawn_position_segments : list + List of segments within which the object can be randomly spawned. Default is an empty list. + random_spawn_position_update_workspace_centre : bool + Flag indicating whether to update the workspace center during random spawning. Default is False. + random_spawn_volume : tuple[float, float, float] + The volume within which objects can be randomly spawned, defined by (x, y, z). Default is (0.5, 0.5, 0.5). + models_rollouts_num : int + The number of rollouts for randomizing models. Default is 0. + """ + + count: int = field(default=0) + random_pose: bool = field(default=False) + random_position: bool = field(default=False) + random_orientation: bool = field(default=False) + random_model: bool = field(default=False) + random_spawn_position_segments: list = field(default_factory=list) + random_spawn_position_update_workspace_centre: bool = field(default=False) + random_spawn_volume: tuple[float, float, float] = field(default=(0.5, 0.5, 0.5)) + models_rollouts_num: int = field(default=0) + random_color: bool = field(default=False) + + +@dataclass +class ObjectData: + """ + ObjectData stores the base properties for any object in the simulation environment. + + Attributes: + ---------- + name : str + The name of the object. Default is "object". + type : str + The type of the object (e.g., "sphere", "box"). Default is an empty string. + relative_to : str + The reference frame relative to which the object is positioned. Default is "world". + position : tuple[float, float, float] + The position of the object in (x, y, z) coordinates. Default is (0.0, 0.0, 0.0). + orientation : tuple[float, float, float, float] + The orientation of the object as a quaternion (x, y, z, w). Default is (1.0, 0.0, 0.0, 0.0). + static : bool + Flag indicating if the object is static in the simulation (immovable). Default is False. + randomize : ObjectRandomizerData + Object randomizer settings for generating randomized object properties. Default is an empty ObjectRandomizerData instance. + """ + + name: str = field(default="object") + type: str = field(default_factory=str) + relative_to: str = field(default="world") + position: tuple[float, float, float] = field(default=(0.0, 0.0, 0.0)) + orientation: tuple[float, float, float, float] = field(default=(1.0, 0.0, 0.0, 0.0)) + color: tuple[float, float, float, float] | None = field(default=None) + static: bool = field(default=False) + randomize: ObjectRandomizerData = field(default_factory=ObjectRandomizerData) + + +@dataclass +class PrimitiveObjectData(ObjectData): + """ + PrimitiveObjectData defines the base properties for primitive objects (e.g., spheres, boxes) in the simulation. + + Attributes: + ---------- + collision : bool + Flag indicating whether the object participates in collision detection. Default is True. + visual : bool + Flag indicating whether the object has a visual representation in the simulation. Default is True. + color : tuple[float, float, float, float] + The color of the object, represented in RGBA format. Default is (0.8, 0.8, 0.8, 1.0). + mass : float + The mass of the object. Default is 0.1. + """ + + collision: bool = field(default=True) + visual: bool = field(default=True) + mass: float = field(default=0.1) + + +@dataclass +class SphereObjectData(PrimitiveObjectData): + """ + SphereObjectData defines the specific properties for a spherical object in the simulation. + + Attributes: + ---------- + radius : float + The radius of the sphere. Default is 0.025. + friction : float + The friction coefficient of the sphere when in contact with surfaces. Default is 1.0. + """ + + radius: float = field(default=0.025) + friction: float = field(default=1.0) + + +@dataclass +class PlaneObjectData(PrimitiveObjectData): + """ + PlaneObjectData defines the specific properties for a planar object in the simulation. + + Attributes: + ---------- + size : tuple[float, float] + The size of the plane, defined by its width and length. Default is (1.0, 1.0). + direction : tuple[float, float, float] + The normal vector representing the direction the plane faces. Default is (0.0, 0.0, 1.0). + friction : float + The friction coefficient of the plane when in contact with other objects. Default is 1.0. + """ + + size: tuple = field(default=(1.0, 1.0)) + direction: tuple = field(default=(0.0, 0.0, 1.0)) + friction: float = field(default=1.0) + + +@dataclass +class CylinderObjectData(PrimitiveObjectData): + """ + CylinderObjectData defines the specific properties for a cylindrical object in the simulation. + + Attributes: + ---------- + radius : float + The radius of the cylinder. Default is 0.025. + length : float + The length of the cylinder. Default is 0.1. + friction : float + The friction coefficient of the cylinder when in contact with surfaces. Default is 1.0. + """ + + radius: float = field(default=0.025) + length: float = field(default=0.1) + friction: float = field(default=1.0) + + +@dataclass +class BoxObjectData(PrimitiveObjectData): + """ + BoxObjectData defines the specific properties for a box-shaped object in the simulation. + + Attributes: + ---------- + size : tuple[float, float, float] + The dimensions of the box in (width, height, depth). Default is (0.05, 0.05, 0.05). + friction : float + The friction coefficient of the box when in contact with surfaces. Default is 1.0. + """ + + size: tuple = field(default=(0.05, 0.05, 0.05)) + friction: float = field(default=1.0) + + +@dataclass +class ModelData(ObjectData): + """ + MeshObjectData defines the specific properties for a mesh-based object in the simulation. + + Attributes: + ---------- + texture : list[float] + A list of texture coordinates or texture properties applied to the mesh. Default is an empty list. + """ + + texture: list[float] = field(default_factory=list) + + +@dataclass +class MeshData(ModelData): + mass: float = field(default_factory=float) + inertia: tuple[float, float, float, float, float, float] = field( + default_factory=tuple + ) + cm: Pose = field(default_factory=Pose) + collision: str = field(default_factory=str) + visual: str = field(default_factory=str) + friction: float = field(default_factory=float) + density: float = field(default_factory=float) diff --git a/env_manager/env_manager/env_manager/models/configs/robot.py b/env_manager/env_manager/env_manager/models/configs/robot.py new file mode 100644 index 0000000..b2f90ca --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/robot.py @@ -0,0 +1,103 @@ +from dataclasses import dataclass, field +from enum import Enum + + +@dataclass +class ToolData: + name: str = field(default_factory=str) + type: str = field(default_factory=str) + + +@dataclass +class GripperData(ToolData): + pass + + +@dataclass +class ParallelGripperData(GripperData): + pass + + +@dataclass +class MultifingerGripperData(GripperData): + pass + + +@dataclass +class VacuumGripperData(GripperData): + pass + + +class GripperEnum(Enum): + parallel = ParallelGripperData + mulrifinger = MultifingerGripperData + vacuum = VacuumGripperData + + +@dataclass +class RobotRandomizerData: + """ + RobotRandomizerData stores configuration parameters for randomizing robot properties during simulation. + + Attributes: + ---------- + pose : bool + Flag indicating whether to randomize the robot's pose (position and orientation). Default is False. + spawn_volume : tuple[float, float, float] + The volume within which the robot can be spawned, defined by (x, y, z) dimensions. Default is (1.0, 1.0, 0.0). + joint_positions : bool + Flag indicating whether to randomize the robot's joint positions. Default is False. + joint_positions_std : float + The standard deviation for randomizing the robot's joint positions. Default is 0.1. + joint_positions_above_object_spawn : bool + Flag indicating whether the robot's joint positions should be randomized to place the robot above the object's spawn position. Default is False. + joint_positions_above_object_spawn_elevation : float + The elevation above the object's spawn position when placing the robot's joints. Default is 0.2. + joint_positions_above_object_spawn_xy_randomness : float + The randomness in the x and y coordinates when placing the robot's joints above the object's spawn position. Default is 0.2. + """ + + pose: bool = field(default=False) + spawn_volume: tuple[float, float, float] = field(default=(1.0, 1.0, 0.0)) + joint_positions: bool = field(default=False) + joint_positions_std: float = field(default=0.1) + joint_positions_above_object_spawn: bool = field(default=False) + joint_positions_above_object_spawn_elevation: float = field(default=0.2) + joint_positions_above_object_spawn_xy_randomness: float = field(default=0.2) + + +@dataclass +class RobotData: + """ + RobotData stores the base properties and configurations for a robot in the simulation. + + Attributes: + ---------- + name : str + The name of the robot. Default is an empty string. + urdf_string : str + Optional parameter that can store a URDF. This parameter is overridden by the node parameters if set in the node configuration. + spawn_position : tuple[float, float, float] + The position where the robot will be spawned in (x, y, z) coordinates. Default is (0.0, 0.0, 0.0). + spawn_quat_xyzw : tuple[float, float, float, float] + The orientation of the robot in quaternion format (x, y, z, w) at spawn. Default is (0.0, 0.0, 0.0, 1.0). + joint_positions : list[float] + A list of the robot's joint positions. Default is an empty list. + with_gripper : bool + Flag indicating whether the robot is equipped with a gripper. Default is False. + gripper_joint_positions : list[float] | float + The joint positions for the gripper. Can be a list of floats or a single float. Default is an empty list. + randomizer : RobotRandomizerData + The randomization settings for the robot, allowing various properties to be randomized in simulation. Default is an instance of RobotRandomizerData. + """ + + name: str = field(default_factory=str) + urdf_string: str = field(default_factory=str) + spawn_position: tuple[float, float, float] = field(default=(0.0, 0.0, 0.0)) + spawn_quat_xyzw: tuple[float, float, float, float] = field( + default=(0.0, 0.0, 0.0, 1.0) + ) + joint_positions: list[float] = field(default_factory=list) + with_gripper: bool = field(default=False) + gripper_joint_positions: list[float] | float = field(default_factory=list) + randomizer: RobotRandomizerData = field(default_factory=RobotRandomizerData) diff --git a/env_manager/env_manager/env_manager/models/configs/scene.py b/env_manager/env_manager/env_manager/models/configs/scene.py new file mode 100644 index 0000000..94d078f --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/scene.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field + +from .camera import CameraData +from .light import LightData +from .objects import ObjectData +from .robot import RobotData +from .terrain import TerrainData + + +@dataclass +class PluginsData: + """ + PluginsData stores the configuration for various plugins used in the simulation environment. + + Attributes: + ---------- + scene_broadcaster : bool + Flag indicating whether the scene broadcaster plugin is enabled. Default is False. + user_commands : bool + Flag indicating whether the user commands plugin is enabled. Default is False. + fts_broadcaster : bool + Flag indicating whether the force torque sensor (FTS) broadcaster plugin is enabled. Default is False. + sensors_render_engine : str + The rendering engine used for sensors. Default is "ogre2". + """ + + scene_broadcaster: bool = field(default_factory=bool) + user_commands: bool = field(default_factory=bool) + fts_broadcaster: bool = field(default_factory=bool) + sensors_render_engine: str = field(default="ogre2") + + +@dataclass +class SceneData: + """ + SceneData contains the configuration and settings for the simulation scene. + + Attributes: + ---------- + physics_rollouts_num : int + The number of physics rollouts to perform in the simulation. Default is 0. + gravity : tuple[float, float, float] + The gravitational acceleration vector applied in the scene, represented as (x, y, z). Default is (0.0, 0.0, -9.80665). + gravity_std : tuple[float, float, float] + The standard deviation for the gravitational acceleration, represented as (x, y, z). Default is (0.0, 0.0, 0.0232). + robot : RobotData + The configuration data for the robot present in the scene. Default is an instance of RobotData. + terrain : TerrainData + The configuration data for the terrain in the scene. Default is an instance of TerrainData. + light : LightData + The configuration data for the lighting in the scene. Default is an instance of LightData. + objects : list[ObjectData] + A list of objects present in the scene, represented by their ObjectData configurations. Default is an empty list. + camera : list[CameraData] + A list of cameras in the scene, represented by their CameraData configurations. Default is an empty list. + plugins : PluginsData + The configuration data for various plugins utilized in the simulation environment. Default is an instance of PluginsData. + """ + + physics_rollouts_num: int = field(default=0) + gravity: tuple[float, float, float] = field(default=(0.0, 0.0, -9.80665)) + gravity_std: tuple[float, float, float] = field(default=(0.0, 0.0, 0.0232)) + robot: RobotData = field(default_factory=RobotData) + terrain: TerrainData = field(default_factory=TerrainData) + light: LightData = field(default_factory=LightData) + objects: list[ObjectData] = field(default_factory=list) + camera: list[CameraData] = field(default_factory=list) + plugins: PluginsData = field(default_factory=PluginsData) diff --git a/env_manager/env_manager/env_manager/models/configs/terrain.py b/env_manager/env_manager/env_manager/models/configs/terrain.py new file mode 100644 index 0000000..f240132 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/configs/terrain.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass, field + + +@dataclass +class TerrainData: + """ + TerrainData stores the configuration for the terrain in the simulation environment. + + Attributes: + ---------- + name : str + The name of the terrain. Default is "ground". + type : str + The type of terrain (e.g., "flat", "hilly", "uneven"). Default is "flat". + spawn_position : tuple[float, float, float] + The position where the terrain will be spawned in the simulation, represented as (x, y, z). Default is (0, 0, 0). + spawn_quat_xyzw : tuple[float, float, float, float] + The orientation of the terrain at spawn, represented as a quaternion (x, y, z, w). Default is (0, 0, 0, 1). + size : tuple[float, float] + The size of the terrain, represented as (width, length). Default is (1.5, 1.5). + model_rollouts_num : int + The number of rollouts for randomizing terrain models. Default is 1. + """ + + name: str = field(default="ground") + type: str = field(default="flat") + spawn_position: tuple[float, float, float] = field(default=(0, 0, 0)) + spawn_quat_xyzw: tuple[float, float, float, float] = field(default=(0, 0, 0, 1)) + size: tuple[float, float] = field(default=(1.5, 1.5)) + ambient: tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0)) + specular: tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0)) + diffuse: tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0)) + model_rollouts_num: int = field(default=1) diff --git a/env_manager/env_manager/env_manager/models/lights/__init__.py b/env_manager/env_manager/env_manager/models/lights/__init__.py new file mode 100644 index 0000000..f791750 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/lights/__init__.py @@ -0,0 +1,34 @@ +from enum import Enum +from gym_gz.scenario.model_wrapper import ModelWrapper + +from .random_sun import RandomSun +from .sun import Sun + + +# Enum для типов света +class LightType(Enum): + SUN = "sun" + RANDOM_SUN = "random_sun" + + +LIGHT_MODEL_CLASSES = { + LightType.SUN: Sun, + LightType.RANDOM_SUN: RandomSun, +} + + +def get_light_model_class(light_type: str) -> type[ModelWrapper]: + try: + light_enum = LightType(light_type) + return LIGHT_MODEL_CLASSES[light_enum] + except KeyError: + raise ValueError(f"Unsupported light type: {light_type}") + + +def is_light_type_randomizable(light_type: str) -> bool: + try: + light_enum = LightType(light_type) + return light_enum == LightType.RANDOM_SUN + except ValueError: + return False + diff --git a/env_manager/env_manager/env_manager/models/lights/random_sun.py b/env_manager/env_manager/env_manager/models/lights/random_sun.py new file mode 100644 index 0000000..e91f56c --- /dev/null +++ b/env_manager/env_manager/env_manager/models/lights/random_sun.py @@ -0,0 +1,240 @@ +import numpy as np +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from numpy.random import RandomState +from scenario import core as scenario +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class RandomSun(model_wrapper.ModelWrapper): + """ + RandomSun is a class that creates a randomly positioned directional light (like the Sun) in a simulation world. + + Attributes: + ---------- + world : scenario_gazebo.World + The simulation world where the model is inserted. + name : str + The name of the sun model. Default is "sun". + minmax_elevation : tuple[float, float] + Minimum and maximum values for the random elevation angle of the sun. Default is (-0.15, -0.65). + distance : float + Distance from the origin to the sun in the simulation. Default is 800.0. + visual : bool + Flag indicating whether the sun should have a visual sphere in the simulation. Default is True. + radius : float + Radius of the visual sphere representing the sun. Default is 20.0. + color_minmax_r : tuple[float, float] + Range for the red component of the sun's light color. Default is (1.0, 1.0). + color_minmax_g : tuple[float, float] + Range for the green component of the sun's light color. Default is (1.0, 1.0). + color_minmax_b : tuple[float, float] + Range for the blue component of the sun's light color. Default is (1.0, 1.0). + specular : float + The specular factor for the sun's light. Default is 1.0. + attenuation_minmax_range : tuple[float, float] + Range for the light attenuation (falloff) distance. Default is (750.0, 15000.0). + attenuation_minmax_constant : tuple[float, float] + Range for the constant component of the light attenuation. Default is (0.5, 1.0). + attenuation_minmax_linear : tuple[float, float] + Range for the linear component of the light attenuation. Default is (0.001, 0.1). + attenuation_minmax_quadratic : tuple[float, float] + Range for the quadratic component of the light attenuation. Default is (0.0001, 0.01). + np_random : RandomState | None + Random state for generating random values. Default is None, in which case a new random generator is created. + + Raises: + ------- + RuntimeError: + Raised if the sun model fails to be inserted into the world. + + Methods: + -------- + get_sdf(): + Generates the SDF string for the sun model. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "sun", + minmax_elevation: tuple[float, float] = (-0.15, -0.65), + distance: float = 800.0, + visual: bool = True, + radius: float = 20.0, + color_minmax_r: tuple[float, float] = (1.0, 1.0), + color_minmax_g: tuple[float, float] = (1.0, 1.0), + color_minmax_b: tuple[float, float] = (1.0, 1.0), + specular: float = 1.0, + attenuation_minmax_range: tuple[float, float] = (750.0, 15000.0), + attenuation_minmax_constant: tuple[float, float] = (0.5, 1.0), + attenuation_minmax_linear: tuple[float, float] = (0.001, 0.1), + attenuation_minmax_quadratic: tuple[float, float] = (0.0001, 0.01), + np_random: RandomState | None = None, + **kwargs, + ): + if np_random is None: + np_random = np.random.default_rng() + + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Get random yaw direction + direction = np_random.uniform(-1.0, 1.0, (2,)) + # Normalize yaw direction + direction = direction / np.linalg.norm(direction) + + # Get random elevation + direction = np.append( + direction, + np_random.uniform(minmax_elevation[0], minmax_elevation[1]), + ) + # Normalize again + direction = direction / np.linalg.norm(direction) + + # Initial pose + initial_pose = scenario_core.Pose( + ( + -direction[0] * distance, + -direction[1] * distance, + -direction[2] * distance, + ), + (1, 0, 0, 0), + ) + + # Create SDF string for the model + sdf = self.get_sdf( + model_name=model_name, + direction=tuple(direction), + visual=visual, + radius=radius, + color_minmax_r=color_minmax_r, + color_minmax_g=color_minmax_g, + color_minmax_b=color_minmax_b, + attenuation_minmax_range=attenuation_minmax_range, + attenuation_minmax_constant=attenuation_minmax_constant, + attenuation_minmax_linear=attenuation_minmax_linear, + attenuation_minmax_quadratic=attenuation_minmax_quadratic, + specular=specular, + np_random=np_random, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) + + @classmethod + def get_sdf( + self, + model_name: str, + direction: tuple[float, float, float], + visual: bool, + radius: float, + color_minmax_r: tuple[float, float], + color_minmax_g: tuple[float, float], + color_minmax_b: tuple[float, float], + attenuation_minmax_range: tuple[float, float], + attenuation_minmax_constant: tuple[float, float], + attenuation_minmax_linear: tuple[float, float], + attenuation_minmax_quadratic: tuple[float, float], + specular: float, + np_random: RandomState, + ) -> str: + """ + Generates the SDF string for the sun model. + + Args: + ----- + model_name : str + The name of the model. + direction : Tuple[float, float, float] + The direction of the sun's light. + visual : bool + If True, a visual representation of the sun will be created. + radius : float + The radius of the visual representation. + color_minmax_r : Tuple[float, float] + Range for the red component of the light color. + color_minmax_g : Tuple[float, float] + Range for the green component of the light color. + color_minmax_b : Tuple[float, float] + Range for the blue component of the light color. + attenuation_minmax_range : Tuple[float, float] + Range for light attenuation distance. + attenuation_minmax_constant : Tuple[float, float] + Range for the constant attenuation factor. + attenuation_minmax_linear : Tuple[float, float] + Range for the linear attenuation factor. + attenuation_minmax_quadratic : Tuple[float, float] + Range for the quadratic attenuation factor. + specular : float + The specular reflection factor for the light. + np_random : RandomState + The random number generator used to sample random values for the parameters. + + Returns: + -------- + str: + The SDF string for the sun model. + """ + # Sample random values for parameters + color_r = np_random.uniform(color_minmax_r[0], color_minmax_r[1]) + color_g = np_random.uniform(color_minmax_g[0], color_minmax_g[1]) + color_b = np_random.uniform(color_minmax_b[0], color_minmax_b[1]) + attenuation_range = np_random.uniform( + attenuation_minmax_range[0], attenuation_minmax_range[1] + ) + attenuation_constant = np_random.uniform( + attenuation_minmax_constant[0], attenuation_minmax_constant[1] + ) + attenuation_linear = np_random.uniform( + attenuation_minmax_linear[0], attenuation_minmax_linear[1] + ) + attenuation_quadratic = np_random.uniform( + attenuation_minmax_quadratic[0], attenuation_minmax_quadratic[1] + ) + + return f''' + + true + + + {direction[0]} {direction[1]} {direction[2]} + + {attenuation_range} + {attenuation_constant} + {attenuation_linear} + {attenuation_quadratic} + + {color_r} {color_g} {color_b} 1 + {specular*color_r} {specular*color_g} {specular*color_b} 1 + true + + { + f""" + + + + {radius} + + + + {color_r} {color_g} {color_b} 1 + + false + + """ if visual else "" + } + + + ''' diff --git a/env_manager/env_manager/env_manager/models/lights/sun.py b/env_manager/env_manager/env_manager/models/lights/sun.py new file mode 100644 index 0000000..a10c565 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/lights/sun.py @@ -0,0 +1,191 @@ +import numpy as np +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Sun(model_wrapper.ModelWrapper): + """ + Sun is a class that represents a directional light source in the simulation, similar to the Sun. + + It can have both visual and light properties, with customizable parameters such as color, direction, and attenuation. + + Attributes: + ---------- + world : scenario_gazebo.World + The Gazebo world where the sun model will be inserted. + name : str + The name of the sun model. Default is "sun". + direction : tuple[float, float, float] + The direction of the sunlight, normalized. Default is (0.5, -0.25, -0.75). + color : tuple[float, float, float, float] + The RGBA color values for the light's diffuse color. Default is (1.0, 1.0, 1.0, 1.0). + distance : float + The distance from the origin to the sun. Default is 800.0. + visual : bool + If True, a visual representation of the sun will be added. Default is True. + radius : float + The radius of the visual representation of the sun. Default is 20.0. + specular : float + The intensity of the specular reflection. Default is 1.0. + attenuation_range : float + The maximum range for the light attenuation. Default is 10000.0. + attenuation_constant : float + The constant attenuation factor. Default is 0.9. + attenuation_linear : float + The linear attenuation factor. Default is 0.01. + attenuation_quadratic : float + The quadratic attenuation factor. Default is 0.001. + + Raises: + ------- + RuntimeError: + If the sun model fails to be inserted into the Gazebo world. + + Methods: + -------- + get_sdf() -> str: + Generates the SDF string used to describe the sun model in the Gazebo simulation. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "sun", + direction: tuple[float, float, float] = (0.5, -0.25, -0.75), + color: tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), + distance: float = 800.0, + visual: bool = True, + radius: float = 20.0, + specular: float = 1.0, + attenuation_range: float = 10000.0, + attenuation_constant: float = 0.9, + attenuation_linear: float = 0.01, + attenuation_quadratic: float = 0.001, + **kwargs, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Normalize direction + np_direction = np.array(direction) + np_direction = np_direction / np.linalg.norm(np_direction) + + # Initial pose + initial_pose = scenario_core.Pose( + ( + -np_direction[0] * distance, + -np_direction[1] * distance, + -np_direction[2] * distance, + ), + (1, 0, 0, 0), + ) + + # Create SDF string for the model + sdf = self.get_sdf( + model_name=model_name, + direction=tuple(np_direction), + color=color, + visual=visual, + radius=radius, + specular=specular, + attenuation_range=attenuation_range, + attenuation_constant=attenuation_constant, + attenuation_linear=attenuation_linear, + attenuation_quadratic=attenuation_quadratic, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) + + @classmethod + def get_sdf( + cls, + model_name: str, + direction: tuple[float, float, float], + color: tuple[float, float, float, float], + visual: bool, + radius: float, + specular: float, + attenuation_range: float, + attenuation_constant: float, + attenuation_linear: float, + attenuation_quadratic: float, + ) -> str: + """ + Generates the SDF string for the sun model. + + Args: + ----- + model_name : str + The name of the sun model. + direction : tuple[float, float, float] + The direction vector for the sunlight. + color : tuple[float, float, float, float] + The RGBA color values for the sunlight. + visual : bool + Whether to include a visual representation of the sun (a sphere). + radius : float + The radius of the visual sphere. + specular : float + The specular reflection intensity. + attenuation_range : float + The range of the light attenuation. + attenuation_constant : float + The constant factor for the light attenuation. + attenuation_linear : float + The linear factor for the light attenuation. + attenuation_quadratic : float + The quadratic factor for the light attenuation. + + Returns: + -------- + str: + The SDF string for the sun model. + """ + + return f''' + + true + + + {direction[0]} {direction[1]} {direction[2]} + + {attenuation_range} + {attenuation_constant} + {attenuation_linear} + {attenuation_quadratic} + + {color[0]} {color[1]} {color[2]} 1 + {specular*color[0]} {specular*color[1]} {specular*color[2]} 1 + true + + { + f""" + + + + {radius} + + + + {color[0]} {color[1]} {color[2]} 1 + + false + + """ if visual else "" + } + + + ''' diff --git a/env_manager/env_manager/env_manager/models/objects/__init__.py b/env_manager/env_manager/env_manager/models/objects/__init__.py new file mode 100644 index 0000000..a1a70c6 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/__init__.py @@ -0,0 +1,53 @@ +from enum import Enum + +from gym_gz.scenario.model_wrapper import ModelWrapper + +from .model import Model +from .mesh import Mesh +from .primitives import Box, Cylinder, Plane, Sphere +# from .random_object import RandomObject +from .random_primitive import RandomPrimitive + + +class ObjectType(Enum): + BOX = "box" + SPHERE = "sphere" + CYLINDER = "cylinder" + PLANE = "plane" + RANDOM_PRIMITIVE = "random_primitive" + RANDOM_MESH = "random_mesh" + MODEL = "model" + MESH = "mesh" + +OBJECT_MODEL_CLASSES = { + ObjectType.BOX: Box, + ObjectType.SPHERE: Sphere, + ObjectType.CYLINDER: Cylinder, + ObjectType.PLANE: Plane, + ObjectType.RANDOM_PRIMITIVE: RandomPrimitive, + # ObjectType.RANDOM_MESH: RandomObject, + ObjectType.MODEL: Model, + ObjectType.MESH: Mesh +} + + +RANDOMIZABLE_TYPES = { + ObjectType.RANDOM_PRIMITIVE, + ObjectType.RANDOM_MESH, +} + + +def get_object_model_class(object_type: str) -> type[ModelWrapper]: + try: + object_enum = ObjectType(object_type) + return OBJECT_MODEL_CLASSES[object_enum] + except KeyError: + raise ValueError(f"Unsupported object type: {object_type}") + + +def is_object_type_randomizable(object_type: str) -> bool: + try: + object_enum = ObjectType(object_type) + return object_enum in RANDOMIZABLE_TYPES + except ValueError: + return False diff --git a/env_manager/env_manager/env_manager/models/objects/mesh.py b/env_manager/env_manager/env_manager/models/objects/mesh.py new file mode 100644 index 0000000..f467dfd --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/mesh.py @@ -0,0 +1,183 @@ +import numpy as np +import trimesh +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo +from trimesh import Trimesh + +from env_manager.utils.types import Point + +InertiaTensor = tuple[float, float, float, float, float, float] + + +class Mesh(model_wrapper.ModelWrapper): + """ """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "object", + type: str = "mesh", + relative_to: str = "world", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + color: tuple[float, float, float, float] | None = None, + static: bool = False, + texture: list[float] = [], + mass: float = 0.0, + inertia: InertiaTensor | None = None, + cm: Point | None = None, + collision: str = "", + visual: str = "", + friction: float = 0.0, + density: float = 0.0, + **kwargs, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + # Calculate inertia parameters for provided mesh file if not exist + if not inertia and collision and density: + mass, cm, inertia = self.calculate_inertia(collision, density) + else: + raise RuntimeError( + f"Please provide collision mesh filepath for model {name} to calculate inertia" + ) + + if not color: + color = tuple(np.random.uniform(0.0, 1.0, (3,))) + color = (color[0], color[1], color[2], 1.0) + + # Create SDF string for the model + sdf = self.get_sdf( + name=name, + static=static, + collision=collision, + visual=visual, + mass=mass, + inertia=inertia, + friction=friction, + color=color, + center_of_mass=cm, + gui_only=True, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError(f"Failed to insert {model_name}") + # Get the model + model = world.get_model(model_name) + + # Initialize base class + super().__init__(model=model) + + def calculate_inertia(self, file_path, density): + mesh = trimesh.load(file_path) + + if not isinstance(mesh, Trimesh): + raise RuntimeError("Please provide correct stl mesh filepath") + + volume = mesh.volume + mass: float = volume * density + center_of_mass: Point = tuple(mesh.center_mass) + inertia = mesh.moment_inertia + eigenvalues = np.linalg.eigvals(inertia) + inertia_tensor: InertiaTensor = ( + inertia[0, 0], + inertia[0, 1], + inertia[0, 2], + inertia[1, 1], + inertia[1, 2], + inertia[2, 2], + ) + return mass, center_of_mass, inertia_tensor + + @classmethod + def get_sdf( + cls, + name: str, + static: bool, + collision: str, + visual: str, + mass: float, + inertia: InertiaTensor, + friction: float, + color: tuple[float, float, float, float], + center_of_mass: Point, + gui_only: bool, + ) -> str: + """ + Generates the SDF string for the box model. + + Args: + - mesh_args (MeshPureData): Object that contain data of provided mesh data + + Returns: + The SDF string that defines the box model in Gazebo. + """ + return f''' + + {"true" if static else "false"} + + { + f""" + + + + {collision} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + """ if collision else "" + } + { + f""" + + + + {visual} + + + + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + + {1.0 - color[3]} + {'1 false' if gui_only else ''} + + """ if visual else "" + } + + {mass} + {center_of_mass[0]} {center_of_mass[1]} {center_of_mass[2]} 0 0 0 + + {inertia[0]} + {inertia[1]} + {inertia[2]} + {inertia[3]} + {inertia[4]} + {inertia[5]} + + + + + ''' diff --git a/env_manager/env_manager/env_manager/models/objects/model.py b/env_manager/env_manager/env_manager/models/objects/model.py new file mode 100644 index 0000000..1f6de89 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/model.py @@ -0,0 +1,75 @@ +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from rbs_assets_library import get_model_file +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Model(model_wrapper.ModelWrapper): + """ + Represents a 3D mesh model in the Gazebo simulation environment. + This class is responsible for loading and inserting a mesh model + into the Gazebo world with specified attributes. + + Args: + world (scenario_gazebo.World): The Gazebo world where the mesh model will be inserted. + name (str, optional): The name of the mesh model. Defaults to "object". + position (tuple[float, float, float], optional): The position of the mesh in the world. Defaults to (0, 0, 0). + orientation (tuple[float, float, float, float], optional): The orientation of the mesh in quaternion format. Defaults to (1, 0, 0, 0). + gui_only (bool, optional): If True, the visual representation of the mesh will only appear in the GUI. Defaults to False. + **kwargs: Additional keyword arguments. + + Raises: + RuntimeError: If the mesh model fails to be inserted into the Gazebo world. + + Methods: + get_sdf(model_name: str) -> str: + Generates the SDF string used to describe the mesh model in the Gazebo simulation. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "object", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + gui_only: bool = False, + **kwargs, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + # Create SDF string for the model + sdf = self.get_sdf( + model_name=name, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model(sdf, initial_pose, model_name) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + super().__init__(model=model) + + @classmethod + def get_sdf( + cls, + model_name: str, + ) -> str: + """ + Generates the SDF string for the specified mesh model. + + Args: + model_name (str): The name of the mesh model to generate the SDF for. + + Returns: + str: The SDF string that defines the mesh model in Gazebo. + """ + return get_model_file(model_name) diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/__init__.py b/env_manager/env_manager/env_manager/models/objects/primitives/__init__.py new file mode 100644 index 0000000..0dcf311 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/primitives/__init__.py @@ -0,0 +1,4 @@ +from .box import Box +from .cylinder import Cylinder +from .plane import Plane +from .sphere import Sphere diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/box.py b/env_manager/env_manager/env_manager/models/objects/primitives/box.py new file mode 100644 index 0000000..fae02bb --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/primitives/box.py @@ -0,0 +1,176 @@ +import numpy as np +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Box(model_wrapper.ModelWrapper): + """ + The Box class represents a 3D box model in the Gazebo simulation environment. + It includes physical and visual properties such as size, mass, color, and collision properties. + + Attributes: + world (scenario_gazebo.World): The Gazebo world where the box model will be inserted. + name (str): The name of the box model. Default is "box". + position (tuple[float, float, float]): The position of the box in the world. Default is (0, 0, 0). + orientation (tuple[float, float, float, float]): The orientation of the box in quaternion format. Default is (1, 0, 0, 0). + size (tuple[float, float, float]): The size of the box (width, length, height). Default is (0.05, 0.05, 0.05). + mass (float): The mass of the box in kilograms. Default is 0.1. + static (bool): If True, the box will be immovable and static. Default is False. + collision (bool): If True, the box will have collision properties. Default is True. + friction (float): The friction coefficient for the box’s surface. Default is 1.0. + visual (bool): If True, the box will have a visual representation. Default is True. + gui_only (bool): If True, the visual representation of the box will only appear in the GUI, not in the simulation physics. Default is False. + color (tuple[float, float, float, float]): The RGBA color of the box. Default is (0.8, 0.8, 0.8, 1.0). + + Raises: + RuntimeError: + If the box model fails to be inserted into the Gazebo world. + + Methods: + -------- + get_sdf() -> str: + Generates the SDF string used to describe the box model in the Gazebo simulation. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "box", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + size: tuple[float, float, float] = (0.05, 0.05, 0.05), + mass: float = 0.1, + static: bool = False, + collision: bool = True, + friction: float = 1.0, + visual: bool = True, + gui_only: bool = False, + color: tuple[float, float, float, float] | None = None, + **kwargs, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + if not color: + color = tuple(np.random.uniform(0.0, 1.0, (3,))) + color = (color[0], color[1], color[2], 1.0) + + # Create SDF string for the model + sdf = self.get_sdf( + model_name=model_name, + size=size, + mass=mass, + static=static, + collision=collision, + friction=friction, + visual=visual, + gui_only=gui_only, + color=color, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError(f"Failed to insert {model_name}") + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + super().__init__(model=model) + + @classmethod + def get_sdf( + cls, + model_name: str, + size: tuple[float, float, float], + mass: float, + static: bool, + collision: bool, + friction: float, + visual: bool, + gui_only: bool, + color: tuple[float, float, float, float], + ) -> str: + """ + Generates the SDF string for the box model. + + Args: + model_name (str): The name of the box model. + size (tuple[float, float, float]): The dimensions of the box (width, length, height). + mass (float): The mass of the box. + static (bool): If True, the box will be static and immovable. + collision (bool): If True, the box will have collision properties. + friction (float): The friction coefficient for the box. + visual (bool): If True, the box will have a visual representation. + gui_only (bool): If True, the box's visual representation will only appear in the GUI and will not affect physics. + color (tuple[float, float, float, float]): The RGBA color of the box. + + Returns: + The SDF string that defines the box model in Gazebo. + """ + return f''' + + {"true" if static else "false"} + + { + f""" + + + + {size[0]} {size[1]} {size[2]} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + """ if collision else "" + } + { + f""" + + + + {size[0]} {size[1]} {size[2]} + + + + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + + {1.0 - color[3]} + {'1 false' if gui_only else ''} + + """ if visual else "" + } + + {mass} + + {(size[1]**2 + size[2]**2) * mass / 12} + {(size[0]**2 + size[2]**2) * mass / 12} + {(size[0]**2 + size[1]**2) * mass / 12} + 0.0 + 0.0 + 0.0 + + + + + ''' diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/cylinder.py b/env_manager/env_manager/env_manager/models/objects/primitives/cylinder.py new file mode 100644 index 0000000..38df76f --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/primitives/cylinder.py @@ -0,0 +1,173 @@ +import numpy as np +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Cylinder(model_wrapper.ModelWrapper): + """ + The Cylinder class represents a cylindrical model in the Gazebo simulation environment. + + Attributes: + world (scenario_gazebo.World): The Gazebo world where the cylinder model will be inserted. + name (str): The name of the cylinder model. Default is "cylinder". + position (tuple[float, float, float]): The position of the cylinder in the world. Default is (0, 0, 0). + orientation (tuple[float, float, float, float]): The orientation of the cylinder in quaternion format. Default is (1, 0, 0, 0). + radius (float): The radius of the cylinder. Default is 0.025. + length (float): The length/height of the cylinder. Default is 0.1. + mass (float): The mass of the cylinder in kilograms. Default is 0.1. + static (bool): If True, the cylinder will be immovable. Default is False. + collision (bool): If True, the cylinder will have collision properties. Default is True. + friction (float): The friction coefficient for the cylinder's surface. Default is 1.0. + visual (bool): If True, the cylinder will have a visual representation. Default is True. + gui_only (bool): If True, the visual representation of the cylinder will only appear in the GUI, not in the simulation physics. Default is False. + color (tuple[float, float, float, float]): The RGBA color of the cylinder. Default is (0.8, 0.8, 0.8, 1.0). + + Raises: + RuntimeError: If the cylinder model fails to be inserted into the Gazebo world. + + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "cylinder", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + radius: float = 0.025, + length: float = 0.1, + mass: float = 0.1, + static: bool = False, + collision: bool = True, + friction: float = 1.0, + visual: bool = True, + gui_only: bool = False, + color: tuple[float, float, float, float] | None = None, + **kwargs, + ): + model_name = get_unique_model_name(world, name) + + initial_pose = scenario_core.Pose(position, orientation) + + if not color: + color = tuple(np.random.uniform(0.0, 1.0, (3,))) + color = (color[0], color[1], color[2], 1.0) + + sdf = self.get_sdf( + model_name=model_name, + radius=radius, + length=length, + mass=mass, + static=static, + collision=collision, + friction=friction, + visual=visual, + gui_only=gui_only, + color=color, + ) + + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError(f"Failed to insert {model_name}") + + model = world.get_model(model_name) + + super().__init__(model=model) + + @classmethod + def get_sdf( + cls, + model_name: str, + radius: float, + length: float, + mass: float, + static: bool, + collision: bool, + friction: float, + visual: bool, + gui_only: bool, + color: tuple[float, float, float, float], + ) -> str: + """ + Generates the SDF string for the cylinder model. + + Args: + model_name (str): The name of the model. + radius (float): The radius of the cylinder. + length (float): The length or height of the cylinder. + mass (float): The mass of the cylinder in kilograms. + static (bool): If True, the cylinder will remain immovable in the simulation. + collision (bool): If True, adds collision properties to the cylinder. + friction (float): The friction coefficient for the cylinder's surface. + visual (bool): If True, a visual representation of the cylinder will be created. + gui_only (bool): If True, the visual representation will only appear in the GUI, without impacting the simulation's physics. + color (tuple[float, float, float, float]): The RGBA color of the cylinder, where each value is between 0 and 1. + + Returns: + str: The SDF string representing the cylinder model + """ + inertia_xx_yy = (3 * radius**2 + length**2) * mass / 12 + + return f''' + + {"true" if static else "false"} + + { + f""" + + + + {radius} + {length} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + """ if collision else "" + } + { + f""" + + + + {radius} + {length} + + + + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + + {1.0-color[3]} + {'1 false' if gui_only else ''} + + """ if visual else "" + } + + {mass} + + {inertia_xx_yy} + {inertia_xx_yy} + {(mass*radius**2)/2} + 0.0 + 0.0 + 0.0 + + + + + ''' diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/plane.py b/env_manager/env_manager/env_manager/models/objects/primitives/plane.py new file mode 100644 index 0000000..ffbeefb --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/primitives/plane.py @@ -0,0 +1,100 @@ +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Plane(model_wrapper.ModelWrapper): + """ + The Plane class represents a plane model in the Gazebo simulation environment. + It allows for defining a flat surface with collision and visual properties, as well as its orientation and friction settings. + + Attributes: + world (scenario_gazebo.World): The Gazebo world where the plane model will be inserted. + name (str): The name of the plane model. Default is "plane". + position (tuple[float, float, float]): The position of the plane in the world. Default is (0, 0, 0). + orientation (tuple[float, float, float, float]): The orientation of the plane in quaternion format. Default is (1, 0, 0, 0). + size (tuple[float, float]): The size of the plane along the x and y axes. Default is (1.0, 1.0). + direction (tuple[float, float, float]): The normal vector representing the plane's direction. Default is (0.0, 0.0, 1.0), representing a horizontal plane. + collision (bool): If True, the plane will have collision properties. Default is True. + friction (float): The friction coefficient for the plane's surface. Default is 1.0. + visual (bool): If True, the plane will have a visual representation. Default is True. + + Raises: + RuntimeError: If the plane model fails to be inserted into the Gazebo world. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "plane", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + size: tuple[float, float] = (1.0, 1.0), + direction: tuple[float, float, float] = (0.0, 0.0, 1.0), + collision: bool = True, + friction: float = 1.0, + visual: bool = True, + **kwargs, + ): + model_name = get_unique_model_name(world, name) + + initial_pose = scenario_core.Pose(position, orientation) + + sdf = f''' + + true + + { + f""" + + + + {direction[0]} {direction[1]} {direction[2]} + {size[0]} {size[1]} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + """ if collision else "" + } + { + f""" + + + + {direction[0]} {direction[1]} {direction[2]} + {size[0]} {size[1]} + + + + 0.8 0.8 0.8 1 + 0.8 0.8 0.8 1 + 0.8 0.8 0.8 1 + + + """ if visual else "" + } + + + ''' + + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError(f"Failed to insert {model_name}") + + model = world.get_model(model_name) + + super().__init__(model=model) diff --git a/env_manager/env_manager/env_manager/models/objects/primitives/sphere.py b/env_manager/env_manager/env_manager/models/objects/primitives/sphere.py new file mode 100644 index 0000000..1fa2761 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/primitives/sphere.py @@ -0,0 +1,176 @@ +import numpy as np +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Sphere(model_wrapper.ModelWrapper): + """ + A Sphere model for Gazebo simulation. + + This class represents a spherical object that can be inserted into a Gazebo world. + The sphere can be customized by specifying its physical properties, such as radius, mass, + collision parameters, friction, and visual attributes. + + Attributes: + world (scenario_gazebo.World): The Gazebo world where the sphere model will be inserted. + name (str): The name of the sphere model. Defaults to "sphere". + position (tuple[float, float, float]): The position of the sphere in the world. Defaults to (0, 0, 0). + orientation (tuple[float, float, float, float]): The orientation of the sphere in quaternion format. Defaults to (1, 0, 0, 0). + radius (float): The radius of the sphere. Defaults to 0.025. + mass (float): The mass of the sphere. Defaults to 0.1. + static (bool): If True, the sphere will be static in the simulation. Defaults to False. + collision (bool): If True, the sphere will have collision properties. Defaults to True. + friction (float): The friction coefficient of the sphere's surface. Defaults to 1.0. + visual (bool): If True, the sphere will have a visual representation. Defaults to True. + gui_only (bool): If True, the visual will only appear in the GUI. Defaults to False. + color (tuple[float, float, float, float]): The RGBA color of the sphere. Defaults to (0.8, 0.8, 0.8, 1.0). + + Raises: + RuntimeError: If the sphere fails to be inserted into the Gazebo world. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "sphere", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + radius: float = 0.025, + mass: float = 0.1, + static: bool = False, + collision: bool = True, + friction: float = 1.0, + visual: bool = True, + gui_only: bool = False, + color: tuple[float, float, float, float] | None = None, + **kwargs, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + if not color: + color = tuple(np.random.uniform(0.0, 1.0, (3,))) + color = (color[0], color[1], color[2], 1.0) + + # Create SDF string for the model + sdf = self.get_sdf( + model_name=model_name, + radius=radius, + mass=mass, + static=static, + collision=collision, + friction=friction, + visual=visual, + gui_only=gui_only, + color=color, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError(f"Failed to insert {model_name}") + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + super().__init__(model=model) + + @classmethod + def get_sdf( + cls, + model_name: str, + radius: float, + mass: float, + static: bool, + collision: bool, + friction: float, + visual: bool, + gui_only: bool, + color: tuple[float, float, float, float], + ) -> str: + """ + Generates the SDF (Simulation Description Format) string for the sphere model. + + Args: + model_name (str): The name of the model. + radius (float): The radius of the sphere. + mass (float): The mass of the sphere. + static (bool): Whether the sphere is static in the simulation. + collision (bool): Whether the sphere should have collision properties. + friction (float): The friction coefficient for the sphere. + visual (bool): Whether the sphere should have a visual representation. + gui_only (bool): Whether the visual representation is only visible in the GUI. + color (tuple[float, float, float, float]): The RGBA color of the sphere. + + Returns: + str: The SDF string representing the sphere. + """ + # Inertia is identical for all axes + inertia_xx_yy_zz = (mass * radius**2) * 2 / 5 + + return f''' + + {"true" if static else "false"} + + { + f""" + + + + {radius} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + """ if collision else "" + } + { + f""" + + + + {radius} + + + + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + {color[0]} {color[1]} {color[2]} {color[3]} + + {1.0 - color[3]} + {'1 false' if gui_only else ''} + + """ if visual else "" + } + + {mass} + + {inertia_xx_yy_zz} + {inertia_xx_yy_zz} + {inertia_xx_yy_zz} + 0.0 + 0.0 + 0.0 + + + + + ''' diff --git a/env_manager/env_manager/env_manager/models/objects/random_object.py b/env_manager/env_manager/env_manager/models/objects/random_object.py new file mode 100644 index 0000000..5f399c6 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/random_object.py @@ -0,0 +1,84 @@ +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from numpy.random import RandomState +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + +from env_manager.models.utils import ModelCollectionRandomizer + + +class RandomObject(model_wrapper.ModelWrapper): + """ + Represents a randomly selected 3D object model in the Gazebo simulation environment. + This class allows for the insertion of various models based on a collection of model paths, + utilizing a randomization strategy for the chosen model. + + Args: + world (scenario_gazebo.World): The Gazebo world where the random object model will be inserted. + name (str, optional): The name of the random object model. Defaults to "object". + position (tuple[float, float, float], optional): The position of the object in the world. Defaults to (0, 0, 0). + orientation (tuple[float, float, float, float], optional): The orientation of the object in quaternion format. Defaults to (1, 0, 0, 0). + model_paths (str | None, optional): Paths to the model files. Must be set; raises an error if None. + owner (str, optional): The owner of the model collection. Defaults to "GoogleResearch". + collection (str, optional): The collection of models to choose from. Defaults to "Google Scanned Objects". + server (str, optional): The server URL for the model collection. Defaults to "https://fuel.ignitionrobotics.org". + server_version (str, optional): The version of the server to use. Defaults to "1.0". + unique_cache (bool, optional): If True, enables caching of unique models. Defaults to False. + reset_collection (bool, optional): If True, resets the model collection for new selections. Defaults to False. + np_random (RandomState | None, optional): An instance of RandomState for random number generation. Defaults to None. + **kwargs: Additional keyword arguments. + + Raises: + RuntimeError: If the model path is not set or if the random object model fails to be inserted into the Gazebo world. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "object", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + model_paths: str | None = None, + owner: str = "GoogleResearch", + collection: str = "Google Scanned Objects", + server: str = "https://fuel.ignitionrobotics.org", + server_version: str = "1.0", + unique_cache: bool = False, + reset_collection: bool = False, + np_random: RandomState | None = None, + **kwargs, + ): + # Get a unique model name + if model_paths is None: + raise RuntimeError("Set model path for continue") + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + model_collection_randomizer = ModelCollectionRandomizer( + model_paths=model_paths, + owner=owner, + collection=collection, + server=server, + server_version=server_version, + unique_cache=unique_cache, + reset_collection=reset_collection, + np_random=np_random, + ) + + # Note: using default arguments here + modified_sdf_file = model_collection_randomizer.random_model() + + # Insert the model + ok_model = world.to_gazebo().insert_model( + modified_sdf_file, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) diff --git a/env_manager/env_manager/env_manager/models/objects/random_primitive.py b/env_manager/env_manager/env_manager/models/objects/random_primitive.py new file mode 100644 index 0000000..58510b1 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/objects/random_primitive.py @@ -0,0 +1,164 @@ +import numpy as np +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from numpy.random import RandomState +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + +from . import Box, Cylinder, Sphere + + +class RandomPrimitive(model_wrapper.ModelWrapper): + """ + Represents a randomly generated primitive shape (box, cylinder, or sphere) in the Gazebo simulation environment. + This class allows for the insertion of various primitive models based on the user's specifications or randomly chosen. + + Args: + world (scenario_gazebo.World): The Gazebo world where the primitive model will be inserted. + name (str, optional): The name of the primitive model. Defaults to "primitive". + use_specific_primitive ((str | None), optional): If specified, the exact type of primitive to create ('box', 'cylinder', or 'sphere'). Defaults to None, which will randomly select a primitive type. + position (tuple[float, float, float], optional): The position of the primitive in the world. Defaults to (0, 0, 0). + orientation (tuple[float, float, float, float], optional): The orientation of the primitive in quaternion format. Defaults to (1, 0, 0, 0). + static (bool, optional): If True, the primitive will be static and immovable. Defaults to False. + collision (bool, optional): If True, the primitive will have collision properties. Defaults to True. + visual (bool, optional): If True, the primitive will have a visual representation. Defaults to True. + gui_only (bool, optional): If True, the visual representation will only appear in the GUI and not in the simulation physics. Defaults to False. + np_random (RandomState | None, optional): An instance of RandomState for random number generation. If None, a default random generator will be used. Defaults to None. + **kwargs: Additional keyword arguments. + + Raises: + RuntimeError: If the primitive model fails to be inserted into the Gazebo world. + TypeError: If the specified primitive type is not supported. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "primitive", + use_specific_primitive: (str | None) = None, + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + static: bool = False, + collision: bool = True, + visual: bool = True, + gui_only: bool = False, + np_random: RandomState | None = None, + **kwargs, + ): + if np_random is None: + np_random = np.random.default_rng() + + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + # Create SDF string for the model + sdf = self.get_sdf( + model_name=model_name, + use_specific_primitive=use_specific_primitive, + static=static, + collision=collision, + visual=visual, + gui_only=gui_only, + np_random=np_random, + ) + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) + + @classmethod + def get_sdf( + cls, + model_name: str, + use_specific_primitive: (str | None), + static: bool, + collision: bool, + visual: bool, + gui_only: bool, + np_random: RandomState, + ) -> str: + """ + Generates the SDF (Simulation Description Format) string for the specified primitive model. + + This method can create the SDF representation for a box, cylinder, or sphere based on the provided parameters. + If a specific primitive type is not provided, one will be randomly selected. + + Args: + model_name (str): The name of the model being generated. + use_specific_primitive ((str | None)): The specific type of primitive to create ('box', 'cylinder', or 'sphere'). If None, a random primitive will be chosen. + static (bool): If True, the primitive will be static and immovable. + collision (bool): If True, the primitive will have collision properties. + visual (bool): If True, the primitive will have a visual representation. + gui_only (bool): If True, the visual representation will only appear in the GUI and will not affect physics. + np_random (RandomState): An instance of RandomState for random number generation. + + Returns: + str: The SDF string that defines the specified primitive model, including its physical and visual properties. + + Raises: + TypeError: If the specified primitive type is not supported. + + """ + if use_specific_primitive is not None: + primitive = use_specific_primitive + else: + primitive = np_random.choice(["box", "cylinder", "sphere"]) + + mass = np_random.uniform(0.05, 0.25) + friction = np_random.uniform(0.75, 1.5) + color = tuple(np_random.uniform(0.0, 1.0, (3,))) + color: tuple[float, float, float, float] = (color[0], color[1], color[2], 1.0) + + if "box" == primitive: + return Box.get_sdf( + model_name=model_name, + size=tuple(np_random.uniform(0.04, 0.06, (3,))), + mass=mass, + static=static, + collision=collision, + friction=friction, + visual=visual, + gui_only=gui_only, + color=color, + ) + elif "cylinder" == primitive: + return Cylinder.get_sdf( + model_name=model_name, + radius=np_random.uniform(0.01, 0.0375), + length=np_random.uniform(0.025, 0.05), + mass=mass, + static=static, + collision=collision, + friction=friction, + visual=visual, + gui_only=gui_only, + color=color, + ) + elif "sphere" == primitive: + return Sphere.get_sdf( + model_name=model_name, + radius=np_random.uniform(0.01, 0.0375), + mass=mass, + static=static, + collision=collision, + friction=friction, + visual=visual, + gui_only=gui_only, + color=color, + ) + else: + raise TypeError( + f"Type '{use_specific_primitive}' in not a supported primitive. Pleasure use 'box', 'cylinder' or 'sphere." + ) diff --git a/env_manager/env_manager/env_manager/models/robots/__init__.py b/env_manager/env_manager/env_manager/models/robots/__init__.py new file mode 100644 index 0000000..c690e8c --- /dev/null +++ b/env_manager/env_manager/env_manager/models/robots/__init__.py @@ -0,0 +1,16 @@ +from enum import Enum + +from .rbs_arm import RbsArm +from .robot import RobotWrapper + + +class RobotEnum(Enum): + RBS_ARM = "rbs_arm" + + +def get_robot_model_class(robot_model: str) -> type[RobotWrapper]: + model_mapping = { + RobotEnum.RBS_ARM.value: RbsArm, + } + + return model_mapping.get(robot_model, RbsArm) diff --git a/env_manager/env_manager/env_manager/models/robots/rbs_arm.py b/env_manager/env_manager/env_manager/models/robots/rbs_arm.py new file mode 100644 index 0000000..d724946 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/robots/rbs_arm.py @@ -0,0 +1,153 @@ +from gym_gz.utils.scenario import get_unique_model_name +from rclpy.node import Node +from robot_builder.parser.urdf import URDF_parser +from robot_builder.elements.robot import Robot + +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + +from .robot import RobotWrapper + + +class RbsArm(RobotWrapper): + """ + A class representing a robotic arm built using the `robot_builder` library. + + This class is responsible for creating a robotic arm model within a Gazebo simulation environment. + It allows for the manipulation of joint positions for both the arm and the gripper. + + Attributes: + - DEFAULT_ARM_JOINT_POSITIONS (list[float]): The default joint positions for the arm. + - OPEN_GRIPPER_JOINT_POSITIONS (list[float]): The joint positions for the gripper in the open state. + - CLOSED_GRIPPER_JOINT_POSITIONS (list[float]): The joint positions for the gripper in the closed state. + - DEFAULT_GRIPPER_JOINT_POSITIONS (list[float]): The default joint positions for the gripper. + + Args: + - world (scenario_gazebo.World): The Gazebo world where the robot model will be inserted. + - node (Node): The ROS2 node associated with the robotic arm. + - urdf_string (str): The URDF string defining the robot. + - name (str, optional): The name of the robotic arm. Defaults to "rbs_arm". + - position (tuple[float, float, float], optional): The initial position of the robot in the world. Defaults to (0.0, 0.0, 0.0). + - orientation (tuple[float, float, float, float], optional): The initial orientation of the robot in quaternion format. Defaults to (1.0, 0, 0, 0). + - initial_arm_joint_positions (list[float] | float, optional): The initial joint positions for the arm. Defaults to `DEFAULT_ARM_JOINT_POSITIONS`. + - initial_gripper_joint_positions (list[float] | float, optional): The initial joint positions for the gripper. Defaults to `DEFAULT_GRIPPER_JOINT_POSITIONS`. + + Raises: + - RuntimeError: If the robot model fails to be inserted into the Gazebo world. + """ + + DEFAULT_ARM_JOINT_POSITIONS: list[float] = [ + 0.0, + 0.5, + 3.14159, + 1.5, + 0.0, + 1.4, + 0.0, + ] + + OPEN_GRIPPER_JOINT_POSITIONS: list[float] = [ + 0.064, + ] + + CLOSED_GRIPPER_JOINT_POSITIONS: list[float] = [ + 0.0, + ] + + DEFAULT_GRIPPER_JOINT_POSITIONS: list[float] = CLOSED_GRIPPER_JOINT_POSITIONS + + def __init__( + self, + world: scenario_gazebo.World, + node: Node, + urdf_string: str, + name: str = "rbs_arm", + position: tuple[float, float, float] = (0.0, 0.0, 0.0), + orientation: tuple[float, float, float, float] = (1.0, 0, 0, 0), + initial_arm_joint_positions: list[float] | float = DEFAULT_ARM_JOINT_POSITIONS, + initial_gripper_joint_positions: list[float] + | float = DEFAULT_GRIPPER_JOINT_POSITIONS, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Setup initial pose + initial_pose = scenario_core.Pose(position, orientation) + + # Insert the model + ok_model = world.insert_model_from_string(urdf_string, initial_pose, model_name) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Parse robot to get metadata + self._robot: Robot = URDF_parser.load_string(urdf_string) + + self.__initial_gripper_joint_positions = ( + [float(initial_gripper_joint_positions)] + * len(self._robot.gripper_actuated_joint_names) + if isinstance(initial_gripper_joint_positions, (int, float)) + else initial_gripper_joint_positions + ) + + self.__initial_arm_joint_positions = ( + [float(initial_arm_joint_positions)] * len(self._robot.actuated_joint_names) + if isinstance(initial_arm_joint_positions, (int, float)) + else initial_arm_joint_positions + ) + + # Set initial joint configuration + self.set_initial_joint_positions(model) + super().__init__(model=model) + + @property + def robot(self) -> Robot: + """Returns the robot metadata parsed from the URDF string. + + Returns: + Robot: The robot instance containing metadata. + """ + return self._robot + + @property + def initial_arm_joint_positions(self) -> list[float]: + """Returns the initial joint positions for the arm. + + Returns: + list[float]: The initial joint positions for the arm. + """ + return self.__initial_arm_joint_positions + + @property + def initial_gripper_joint_positions(self) -> list[float]: + """Returns the initial joint positions for the gripper. + + Returns: + list[float]: The initial joint positions for the gripper. + """ + return self.__initial_gripper_joint_positions + + def set_initial_joint_positions(self, model): + """Sets the initial positions for the robot's joints. + + This method resets the joint positions of both the arm and gripper to their specified initial values. + + Args: + model: The model representation of the robot within the Gazebo environment. + + Raises: + RuntimeError: If resetting the joint positions fails for any of the joints. + """ + model = model.to_gazebo() + + joint_position_data = [ + (self.__initial_arm_joint_positions, self._robot.actuated_joint_names), + (self.__initial_gripper_joint_positions, self._robot.gripper_actuated_joint_names), + ] + + for positions, joint_names in joint_position_data: + print(f"Setting joint positions for {joint_names}: {positions}") + if not model.reset_joint_positions(positions, joint_names): + raise RuntimeError(f"Failed to set initial positions of {joint_names}'s joints") diff --git a/env_manager/env_manager/env_manager/models/robots/robot.py b/env_manager/env_manager/env_manager/models/robots/robot.py new file mode 100644 index 0000000..fdfa062 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/robots/robot.py @@ -0,0 +1,95 @@ +from gym_gz.scenario import model_wrapper +from robot_builder.elements.robot import Robot +from scenario import gazebo as scenario_gazebo +from abc import abstractmethod + + +class RobotWrapper(model_wrapper.ModelWrapper): + """ + An abstract base class for robot models in a Gazebo simulation. + + This class extends the ModelWrapper from gym_gz and provides a structure for creating + robot models with specific configurations, including joint positions for arms and grippers. + + Args: + world (scenario_gazebo.World, optional): The Gazebo world where the robot model will be inserted. + urdf_string (str, optional): The URDF (Unified Robot Description Format) string defining the robot. + name (str, optional): The name of the robot model. Defaults to None. + position (tuple[float, float, float], optional): The initial position of the robot in the world. Defaults to None. + orientation (tuple[float, float, float, float], optional): The initial orientation of the robot in quaternion format. Defaults to None. + initial_arm_joint_positions (list[float] | float, optional): The initial joint positions for the arm. Defaults to an empty list. + initial_gripper_joint_positions (list[float] | float, optional): The initial joint positions for the gripper. Defaults to an empty list. + model (model_wrapper.ModelWrapper, optional): An existing model instance to initialize the wrapper. Must be provided. + **kwargs: Additional keyword arguments. + + Raises: + ValueError: If the model parameter is not provided. + """ + def __init__( + self, + world: scenario_gazebo.World | None = None, + urdf_string: str | None = None, + name: str | None = None, + position: tuple[float, float, float] | None = None, + orientation: tuple[float, float, float, float] | None = None, + initial_arm_joint_positions: list[float] | float = [], + initial_gripper_joint_positions: list[float] | float = [], + model: model_wrapper.ModelWrapper | None = None, + **kwargs, + ): + if model is not None: + super().__init__(model=model) + else: + raise ValueError("Model should be defined for the parent class") + + @property + @abstractmethod + def robot(self) -> Robot: + """Returns the robot instance containing metadata. + + This property must be implemented by subclasses to return the specific robot metadata. + + Returns: + Robot: The robot instance. + """ + pass + + + @property + @abstractmethod + def initial_gripper_joint_positions(self) -> list[float]: + """Returns the initial joint positions for the gripper. + + This property must be implemented by subclasses to provide the initial positions of the gripper joints. + + Returns: + list[float]: The initial joint positions for the gripper. + """ + pass + + @property + @abstractmethod + def initial_arm_joint_positions(self) -> list[float]: + """Returns the initial joint positions for the arm. + + This property must be implemented by subclasses to provide the initial positions of the arm joints. + + Returns: + list[float]: The initial joint positions for the arm. + """ + pass + + @abstractmethod + def set_initial_joint_positions(self, model): + """Sets the initial positions for the robot's joints. + + This method must be implemented by subclasses to reset the joint positions of the robot + to their specified initial values. + + Args: + model: The model representation of the robot within the Gazebo environment. + + Raises: + RuntimeError: If resetting the joint positions fails. + """ + pass diff --git a/env_manager/env_manager/env_manager/models/sensors/__init__.py b/env_manager/env_manager/env_manager/models/sensors/__init__.py new file mode 100644 index 0000000..0cbbbe7 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/sensors/__init__.py @@ -0,0 +1 @@ +from .camera import Camera diff --git a/env_manager/env_manager/env_manager/models/sensors/camera.py b/env_manager/env_manager/env_manager/models/sensors/camera.py new file mode 100644 index 0000000..63e1616 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/sensors/camera.py @@ -0,0 +1,436 @@ +import os +from threading import Thread + +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + +# from env_manager.models.utils import ModelCollectionRandomizer + + +class Camera(model_wrapper.ModelWrapper): + """ + Represents a camera model in a Gazebo simulation. + + This class extends the ModelWrapper from gym_gz to define a camera model that can be inserted + into the Gazebo world. It supports different types of cameras and can bridge ROS2 topics for + camera data. + + Args: + world (scenario_gazebo.World): The Gazebo world where the camera will be inserted. + name (str, optional): The name of the camera model. If None, a unique name is generated. + position (tuple[float, float, float], optional): The initial position of the camera. Defaults to (0, 0, 0). + orientation (tuple[float, float, float, float], optional): The initial orientation of the camera in quaternion. Defaults to (1, 0, 0, 0). + static (bool, optional): If True, the camera is a static model. Defaults to True. + camera_type (str, optional): The type of camera to create. Defaults to "rgbd_camera". + width (int, optional): The width of the camera image. Defaults to 212. + height (int, optional): The height of the camera image. Defaults to 120. + image_format (str, optional): The format of the camera image. Defaults to "R8G8B8". + update_rate (int, optional): The update rate for the camera sensor. Defaults to 15. + horizontal_fov (float, optional): The horizontal field of view for the camera. Defaults to 1.567821. + vertical_fov (float, optional): The vertical field of view for the camera. Defaults to 1.022238. + clip_color (tuple[float, float], optional): The near and far clipping distances for color. Defaults to (0.02, 1000.0). + clip_depth (tuple[float, float], optional): The near and far clipping distances for depth. Defaults to (0.02, 10.0). + noise_mean (float, optional): The mean of the noise added to the camera images. Defaults to None. + noise_stddev (float, optional): The standard deviation of the noise added to the camera images. Defaults to None. + ros2_bridge_color (bool, optional): If True, a ROS2 bridge for color images is created. Defaults to False. + ros2_bridge_depth (bool, optional): If True, a ROS2 bridge for depth images is created. Defaults to False. + ros2_bridge_points (bool, optional): If True, a ROS2 bridge for point cloud data is created. Defaults to False. + visibility_mask (int, optional): The visibility mask for the camera sensor. Defaults to 0. + visual (str, optional): The type of visual representation for the camera. Defaults to None. + + Raises: + ValueError: If the visual mesh or textures cannot be found. + RuntimeError: If the camera model fails to insert into the Gazebo world. + """ + def __init__( + self, + world: scenario_gazebo.World, + name: str | None = None, + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + static: bool = True, + camera_type: str = "rgbd_camera", + width: int = 212, + height: int = 120, + image_format: str = "R8G8B8", + update_rate: int = 15, + horizontal_fov: float = 1.567821, + vertical_fov: float = 1.022238, + clip_color: tuple[float, float] = (0.02, 1000.0), + clip_depth: tuple[float, float] = (0.02, 10.0), + noise_mean: float | None = None, + noise_stddev: float | None = None, + ros2_bridge_color: bool = False, + ros2_bridge_depth: bool = False, + ros2_bridge_points: bool = False, + visibility_mask: int = 0, + visual: str | None = None, + ): + # Get a unique model name + if name is not None: + model_name = get_unique_model_name(world, name) + else: + model_name = get_unique_model_name(world, camera_type) + self._model_name = model_name + self._camera_type = camera_type + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + use_mesh: bool = False + mesh_path_visual: str = "" + + albedo_map = None + normal_map = None + roughness_map = None + metalness_map = None + # Get resources for visual (if enabled) + # if visual: + # if "intel_realsense_d435" == visual: + # use_mesh = True + # + # # Get path to the model and the important directories + # model_path = ModelCollectionRandomizer.get_collection_paths( + # owner="OpenRobotics", + # collection="", + # model_name="Intel RealSense D435", + # )[0] + # + # mesh_dir = os.path.join(model_path, "meshes") + # texture_dir = os.path.join(model_path, "materials", "textures") + # + # # Get path to the mesh + # mesh_path_visual = os.path.join(mesh_dir, "realsense.dae") + # # Make sure that it exists + # if not os.path.exists(mesh_path_visual): + # raise ValueError( + # f"Visual mesh '{mesh_path_visual}' for Camera model is not a valid file." + # ) + # + # # Find PBR textures + # if texture_dir: + # # List all files + # texture_files = os.listdir(texture_dir) + # + # # Extract the appropriate files + # for texture in texture_files: + # texture_lower = texture.lower() + # if "basecolor" in texture_lower or "albedo" in texture_lower: + # albedo_map = os.path.join(texture_dir, texture) + # elif "normal" in texture_lower: + # normal_map = os.path.join(texture_dir, texture) + # elif "roughness" in texture_lower: + # roughness_map = os.path.join(texture_dir, texture) + # elif ( + # "specular" in texture_lower or "metalness" in texture_lower + # ): + # metalness_map = os.path.join(texture_dir, texture) + # + # if not (albedo_map and normal_map and roughness_map and metalness_map): + # raise ValueError("Not all textures for Camera model were found.") + + # Create SDF string for the model + sdf = f''' + + {static} + + + {model_name} + true + {update_rate} + + + {width} + {height} + {image_format} + + {horizontal_fov} + {vertical_fov} + + {clip_color[0]} + {clip_color[1]} + + { + f""" + + {clip_depth[0]} + {clip_depth[1]} + + """ if "rgbd" in model_name else "" + } + { + f""" + gaussian + {noise_mean} + {noise_stddev} + """ if noise_mean is not None and noise_stddev is not None else "" + } + {visibility_mask} + + true + + { + f""" + + -0.01 0 0 0 1.5707963 0 + + + 0.02 + 0.02 + + + + 0.0 0.8 0.0 + 0.0 0.8 0.0 + 0.0 0.8 0.0 + + + + -0.05 0 0 0 0 0 + + + 0.06 0.05 0.05 + + + + 0.0 0.8 0.0 + 0.0 0.8 0.0 + 0.0 0.8 0.0 + + + """ if visual and not use_mesh else "" + } + { + f""" + + 0.0615752 + + 9.108e-05 + 0.0 + 0.0 + 2.51e-06 + 0.0 + 8.931e-05 + + + + 0 0 0 0 0 1.5707963 + + + {mesh_path_visual} + + RealSense +
false
+
+
+
+ + 1 1 1 1 + 1 1 1 1 + + + {albedo_map} + {normal_map} + {roughness_map} + {metalness_map} + + + +
+ """ if visual and use_mesh else "" + } + +
+
''' + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) + + if ros2_bridge_color or ros2_bridge_depth or ros2_bridge_points: + self.__threads = [] + self.__threads.append( + Thread( + target=self.construct_ros2_bridge, + args=( + self.info_topic, + "sensor_msgs/msg/CameraInfo", + "gz.msgs.CameraInfo", + ), + daemon=True, + ) + ) + if ros2_bridge_color: + self.__threads.append( + Thread( + target=self.construct_ros2_bridge, + args=( + self.color_topic, + "sensor_msgs/msg/Image", + "gz.msgs.Image", + ), + daemon=True, + ) + ) + + if ros2_bridge_depth: + self.__threads.append( + Thread( + target=self.construct_ros2_bridge, + args=( + self.depth_topic, + "sensor_msgs/msg/Image", + "gz.msgs.Image", + ), + daemon=True, + ) + ) + + if ros2_bridge_points: + self.__threads.append( + Thread( + target=self.construct_ros2_bridge, + args=( + self.points_topic, + "sensor_msgs/msg/PointCloud2", + "gz.msgs.PointCloudPacked", + ), + daemon=True, + ) + ) + + for thread in self.__threads: + thread.start() + + def __del__(self): + """Cleans up threads when the Camera object is deleted.""" + if hasattr(self, "__threads"): + for thread in self.__threads: + thread.join() + + @classmethod + def construct_ros2_bridge(cls, topic: str, ros_msg: str, ign_msg: str): + """ + Constructs and runs a ROS2 bridge command for a given topic. + + Args: + topic (str): The topic to bridge. + ros_msg (str): The ROS2 message type to use. + ign_msg (str): The Ignition message type to use. + """ + node_name = "parameter_bridge" + topic.replace("/", "_") + command = ( + f"ros2 run ros_gz_bridge parameter_bridge {topic}@{ros_msg}[{ign_msg} " + + f"--ros-args --remap __node:={node_name} --ros-args -p use_sim_time:=true" + ) + os.system(command) + + @classmethod + def get_frame_id(cls, model_name: str) -> str: + """ + Gets the frame ID for the camera model. + + Args: + model_name (str): The name of the camera model. + + Returns: + str: The frame ID. + """ + return f"{model_name}/{model_name}_link/camera" + + @property + def frame_id(self) -> str: + """Returns the frame ID of the camera.""" + return self.get_frame_id(self._model_name) + + @classmethod + def get_color_topic(cls, model_name: str, camera_type: str) -> str: + """ + Gets the color topic for the camera. + + Args: + model_name (str): The name of the camera model. + camera_type (str): The type of the camera. + + Returns: + str: The color topic. + """ + return f"/{model_name}/image" if "rgbd" in camera_type else f"/{model_name}" + + @property + def color_topic(self) -> str: + """Returns the color topic for the camera.""" + return self.get_color_topic(self._model_name, self._camera_type) + + @classmethod + def get_depth_topic(cls, model_name: str, camera_type: str) -> str: + """ + Gets the depth topic for the camera. + + Args: + model_name (str): The name of the camera model. + camera_type (str): The type of the camera. + + Returns: + str: The depth topic. + """ + return ( + f"/{model_name}/depth_image" if "rgbd" in camera_type else f"/{model_name}" + ) + + @property + def depth_topic(self) -> str: + """Returns the depth topic for the camera.""" + return self.get_depth_topic(self._model_name, self._camera_type) + + @classmethod + def get_points_topic(cls, model_name: str) -> str: + """ + Gets the points topic for the camera. + + Args: + model_name (str): The name of the camera model. + + Returns: + str: /{model_name}/points. + """ + return f"/{model_name}/points" + + @property + def points_topic(self) -> str: + """Returns the points topic for the camera.""" + return self.get_points_topic(self._model_name) + + @property + def info_topic(self): + """Returns the camera info topic.""" + return f"/{self._model_name}/camera_info" + + @classmethod + def get_link_name(cls, model_name: str) -> str: + """ + Gets the link name for the camera model. + + Args: + model_name (str): The name of the camera model. + + Returns: + str: {model_name}_link. + """ + return f"{model_name}_link" + + @property + def link_name(self) -> str: + """Returns the link name for the camera.""" + return self.get_link_name(self._model_name) diff --git a/env_manager/env_manager/env_manager/models/terrains/__init__.py b/env_manager/env_manager/env_manager/models/terrains/__init__.py new file mode 100644 index 0000000..9829bf3 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/terrains/__init__.py @@ -0,0 +1,25 @@ +# from gym_gz.scenario.model_wrapper import ModelWrapper + +from .ground import Ground + + +# def get_terrain_model_class(terrain_type: str) -> type[ModelWrapper]: +# # TODO: Refactor into enum +# +# if "flat" == terrain_type: +# return Ground +# elif "random_flat" == terrain_type: +# return RandomGround +# elif "lunar_heightmap" == terrain_type: +# return LunarHeightmap +# elif "lunar_surface" == terrain_type: +# return LunarSurface +# elif "random_lunar_surface" == terrain_type: +# return RandomLunarSurface +# else: +# raise AttributeError(f"Unsupported terrain [{terrain_type}]") +# +# +# def is_terrain_type_randomizable(terrain_type: str) -> bool: +# +# return "random_flat" == terrain_type or "random_lunar_surface" == terrain_type diff --git a/env_manager/env_manager/env_manager/models/terrains/ground.py b/env_manager/env_manager/env_manager/models/terrains/ground.py new file mode 100644 index 0000000..09044a3 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/terrains/ground.py @@ -0,0 +1,104 @@ +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + + +class Ground(model_wrapper.ModelWrapper): + """ + Represents a ground model in a Gazebo simulation. + + This class extends the ModelWrapper from gym_gz to define a ground model that can be + inserted into the Gazebo world. The ground is defined by its size, position, orientation, + and friction properties. + + Args: + world (scenario_gazebo.World): The Gazebo world where the ground will be inserted. + name (str, optional): The name of the ground model. Defaults to "ground". + position (tuple[float, float, float], optional): The initial position of the ground. Defaults to (0, 0, 0). + orientation (tuple[float, float, float, float], optional): The initial orientation of the ground in quaternion. Defaults to (1, 0, 0, 0). + size (tuple[float, float], optional): The size of the ground plane. Defaults to (1.5, 1.5). + collision_thickness (float, optional): The thickness of the collision surface. Defaults to 0.05. + friction (float, optional): The friction coefficient for the ground surface. Defaults to 5.0. + ambient (tuple[float, float, float, float], optional): The ambient color of the material. Defaults to (0.8, 0.8, 0.8, 1.0). + specular (tuple[float, float, float, float], optional): The specular color of the material. Defaults to (0.8, 0.8, 0.8, 1.0). + diffuse (tuple[float, float, float, float], optional): The diffuse color of the material. Defaults to (0.8, 0.8, 0.8, 1.0). + **kwargs: Additional keyword arguments for future extensions. + + Raises: + RuntimeError: If the ground model fails to insert into the Gazebo world. + """ + + def __init__( + self, + world: scenario_gazebo.World, + name: str = "ground", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + size: tuple[float, float] = (1.5, 1.5), + collision_thickness=0.05, + friction: float = 5.0, + ambient: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0), + specular: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0), + diffuse: tuple[float, float, float, float] = (0.8, 0.8, 0.8, 1.0), + **kwargs, + ): + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + # Create SDF string for the model with the provided material properties + sdf = f""" + + true + + + + + 0 0 1 + {size[0]} {size[1]} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + + + + 0 0 1 + {size[0]} {size[1]} + + + + {' '.join(map(str, ambient))} + {' '.join(map(str, diffuse))} + {' '.join(map(str, specular))} + + + + + """ + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) diff --git a/env_manager/env_manager/env_manager/models/terrains/random_ground.py b/env_manager/env_manager/env_manager/models/terrains/random_ground.py new file mode 100644 index 0000000..948b376 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/terrains/random_ground.py @@ -0,0 +1,137 @@ +import os +import numpy as np + +from gym_gz.scenario import model_wrapper +from gym_gz.utils.scenario import get_unique_model_name +from scenario import core as scenario_core +from scenario import gazebo as scenario_gazebo + +from rbs_assets_library import get_textures_path + + +class RandomGround(model_wrapper.ModelWrapper): + def __init__( + self, + world: scenario_gazebo.World, + name: str = "random_ground", + position: tuple[float, float, float] = (0, 0, 0), + orientation: tuple[float, float, float, float] = (1, 0, 0, 0), + size: tuple[float, float] = (1.0, 1.0), + friction: float = 5.0, + texture_dir: str | None = None, + **kwargs, + ): + + np_random = np.random.default_rng() + + # Get a unique model name + model_name = get_unique_model_name(world, name) + + # Initial pose + initial_pose = scenario_core.Pose(position, orientation) + + # Get textures from ENV variable if not directly specified + if not texture_dir: + texture_dir = os.environ.get("TEXTURE_DIRS", default="") + if not texture_dir: + texture_dir = get_textures_path() + + # Find random PBR texture + albedo_map = None + normal_map = None + roughness_map = None + metalness_map = None + if texture_dir: + if ":" in texture_dir: + textures = [] + for d in texture_dir.split(":"): + textures.extend([os.path.join(d, f) for f in os.listdir(d)]) + else: + # Get list of the available textures + textures = os.listdir(texture_dir) + + # Choose a random texture from these + random_texture_dir = str(np_random.choice(textures)) + + random_texture_dir = texture_dir + random_texture_dir + + # List all files + texture_files = os.listdir(random_texture_dir) + + # Extract the appropriate files + for texture in texture_files: + texture_lower = texture.lower() + if "color" in texture_lower or "albedo" in texture_lower: + albedo_map = os.path.join(random_texture_dir, texture) + elif "normal" in texture_lower: + normal_map = os.path.join(random_texture_dir, texture) + elif "roughness" in texture_lower: + roughness_map = os.path.join(random_texture_dir, texture) + elif "specular" in texture_lower or "metalness" in texture_lower: + metalness_map = os.path.join(random_texture_dir, texture) + + # Create SDF string for the model + sdf = f""" + + true + + + + + 0 0 1 + {size[0]} {size[1]} + + + + + + {friction} + {friction} + 0 0 0 + 0.0 + 0.0 + + + + + + + + 0 0 1 + {size[0]} {size[1]} + + + + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + + + {"%s" + % albedo_map if albedo_map is not None else ""} + {"%s" + % normal_map if normal_map is not None else ""} + {"%s" + % roughness_map if roughness_map is not None else ""} + {"%s" + % metalness_map if metalness_map is not None else ""} + + + + + + + """ + + # Insert the model + ok_model = world.to_gazebo().insert_model_from_string( + sdf, initial_pose, model_name + ) + if not ok_model: + raise RuntimeError("Failed to insert " + model_name) + + # Get the model + model = world.get_model(model_name) + + # Initialize base class + model_wrapper.ModelWrapper.__init__(self, model=model) diff --git a/env_manager/env_manager/env_manager/models/utils/__init__.py b/env_manager/env_manager/env_manager/models/utils/__init__.py new file mode 100644 index 0000000..1c438a3 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/utils/__init__.py @@ -0,0 +1,2 @@ +# from .model_collection_randomizer import ModelCollectionRandomizer +from .xacro2sdf import xacro2sdf diff --git a/env_manager/env_manager/env_manager/models/utils/model_collection_randomizer.py b/env_manager/env_manager/env_manager/models/utils/model_collection_randomizer.py new file mode 100644 index 0000000..428a7a5 --- /dev/null +++ b/env_manager/env_manager/env_manager/models/utils/model_collection_randomizer.py @@ -0,0 +1,1266 @@ +import glob +import os + +import numpy as np +import trimesh +from gym_gz.utils import logger +from numpy.random import RandomState +from pcg_gazebo.parsers import parse_sdf +from pcg_gazebo.parsers.sdf import create_sdf_element +from scenario import gazebo as scenario_gazebo + +# Note: only models with mesh geometry are supported + + +class ModelCollectionRandomizer: + """ + A class to randomize and manage paths for model collections in simulation environments. + + This class allows for the selection and management of model paths, enabling features + like caching, blacklisting, and random sampling from a specified collection. + + Attributes: + _class_model_paths (list[str] | None): Class-level cache for model paths. + __sdf_base_name (str): The base name for the SDF model file. + __configured_sdf_base_name (str): The base name for the modified SDF model file. + __blacklisted_base_name (str): The base name used for blacklisted models. + __collision_mesh_dir (str): Directory where collision meshes are stored. + __collision_mesh_file_type (str): The file type of the collision meshes. + __original_scale_base_name (str): The base name for original scale metadata files. + _unique_cache (bool): Whether to use a unique cache for model paths. + _enable_blacklisting (bool): Whether to enable blacklisting of unusable models. + _model_paths (list[str] | None): Instance-specific model paths if unique cache is enabled. + np_random (RandomState): Numpy random state for random sampling. + """ + + _class_model_paths = None + __sdf_base_name = "model.sdf" + __configured_sdf_base_name = "model_modified.sdf" + __blacklisted_base_name = "BLACKLISTED" + __collision_mesh_dir = "meshes/collision/" + __collision_mesh_file_type = "stl" + __original_scale_base_name = "original_scale.txt" + + def __init__( + self, + model_paths=None, + owner="GoogleResearch", + collection="Google Scanned Objects", + server="https://fuel.ignitionrobotics.org", + server_version="1.0", + unique_cache=False, + reset_collection=False, + enable_blacklisting=True, + np_random: RandomState | None = None, + ): + """ + Initializes the ModelCollectionRandomizer. + + Args: + model_paths (list[str] | None): List of model paths to initialize with. + owner (str): The owner of the model collection (default is "GoogleResearch"). + collection (str): The name of the model collection (default is "Google Scanned Objects"). + server (str): The server URL for the model collection (default is "https://fuel.ignitionrobotics.org"). + server_version (str): The server version to use (default is "1.0"). + unique_cache (bool): Whether to use a unique cache for the instance (default is False). + reset_collection (bool): Whether to reset the class-level model paths (default is False). + enable_blacklisting (bool): Whether to enable blacklisting of models (default is True). + np_random (RandomState | None): Numpy random state for random operations (default is None). + + Raises: + ValueError: If an invalid model path is provided. + """ + + # If enabled, the newly created objects of this class will use its own individual cache + # for model paths and must discover/download them on its own + self._unique_cache = unique_cache + + # Flag that determines if models that cannot be used are blacklisted + self._enable_blacklisting = enable_blacklisting + + # If enabled, the cache of the class used to store model paths among instances will be reset + if reset_collection and not self._unique_cache: + self._class_model_paths = None + + # Get file path to all models from + # a) `model_paths` arg + # b) local cache owner (if `owner` has some models, i.e `collection` is already downloaded) + # c) Fuel collection (if `owner` has no models in local cache) + if model_paths is not None: + # Use arg + if self._unique_cache: + self._model_paths = model_paths + else: + self._class_model_paths = model_paths + else: + # Use local cache or Fuel + if self._unique_cache: + self._model_paths = self.get_collection_paths( + owner=owner, + collection=collection, + server=server, + server_version=server_version, + ) + elif self._class_model_paths is None: + # Executed only once, unless the paths are reset with `reset_collection` arg + self._class_model_paths = self.get_collection_paths( + owner=owner, + collection=collection, + server=server, + server_version=server_version, + ) + + # Initialise rng with (with seed is desired) + if np_random is not None: + self.np_random = np_random + else: + self.np_random = np.random.default_rng() + + @classmethod + def get_collection_paths( + cls, + owner="GoogleResearch", + collection="Google Scanned Objects", + server="https://fuel.ignitionrobotics.org", + server_version="1.0", + model_name: str = "", + ) -> list[str]: + """ + Retrieves model paths from the local cache or downloads them from a specified server. + + The method first checks the local cache for models belonging to the specified owner. + If no models are found, it attempts to download the models from the specified Fuel server. + + Args: + cls: The class reference. + owner (str): The owner of the model collection (default is "GoogleResearch"). + collection (str): The name of the model collection (default is "Google Scanned Objects"). + server (str): The server URL for the model collection (default is "https://fuel.ignitionrobotics.org"). + server_version (str): The server version to use (default is "1.0"). + model_name (str): The name of the specific model to fetch (default is an empty string). + + Returns: + list[str]: A list of paths to the retrieved models. + + Raises: + RuntimeError: If the download command fails or if no models are found after the download. + """ + + # First check the local cache (for performance) + # Note: This unfortunately does not check if models belong to the specified collection + # TODO: Make sure models belong to the collection if sampled from local cache + model_paths = scenario_gazebo.get_local_cache_model_paths( + owner=owner, name=model_name + ) + if len(model_paths) > 0: + return model_paths + + # Else download the models from Fuel and then try again + if collection: + download_uri = "%s/%s/%s/collections/%s" % ( + server, + server_version, + owner, + collection, + ) + elif model_name: + download_uri = "%s/%s/%s/models/%s" % ( + server, + server_version, + owner, + model_name, + ) + download_command = 'ign fuel download -v 3 -t model -j %s -u "%s"' % ( + os.cpu_count(), + download_uri, + ) + os.system(download_command) + + model_paths = scenario_gazebo.get_local_cache_model_paths( + owner=owner, name=model_name + ) + if 0 == len(model_paths): + logger.error( + 'URI "%s" is not valid and does not contain any models that are \ + owned by the owner of the collection' + % download_uri + ) + pass + + return model_paths + + def random_model( + self, + min_scale=0.125, + max_scale=0.175, + min_mass=0.05, + max_mass=0.25, + min_friction=0.75, + max_friction=1.5, + decimation_fraction_of_visual=0.25, + decimation_min_faces=40, + decimation_max_faces=200, + max_faces=40000, + max_vertices=None, + component_min_faces_fraction=0.1, + component_max_volume_fraction=0.35, + fix_mtl_texture_paths=True, + skip_blacklisted=True, + return_sdf_path=True, + ) -> str: + """ + Selects and configures a random model from the collection. + + The method attempts to find a valid model, applying randomization + and returning the path to the configured SDF file or the model directory. + + Args: + min_scale (float): Minimum scale for the model. + max_scale (float): Maximum scale for the model. + min_mass (float): Minimum mass for the model. + max_mass (float): Maximum mass for the model. + min_friction (float): Minimum friction coefficient. + max_friction (float): Maximum friction coefficient. + decimation_fraction_of_visual (float): Fraction of visual decimation. + decimation_min_faces (int): Minimum number of faces for decimation. + decimation_max_faces (int): Maximum number of faces for decimation. + max_faces (int): Maximum faces for the model. + max_vertices (int, optional): Maximum vertices for the model. + component_min_faces_fraction (float): Minimum face fraction for components. + component_max_volume_fraction (float): Maximum volume fraction for components. + fix_mtl_texture_paths (bool): Whether to fix MTL texture paths. + skip_blacklisted (bool): Whether to skip blacklisted models. + return_sdf_path (bool): If True, return the configured SDF file path. + + Returns: + str: Path to the configured SDF file or model directory. + + Raises: + RuntimeError: If a valid model cannot be found after multiple attempts. + """ + + # Loop until a model is found, checked for validity, configured and returned + # If any of these steps fail, sample another model and try again + # Note: Due to this behaviour, the function could stall if all models are invalid + # TODO: Add a simple timeout to random sampling of valid model (# of attempts or time-based) + while True: + # Get path to a random model from the collection + model_path = self.get_random_model_path() + + # Check if the model is already blacklisted and skip if desired + if skip_blacklisted and self.is_blacklisted(model_path): + continue + + # Check is the model is already configured + if self.is_configured(model_path): + # If so, break the loop + break + + # Process the model and break loop only if it is valid + if self.process_model( + model_path, + decimation_fraction_of_visual=decimation_fraction_of_visual, + decimation_min_faces=decimation_min_faces, + decimation_max_faces=decimation_max_faces, + max_faces=max_faces, + max_vertices=max_vertices, + component_min_faces_fraction=component_min_faces_fraction, + component_max_volume_fraction=component_max_volume_fraction, + fix_mtl_texture_paths=fix_mtl_texture_paths, + ): + break + + # Apply randomization + self.randomize_configured_model( + model_path, + min_scale=min_scale, + max_scale=max_scale, + min_friction=min_friction, + max_friction=max_friction, + min_mass=min_mass, + max_mass=max_mass, + ) + + if return_sdf_path: + # Return path to the configured SDF file + return self.get_configured_sdf_path(model_path) + else: + # Return path to the model directory + return model_path + + def process_all_models( + self, + decimation_fraction_of_visual=0.025, + decimation_min_faces=8, + decimation_max_faces=400, + max_faces=40000, + max_vertices=None, + component_min_faces_fraction=0.1, + component_max_volume_fraction=0.35, + fix_mtl_texture_paths=True, + ): + """ + Processes all models in the collection, applying configuration and decimation. + + This method iterates over each model in the collection, applying visual decimation + and checking for validity. If a model cannot be processed successfully, it is + blacklisted. The method tracks and prints the number of processed models and + the count of blacklisted models at the end of execution. + + Args: + decimation_fraction_of_visual (float): Fraction of visual decimation to apply. + Default is 0.025. + decimation_min_faces (int): Minimum number of faces for decimation. Default is 8. + decimation_max_faces (int): Maximum number of faces for decimation. Default is 400. + max_faces (int): Maximum faces allowed for the model. Default is 40000. + max_vertices (int, optional): Maximum vertices allowed for the model. Default is None. + component_min_faces_fraction (float): Minimum face fraction for components during + processing. Default is 0.1. + component_max_volume_fraction (float): Maximum volume fraction for components during + processing. Default is 0.35. + fix_mtl_texture_paths (bool): Whether to fix MTL texture paths for the models. + Default is True. + + Returns: + None: This method modifies the models in place and does not return a value. + + Prints: + - The status of each processed model along with its index. + - The total number of blacklisted models after processing all models. + + Raises: + None: The method does not raise exceptions but may log issues if models cannot be + processed. + """ + if self._unique_cache: + model_paths = self._model_paths + else: + model_paths = self._class_model_paths + + blacklist_model_counter = 0 + for i in range(len(model_paths)): + if not self.process_model( + model_paths[i], + decimation_fraction_of_visual=decimation_fraction_of_visual, + decimation_min_faces=decimation_min_faces, + decimation_max_faces=decimation_max_faces, + max_faces=max_faces, + max_vertices=max_vertices, + component_min_faces_fraction=component_min_faces_fraction, + component_max_volume_fraction=component_max_volume_fraction, + fix_mtl_texture_paths=fix_mtl_texture_paths, + ): + blacklist_model_counter += 1 + + print('Processed model %i/%i "%s"' % (i, len(model_paths), model_paths[i])) + + print("Number of blacklisted models: %i" % blacklist_model_counter) + + def process_model( + self, + model_path, + decimation_fraction_of_visual=0.25, + decimation_min_faces=40, + decimation_max_faces=200, + max_faces=40000, + max_vertices=None, + component_min_faces_fraction=0.1, + component_max_volume_fraction=0.35, + fix_mtl_texture_paths=True, + ) -> bool: + """ + Processes a specified model to configure it for simulation, applying geometry + decimation and checking for various validity criteria. + + This method parses the SDF (Simulation Description Format) of the specified + model, processes its components, checks for geometry validity, and updates + inertial properties as needed. The processed model is then saved back to + an SDF file. + + Args: + model_path (str): Path to the model to be processed. + decimation_fraction_of_visual (float): Fraction of visual geometry to retain + during decimation. Default is 0.25. + decimation_min_faces (int): Minimum number of faces allowed in the decimated + visual geometry. Default is 40. + decimation_max_faces (int): Maximum number of faces allowed in the decimated + visual geometry. Default is 200. + max_faces (int): Maximum number of faces for the entire model. Default is 40000. + max_vertices (Optional[int]): Maximum number of vertices for the model. Default is None. + component_min_faces_fraction (float): Minimum face fraction for components to + be considered valid during processing. Default is 0.1. + component_max_volume_fraction (float): Maximum volume fraction for components + to be considered valid during processing. Default is 0.35. + fix_mtl_texture_paths (bool): Whether to fix texture paths in MTL files for + mesh formats. Default is True. + + Returns: + bool: Returns True if the model was processed successfully; otherwise, + returns False if any validity checks fail or if processing issues are encountered. + + Raises: + None: The method does not raise exceptions but may return False if processing + fails due to model invalidity. + + Notes: + - The model is blacklisted if it fails to meet the specified criteria for + geometry and inertial properties. + - The method updates the SDF of the model with processed geometries and + inertial properties. + """ + + # Parse the SDF of the model + sdf = parse_sdf(self.get_sdf_path(model_path)) + + # Process the model(s) contained in the SDF + for model in sdf.models: + # Process the link(s) of each model + for link in model.links: + # Get rid of the existing collisions prior to simplifying it + link.collisions.clear() + + # Values for the total inertial properties of current link + # These values will be updated for each body that the link contains + total_mass = 0.0 + total_inertia = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + common_centre_of_mass = [0.0, 0.0, 0.0] + + # Go through the visuals and process them + for visual in link.visuals: + # Get path to the mesh of the link's visual + mesh_path = self.get_mesh_path(model_path, visual) + + # If desired, fix texture path in 'mtl' files for '.obj' mesh format + if fix_mtl_texture_paths: + self.fix_mtl_texture_paths( + model_path, mesh_path, model.attributes["name"] + ) + + # Load the mesh (without materials) + mesh = trimesh.load(mesh_path, force="mesh", skip_materials=True) + + # Check if model has too much geometry (blacklist if needed) + if not self.check_excessive_geometry( + mesh, model_path, max_faces=max_faces, max_vertices=max_vertices + ): + return False + + # Check if model has disconnected geometry/components (blacklist if needed) + if not self.check_disconnected_components( + mesh, + model_path, + component_min_faces_fraction=component_min_faces_fraction, + component_max_volume_fraction=component_max_volume_fraction, + ): + return False + + # Compute inertial properties for this mesh + ( + total_mass, + total_inertia, + common_centre_of_mass, + ) = self.sum_inertial_properties( + mesh, total_mass, total_inertia, common_centre_of_mass + ) + + # Add decimated collision geometry to the SDF + self.add_collision( + mesh, + link, + model_path, + fraction_of_visual=decimation_fraction_of_visual, + min_faces=decimation_min_faces, + max_faces=decimation_max_faces, + ) + + # Write original scale (size) into the SDF + # This is used for later reference during randomization (for scale limits) + self.write_original_scale(mesh, model_path) + + # Make sure the link has valid inertial properties (blacklist if needed) + if not self.check_inertial_properties( + model_path, total_mass, total_inertia + ): + return False + + # Write inertial properties to the SDF of the link + self.write_inertial_properties( + link, total_mass, total_inertia, common_centre_of_mass + ) + + # Write the configured SDF into a file + sdf.export_xml(self.get_configured_sdf_path(model_path)) + return True + + def add_collision( + self, + mesh, + link, + model_path, + fraction_of_visual=0.05, + min_faces=8, + max_faces=750, + friction=1.0, + ): + """ + Adds collision geometry to a specified link in a model, using a simplified + version of the provided mesh. + + This method creates a collision mesh by decimating the visual geometry of + the model, generating an SDF (Simulation Description Format) element for + the collision, and adding it to the specified link. + + Args: + mesh (trimesh.Trimesh): The original mesh geometry from which the + collision geometry will be created. + link (Link): The link to which the collision geometry will be added. + model_path (str): The path to the model where the collision mesh will + be stored. + fraction_of_visual (float): The fraction of the visual mesh to retain + for the collision mesh. Default is 0.05. + min_faces (int): The minimum number of faces allowed in the collision + mesh after decimation. Default is 8. + max_faces (int): The maximum number of faces allowed in the collision + mesh after decimation. Default is 750. + friction (float): The friction coefficient to be applied to the collision + surface. Default is 1.0. + + Returns: + None: This method does not return a value. + + Notes: + - The collision mesh is generated through quadratic decimation of the + original visual mesh based on the specified parameters. + - The created collision mesh is stored in the same directory as the + model, and its path is relative to the model path. + - The method updates the link's SDF with the collision geometry and + surface properties. + """ + + # Determine name of path to the collision geometry + collision_name = ( + link.attributes["name"] + "_collision_" + str(len(link.collisions)) + ) + collision_mesh_path = self.get_collision_mesh_path(model_path, collision_name) + + # Determine number of faces to keep after the decimation + face_count = min( + max(fraction_of_visual * len(mesh.faces), min_faces), max_faces + ) + + # Simplify mesh via decimation + collision_mesh = mesh.simplify_quadratic_decimation(face_count) + + # Export the collision mesh to the appropriate location + os.makedirs(os.path.dirname(collision_mesh_path), exist_ok=True) + collision_mesh.export( + collision_mesh_path, file_type=self.__collision_mesh_file_type + ) + + # Create collision SDF element + collision = create_sdf_element("collision") + + # Add collision geometry to the SDF + collision.geometry.mesh = create_sdf_element("mesh") + collision.geometry.mesh.uri = os.path.relpath( + collision_mesh_path, start=model_path + ) + + # Add surface friction to the SDF of collision (default to 1 and randomize later) + collision.surface = create_sdf_element("surface") + collision.surface.friction = create_sdf_element("friction", "surface") + collision.surface.friction.ode = create_sdf_element("ode", "collision") + collision.surface.friction.ode.mu = friction + collision.surface.friction.ode.mu2 = friction + + # Add it to the SDF of the link + collision_name = os.path.basename(collision_mesh_path).split(".")[0] + link.add_collision(collision_name, collision) + + def sum_inertial_properties( + self, mesh, total_mass, total_inertia, common_centre_of_mass, density=1.0 + ) -> tuple[float, float, float]: + # Arbitrary density is used here + # The mass will be randomized once it is fully computed for a link + mesh.density = density + + # Tmp variable to store the mass of all previous geometry, used to determine centre of mass + mass_of_others = total_mass + + # For each additional mesh, simply add the mass and inertia + total_mass += mesh.mass + total_inertia += mesh.moment_inertia + + # Compute a common centre of mass between all previous geometry and the new mesh + common_centre_of_mass = [ + mass_of_others * common_centre_of_mass[0] + mesh.mass * mesh.center_mass[0], + mass_of_others * common_centre_of_mass[1] + mesh.mass * mesh.center_mass[1], + mass_of_others * common_centre_of_mass[2] + mesh.mass * mesh.center_mass[2], + ] / total_mass + + return total_mass, total_inertia, common_centre_of_mass + + def randomize_configured_model( + self, + model_path, + min_scale=0.05, + max_scale=0.25, + min_mass=0.1, + max_mass=3.0, + min_friction=0.75, + max_friction=1.5, + ): + """ + Randomizes the scale, mass, and friction properties of a configured model + in the specified SDF file. + + This method modifies the properties of each link in the model, applying + random values within specified ranges for scale, mass, and friction. + The randomized values overwrite the original settings in the configured + SDF file. + + Args: + model_path (str): The path to the model directory containing the + configured SDF file. + min_scale (float): The minimum scale factor for randomization. + Default is 0.05. + max_scale (float): The maximum scale factor for randomization. + Default is 0.25. + min_mass (float): The minimum mass for randomization. Default is 0.1. + max_mass (float): The maximum mass for randomization. Default is 3.0. + min_friction (float): The minimum friction coefficient for randomization. + Default is 0.75. + max_friction (float): The maximum friction coefficient for randomization. + Default is 1.5. + + Returns: + None: This method does not return a value. + + Notes: + - The configured SDF file is updated in place, meaning the original + properties are overwritten with the new randomized values. + - The randomization is applied to each link in the model, ensuring + a diverse range of physical properties. + """ + + # Get path to the configured SDF file + configured_sdf_path = self.get_configured_sdf_path(model_path) + + # Parse the configured SDF that needs to be randomized + sdf = parse_sdf(configured_sdf_path) + + # Process the model(s) contained in the SDF + for model in sdf.models: + # Process the link(s) of each model + for link in model.links: + # Randomize scale of the link + self.randomize_scale( + model_path, link, min_scale=min_scale, max_scale=max_scale + ) + + # Randomize inertial properties of the link + self.randomize_inertial(link, min_mass=min_mass, max_mass=max_mass) + + # Randomize friction of the link + self.randomize_friction( + link, min_friction=min_friction, max_friction=max_friction + ) + + # Overwrite the configured SDF file with randomized values + sdf.export_xml(configured_sdf_path) + + def randomize_scale(self, model_path, link, min_scale=0.05, max_scale=0.25): + """ + Randomizes the scale of a link's geometry within a specified range. + + This method modifies the scale of the visual and collision geometry + of a link in the model. The scale is randomized using a uniform + distribution within the given minimum and maximum scale values. + Additionally, the link's inertial properties (mass and inertia) + are recalculated based on the new scale. + + Args: + model_path (str): The path to the model directory, used to + read the original scale of the mesh. + link (Link): The link object containing the visual and collision + geometries to be scaled. + min_scale (float): The minimum scale factor for randomization. + Default is 0.05. + max_scale (float): The maximum scale factor for randomization. + Default is 0.25. + + Returns: + bool: Returns `False` if the link has more than one visual, + indicating that scaling is not supported. Returns `None` + if the scaling is successful. + + Notes: + - This method currently supports only links that contain a + single mesh geometry. + - The mass of the link is scaled by the cube of the scale factor, + while the inertial properties are scaled by the fifth power of + the scale factor. + - The scale values are applied uniformly to all dimensions (x, y, z). + """ + + # Note: This function currently supports only scaling of links with single mesh geometry + if len(link.visuals) > 1: + return False + + # Get a random scale for the size of mesh + random_scale = self.np_random.uniform(min_scale, max_scale) + + # Determine a scale factor that will result in such scale for the size of mesh + original_mesh_scale = self.read_original_scale(model_path) + scale_factor = random_scale / original_mesh_scale + + # Determine scale factor for inertial properties based on random scale and current scale + current_scale = link.visuals[0].geometry.mesh.scale.value[0] + inertial_scale_factor = scale_factor / current_scale + + # Write scale factor into SDF for visual and collision geometry + link.visuals[0].geometry.mesh.scale = [scale_factor] * 3 + link.collisions[0].geometry.mesh.scale = [scale_factor] * 3 + + # Recompute inertial properties according to the scale + link.inertial.pose.x *= inertial_scale_factor + link.inertial.pose.y *= inertial_scale_factor + link.inertial.pose.z *= inertial_scale_factor + + # Mass is scaled n^3 + link.mass = link.mass.value * inertial_scale_factor**3 + + # Inertia is scaled n^5 + inertial_scale_factor_n5 = inertial_scale_factor**5 + link.inertia.ixx = link.inertia.ixx.value * inertial_scale_factor_n5 + link.inertia.iyy = link.inertia.iyy.value * inertial_scale_factor_n5 + link.inertia.izz = link.inertia.izz.value * inertial_scale_factor_n5 + link.inertia.ixy = link.inertia.ixy.value * inertial_scale_factor_n5 + link.inertia.ixz = link.inertia.ixz.value * inertial_scale_factor_n5 + link.inertia.iyz = link.inertia.iyz.value * inertial_scale_factor_n5 + + def randomize_inertial( + self, link, min_mass=0.1, max_mass=3.0 + ) -> tuple[float, float]: + """ + Randomizes the mass and updates the inertial properties of a link. + + This method assigns a random mass to the link within the specified + range and recalculates its inertial properties based on the new mass. + + Args: + link (Link): The link object whose mass and inertial properties + will be randomized. + min_mass (float): The minimum mass for randomization. Default is 0.1. + max_mass (float): The maximum mass for randomization. Default is 3.0. + + Returns: + tuple[float, float]: A tuple containing the new mass and the + mass scale factor used to update the inertial properties. + + Notes: + - The method modifies the link's mass and updates its + inertial properties (ixx, iyy, izz, ixy, ixz, iyz) based on the + ratio of the new mass to the original mass. + """ + + random_mass = self.np_random.uniform(min_mass, max_mass) + mass_scale_factor = random_mass / link.mass.value + + link.mass = random_mass + link.inertia.ixx = link.inertia.ixx.value * mass_scale_factor + link.inertia.iyy = link.inertia.iyy.value * mass_scale_factor + link.inertia.izz = link.inertia.izz.value * mass_scale_factor + link.inertia.ixy = link.inertia.ixy.value * mass_scale_factor + link.inertia.ixz = link.inertia.ixz.value * mass_scale_factor + link.inertia.iyz = link.inertia.iyz.value * mass_scale_factor + + def randomize_friction(self, link, min_friction=0.75, max_friction=1.5): + """ + Randomizes the friction coefficients of a link's collision surfaces. + + This method assigns random friction values to each collision surface + of the link within the specified range. + + Args: + link (Link): The link object whose collision surfaces' friction + coefficients will be randomized. + min_friction (float): The minimum friction coefficient for + randomization. Default is 0.75. + max_friction (float): The maximum friction coefficient for + randomization. Default is 1.5. + + Notes: + - The friction coefficients are applied to both the + 'mu' and 'mu2' attributes of the collision surface's friction + properties. + """ + + for collision in link.collisions: + random_friction = self.np_random.uniform(min_friction, max_friction) + + collision.surface.friction.ode.mu = random_friction + collision.surface.friction.ode.mu2 = random_friction + + def write_inertial_properties(self, link, mass, inertia, centre_of_mass): + """ + Writes the specified mass and inertial properties to a link. + + This method updates the link's mass, inertia tensor, and + centre of mass based on the provided values. + + Args: + link (Link): The link object to be updated with new inertial properties. + mass (float): The mass to set for the link. + inertia (list[list[float]]): A 3x3 list representing the + inertia tensor of the link. + centre_of_mass (list[float]): A list containing the x, y, z + coordinates of the link's centre of mass. + + Notes: + - The method directly modifies the link's mass and inertia properties, + as well as its inertial pose, which is set to the specified centre of mass. + """ + + link.mass = mass + + link.inertia.ixx = inertia[0][0] + link.inertia.iyy = inertia[1][1] + link.inertia.izz = inertia[2][2] + link.inertia.ixy = inertia[0][1] + link.inertia.ixz = inertia[0][2] + link.inertia.iyz = inertia[1][2] + + link.inertial.pose = [ + centre_of_mass[0], + centre_of_mass[1], + centre_of_mass[2], + 0.0, + 0.0, + 0.0, + ] + + def write_original_scale(self, mesh, model_path): + """ + Writes the original scale of the mesh to a file. + + This method records the scale of the provided mesh into a text file, + which can be used later for reference during randomization or scaling + operations. + + Args: + mesh (Mesh): The mesh object whose original scale is to be recorded. + model_path (str): The path to the model directory where the + original scale file will be saved. + + Notes: + - The scale is written as a string representation of the mesh's + scale property. + - The file is created or overwritten at the location returned + by the `get_original_scale_path` method. + """ + + file = open(self.get_original_scale_path(model_path), "w") + file.write(str(mesh.scale)) + file.close() + + def read_original_scale(self, model_path) -> float: + """ + Reads the original scale of a model from a file. + + This method retrieves the original scale value that was previously + saved for a specific model. The scale is expected to be stored as + a string in a file, and this method converts it back to a float. + + Args: + model_path (str): The path to the model directory from which + to read the original scale. + + Returns: + float: The original scale of the model. + + Raises: + ValueError: If the contents of the scale file cannot be converted to a float. + """ + + file = open(self.get_original_scale_path(model_path), "r") + original_scale = file.read() + file.close() + + return float(original_scale) + + def check_excessive_geometry( + self, mesh, model_path, max_faces=40000, max_vertices=None + ) -> bool: + """ + Checks if the mesh exceeds the specified geometry limits. + + This method evaluates the number of faces and vertices in the + given mesh. If the mesh exceeds the defined limits for faces + or vertices, it blacklists the model and returns False. + + Args: + mesh (Mesh): The mesh object to be checked for excessive geometry. + model_path (str): The path to the model directory for logging purposes. + max_faces (int, optional): The maximum allowed number of faces. + Defaults to 40000. + max_vertices (int, optional): The maximum allowed number of vertices. + Defaults to None. + + Returns: + bool: True if the mesh does not exceed the limits; otherwise, False. + """ + + if max_faces is not None: + num_faces = len(mesh.faces) + if num_faces > max_faces: + self.blacklist_model( + model_path, reason="Excessive geometry (%d faces)" % num_faces + ) + return False + + if max_vertices is not None: + num_vertices = len(mesh.vertices) + if num_vertices > max_vertices: + self.blacklist_model( + model_path, reason="Excessive geometry (%d vertices)" % num_vertices + ) + return False + + return True + + def check_disconnected_components( + self, + mesh, + model_path, + component_min_faces_fraction=0.05, + component_max_volume_fraction=0.1, + ) -> bool: + """ + Checks for disconnected components within the mesh. + + This method identifies connected components within the mesh. If + there are multiple components, it evaluates their relative volumes. + If any component exceeds the specified volume fraction threshold, + the model is blacklisted. + + Args: + mesh (Mesh): The mesh object to be evaluated for disconnected components. + model_path (str): The path to the model directory for logging purposes. + component_min_faces_fraction (float, optional): The minimum fraction + of faces required for a component to be considered. Defaults to 0.05. + component_max_volume_fraction (float, optional): The maximum allowed + volume fraction for a component. Defaults to 0.1. + + Returns: + bool: True if the mesh has no significant disconnected components; + otherwise, False. + """ + + # Get a list of all connected components inside the mesh + # Consider components only with `component_min_faces_fraction` percent faces + min_faces = round(component_min_faces_fraction * len(mesh.faces)) + connected_components = trimesh.graph.connected_components( + mesh.face_adjacency, min_len=min_faces + ) + + # If more than 1 objects were detected, consider also relative volume of the meshes + if len(connected_components) > 1: + total_volume = mesh.volume + + large_component_counter = 0 + for component in connected_components: + submesh = mesh.copy() + mask = np.zeros(len(mesh.faces), dtype=np.bool) + mask[component] = True + submesh.update_faces(mask) + + volume_fraction = submesh.volume / total_volume + if volume_fraction > component_max_volume_fraction: + large_component_counter += 1 + + if large_component_counter > 1: + self.blacklist_model( + model_path, + reason="Disconnected components (%d instances)" + % len(connected_components), + ) + return False + + return True + + def check_inertial_properties(self, model_path, mass, inertia) -> bool: + """ + Checks the validity of inertial properties for a model. + + This method evaluates whether the mass and the principal moments of + inertia are valid. If any of these values are below a specified threshold, + the model is blacklisted. + + Args: + model_path (str): The path to the model directory for logging purposes. + mass (float): The mass of the model to be checked. + inertia (list of list of float): A 2D list representing the inertia matrix. + + Returns: + bool: True if the inertial properties are valid; otherwise, False. + """ + if ( + mass < 1e-10 + or inertia[0][0] < 1e-10 + or inertia[1][1] < 1e-10 + or inertia[2][2] < 1e-10 + ): + self.blacklist_model(model_path, reason="Invalid inertial properties") + return False + + return True + + def get_random_model_path(self) -> str: + """ + Retrieves a random model path from the available models. + + This method selects and returns a random model path. The selection + depends on whether a unique cache is being used or not. + + Returns: + str: A randomly selected model path. + """ + if self._unique_cache: + return self.np_random.choice(self._model_paths) + else: + return self.np_random.choice(self._class_model_paths) + + def get_collision_mesh_path(self, model_path, collision_name) -> str: + """ + Constructs the path for the collision mesh file. + + This method generates and returns the file path for a collision + mesh based on the model path and the name of the collision. + + Args: + model_path (str): The path to the model directory. + collision_name (str): The name of the collision geometry. + + Returns: + str: The full path to the collision mesh file. + """ + return os.path.join( + model_path, + self.__collision_mesh_dir, + collision_name + "." + self.__collision_mesh_file_type, + ) + + def get_sdf_path(self, model_path) -> str: + """ + Constructs the path for the SDF (Simulation Description Format) file. + + This method generates and returns the file path for the SDF file + corresponding to the given model path. + + Args: + model_path (str): The path to the model directory. + + Returns: + str: The full path to the SDF file. + """ + return os.path.join(model_path, self.__sdf_base_name) + + def get_configured_sdf_path(self, model_path) -> str: + """ + Constructs the path for the configured SDF file. + + This method generates and returns the file path for the configured + SDF file that includes any modifications made to the original model. + + Args: + model_path (str): The path to the model directory. + + Returns: + str: The full path to the configured SDF file. + """ + return os.path.join(model_path, self.__configured_sdf_base_name) + + def get_blacklisted_path(self, model_path) -> str: + """ + Constructs the path for the blacklisted model file. + + This method generates and returns the file path where information + about blacklisted models is stored. + + Args: + model_path (str): The path to the model directory. + + Returns: + str: The full path to the blacklisted model file. + """ + return os.path.join(model_path, self.__blacklisted_base_name) + + def get_mesh_path(self, model_path, visual_or_collision) -> str: + """ + Constructs the path for the mesh associated with a visual or collision. + + This method retrieves the URI from the specified visual or collision + object and constructs the full path to the corresponding mesh file. + + Args: + model_path (str): The path to the model directory. + visual_or_collision (object): An object containing the geometry + information for either visual or collision. + + Returns: + str: The full path to the mesh file. + + Notes: + This method may require adjustments for specific collections or models. + """ + # TODO: This might need fixing for certain collections/models + mesh_uri = visual_or_collision.geometry.mesh.uri.value + return os.path.join(model_path, mesh_uri) + + def get_original_scale_path(self, model_path) -> str: + """ + Constructs the path for the original scale file. + + This method generates and returns the file path for the file + that stores the original scale of the model. + + Args: + model_path (str): The path to the model directory. + + Returns: + str: The full path to the original scale file. + """ + return os.path.join(model_path, self.__original_scale_base_name) + + def blacklist_model(self, model_path, reason="Unknown"): + """ + Blacklists a model by writing its path and the reason to a file. + + If blacklisting is enabled, this method writes the specified reason + for blacklisting the model to a designated file. It also logs the + blacklisting action. + + Args: + model_path (str): The path to the model directory. + reason (str): The reason for blacklisting the model. Default is "Unknown". + """ + if self._enable_blacklisting: + bl_file = open(self.get_blacklisted_path(model_path), "w") + bl_file.write(reason) + bl_file.close() + logger.warn( + '%s model "%s". Reason: %s.' + % ( + "Blacklisting" if self._enable_blacklisting else "Skipping", + model_path, + reason, + ) + ) + + def is_blacklisted(self, model_path) -> bool: + """ + Checks if a model is blacklisted. + + This method checks if the blacklisted file for a given model exists, + indicating that the model has been blacklisted. + + Args: + model_path (str): The path to the model directory. + + Returns: + bool: True if the model is blacklisted; otherwise, False. + """ + return os.path.isfile(self.get_blacklisted_path(model_path)) + + def is_configured(self, model_path) -> bool: + """ + Checks if a model is configured. + + This method checks if the configured SDF file for a given model + exists, indicating that the model has been processed and configured. + + Args: + model_path (str): The path to the model directory. + + Returns: + bool: True if the model is configured; otherwise, False. + """ + return os.path.isfile(self.get_configured_sdf_path(model_path)) + + def fix_mtl_texture_paths(self, model_path, mesh_path, model_name): + """ + Fixes the texture paths in the associated MTL file of an OBJ model. + + This method modifies the texture paths in the MTL file associated + with a given OBJ file to ensure they are correctly linked to + the model's textures. It also ensures that the texture files are + uniquely named by prepending the model's name to avoid conflicts. + + Args: + model_path (str): The path to the model directory where textures are located. + mesh_path (str): The path to the OBJ file for which the MTL file needs fixing. + model_name (str): The name of the model, used to make texture file names unique. + + Notes: + - This method scans the model directory for texture files, identifies + the relevant MTL file, and updates any texture paths in that MTL file. + - If a texture file's name does not include the model's name, the method + renames it to include the model's name. + - The method only processes MTL files linked to OBJ files. + """ + # The `.obj` files use mtl + if mesh_path.endswith(".obj"): + # Find all textures located in the model path, used later to relative linking + texture_files = glob.glob(os.path.join(model_path, "**", "textures", "*.*")) + + # Find location of mtl file, if any + mtllib_file = None + with open(mesh_path, "r") as file: + for line in file: + if "mtllib" in line: + mtllib_file = line.split(" ")[-1].strip() + break + + if mtllib_file is not None: + mtllib_file = os.path.join(os.path.dirname(mesh_path), mtllib_file) + + fin = open(mtllib_file, "r") + data = fin.read() + for line in data.splitlines(): + if "map_" in line: + # Find the name of the texture/map in the mtl + map_file = line.split(" ")[-1].strip() + + # Find the first match of the texture/map file + for texture_file in texture_files: + if os.path.basename( + texture_file + ) == map_file or os.path.basename( + texture_file + ) == os.path.basename(map_file): + # Make the file unique to the model (unless it already is) + if model_name in texture_file: + new_texture_file_name = texture_file + else: + new_texture_file_name = texture_file.replace( + map_file, model_name + "_" + map_file + ) + + os.rename(texture_file, new_texture_file_name) + + # Apply the correct relative path + data = data.replace( + map_file, + os.path.relpath( + new_texture_file_name, + start=os.path.dirname(mesh_path), + ), + ) + break + fin.close() + + # Write in the correct data + fout = open(mtllib_file, "w") + fout.write(data) + fout.close() diff --git a/env_manager/env_manager/env_manager/models/utils/xacro2sdf.py b/env_manager/env_manager/env_manager/models/utils/xacro2sdf.py new file mode 100644 index 0000000..2a682ee --- /dev/null +++ b/env_manager/env_manager/env_manager/models/utils/xacro2sdf.py @@ -0,0 +1,37 @@ + +import subprocess +import tempfile +from typing import Dict, Optional, Tuple + +import xacro + + +def xacro2sdf( + input_file_path: str, mappings: Dict, model_path_remap: Optional[Tuple[str, str]] +) -> str: + """Convert xacro (URDF variant) with given arguments to SDF and return as a string.""" + + # Convert all values in mappings to strings + for keys, values in mappings.items(): + mappings[keys] = str(values) + + # Convert xacro to URDF + urdf_xml = xacro.process(input_file_name=input_file_path, mappings=mappings) + + # Create temporary file for URDF (`ign sdf -p` accepts only files) + with tempfile.NamedTemporaryFile() as tmp_urdf: + with open(tmp_urdf.name, "w") as urdf_file: + urdf_file.write(urdf_xml) + + # Convert to SDF + result = subprocess.run( + ["ign", "sdf", "-p", tmp_urdf.name], stdout=subprocess.PIPE + ) + sdf_xml = result.stdout.decode("utf-8") + + # Remap package name to model name, such that meshes can be located by Ignition + if model_path_remap is not None: + sdf_xml = sdf_xml.replace(model_path_remap[0], model_path_remap[1]) + + # Return as string + return sdf_xml diff --git a/env_manager/env_manager/env_manager/scene/__init__.py b/env_manager/env_manager/env_manager/scene/__init__.py new file mode 100644 index 0000000..421848f --- /dev/null +++ b/env_manager/env_manager/env_manager/scene/__init__.py @@ -0,0 +1 @@ +from .scene import Scene diff --git a/env_manager/env_manager/env_manager/scene/scene.py b/env_manager/env_manager/env_manager/scene/scene.py new file mode 100644 index 0000000..a42e4f8 --- /dev/null +++ b/env_manager/env_manager/env_manager/scene/scene.py @@ -0,0 +1,1480 @@ +from dataclasses import asdict + +import numpy as np +import rclpy +from geometry_msgs.msg import Point as RosPoint +from geometry_msgs.msg import Quaternion as RosQuat +from gym_gz.scenario.model_wrapper import ModelWrapper +from rbs_assets_library import get_model_meshes_info +from rcl_interfaces.srv import GetParameters +from rclpy.node import Node, Parameter +from scenario.bindings.gazebo import GazeboSimulator, PhysicsEngine_dart, World +from scipy.spatial import distance +from scipy.spatial.transform import Rotation +from visualization_msgs.msg import Marker, MarkerArray + +from env_manager.models.terrains.random_ground import RandomGround + +from ..models import Box, Camera, Cylinder, Ground, Mesh, Model, Plane, Sphere, Sun +from ..models.configs import CameraData, ObjectData, SceneData +from ..models.robots import get_robot_model_class +from ..utils import Tf2Broadcaster, Tf2DynamicBroadcaster +from ..utils.conversions import quat_to_wxyz +from ..utils.gazebo import ( + transform_move_to_model_pose, + transform_move_to_model_position, +) +from ..utils.types import Point, Pose + + +# TODO: Split scene randomizer and scene +class Scene: + """ + Manages the simulation scene during runtime, including object and environment randomization. + + The Scene class initializes and manages various components in a Gazebo simulation + environment, such as robots, cameras, lights, and terrain. It also provides methods + for scene setup, parameter retrieval, and scene initialization. + + Attributes: + gazebo (GazeboSimulator): An instance of the Gazebo simulator. + robot (RobotData): The robot data configuration. + cameras (list[CameraData]): List of camera configurations. + objects (list[ObjectData]): List of object configurations in the scene. + node (Node): The ROS 2 node for communication. + light (LightData): The light configuration for the scene. + terrain (TerrainData): The terrain configuration for the scene. + world (World): The Gazebo world instance. + """ + + def __init__( + self, + node: Node, + gazebo: GazeboSimulator, + scene: SceneData, + robot_urdf_string: str, + ) -> None: + """ + Initializes the Scene object with the necessary components and parameters. + + Args: + - node (Node): The ROS 2 node for managing communication and parameters. + - gazebo (GazeboSimulator): An instance of the Gazebo simulator to manage the simulation. + - scene (SceneData): A data object containing configuration for the scene. + - robot_urdf_string (str): The URDF string of the robot model. + """ + self.gazebo = gazebo + self.robot = scene.robot + self.cameras = scene.camera + self.objects = scene.objects + self.node = node + self.light = scene.light + self.terrain = scene.terrain + self.__scene_initialized = False + self.__plugin_scene_broadcaster = True + self.__plugin_user_commands = True + self.__fts_enable = True + self.__plugint_sensor_render_engine = "ogre2" + self.world: World = gazebo.get_world() + + self.__objects_unique_names = {} + self.object_models: list[ModelWrapper] = [] + self.__object_positions: dict[str, Point] = {} + self.__objects_workspace_centre: Point = (0.0, 0.0, 0.0) + + if not robot_urdf_string: + self.param_client = node.create_client( + GetParameters, "/robot_state_publisher/get_parameters" + ) + self.get_parameter_sync() + # scene.robot.urdf_string = ( + # node.get_parameter("robot_description") + # .get_parameter_value() + # .string_value + # ) + else: + scene.robot.urdf_string = robot_urdf_string + + self.tf2_broadcaster = Tf2Broadcaster(node=node) + self.tf2_broadcaster_dyn = Tf2DynamicBroadcaster(node=node) + + self.__workspace_center = (0.0, 0.0, 0.0) + self.__workspace_dimensions = (1.5, 1.5, 1.5) + + half_width = self.__workspace_dimensions[0] / 2 + half_height = self.__workspace_dimensions[1] / 2 + half_depth = self.__workspace_dimensions[2] / 2 + + self.__workspace_min_bound = ( + self.__workspace_center[0] - half_width, + self.__workspace_center[1] - half_height, + self.__workspace_center[2] - half_depth, + ) + self.__workspace_max_bound = ( + self.__workspace_center[0] + half_width, + self.__workspace_center[1] + half_height, + self.__workspace_center[2] + half_depth, + ) + + # self.marker_array = MarkerArray() + + def get_parameter_sync(self): + """ + Retrieves the `robot_description` parameter synchronously, waiting for the asynchronous call to complete. + """ + while not self.param_client.wait_for_service(timeout_sec=5.0): + self.node.get_logger().info( + "Service /robot_state_publisher/get_parameters is unavailable, waiting ..." + ) + + request = self.create_get_parameters_request() + future = self.param_client.call_async(request) + + rclpy.spin_until_future_complete(self.node, future) + + if future.result() is None: + raise RuntimeError("Failed to retrieve the robot_description parameter") + + response = future.result() + if not response.values or not response.values[0].string_value: + raise RuntimeError("The robot_description parameter is missing or empty") + + value = response.values[0] + self.node.get_logger().info( + "Succesfully got parameter response from robot_state_publisher" + ) + param = Parameter( + "robot_description", Parameter.Type.STRING, value.string_value + ) + self.robot.urdf_string = value.string_value + self.node.set_parameters([param]) + + def create_get_parameters_request(self) -> GetParameters.Request: + """ + Creates a request to get the parameters from the robot state publisher. + + Returns: + GetParameters.Request: A request object containing the parameter name to retrieve. + """ + return GetParameters.Request(names=["robot_description"]) + + def init_scene(self): + """ + Initializes the simulation scene by adding the terrain, lights, robots, and objects. + + This method pauses the Gazebo simulation, sets up the physics engine, and + configures the environment and components as defined in the SceneData. + + Note: + This method must be called after the Scene object has been initialized + and before running the simulation. + """ + + self.gazebo_paused_run() + + self.init_world_plugins() + + self.world.set_physics_engine(PhysicsEngine_dart) + + self.add_terrain() + self.add_light() + self.add_robot() + for obj in self.objects: + self.add_object(obj) + for camera in self.cameras: + self.add_camera(camera) + + self.__scene_initialized = True + + def reset_scene(self): + """ + Resets the scene to its initial state by resetting the robot joint positions + and randomizing the positions of objects and cameras. + + This method performs the following actions: + - Resets the joint positions of the robot. + - Clears previously stored object positions. + - For each object in the scene, if randomization is enabled, + randomizes its position; otherwise, resets it to its original pose. + - Randomizes the pose of enabled cameras if their pose has expired. + + Note: + This method assumes that the objects and cameras have been properly initialized + and that their properties (like randomization and pose expiration) are set correctly. + """ + self.reset_robot_joint_positions() + + object_randomized = set() + self.__object_positions.clear() + for object in self.objects: + if object.name in object_randomized: + continue + # TODO: Add color randomization (something that affects the runtime during the second spawn of an object) + # if object.randomize.random_color: + # self.randomize_object_color(object) + if object.randomize.random_pose: + self.randomize_object_position(object) + else: + self.reset_objects_pose(object) + object_randomized.add(object.name) + + for camera in self.cameras: + if camera.enable and self.is_camera_pose_expired(camera): + self.randomize_camera_pose(camera=camera) + + def init_world_plugins(self): + """ + Initializes and inserts world plugins into the Gazebo simulator for various functionalities. + + This method configures the following plugins based on the internal settings: + - **SceneBroadcaster**: Inserts the scene broadcaster plugin to enable GUI clients to receive + updates about the scene. This is only done if the scene broadcaster is enabled. + - **UserCommands**: Inserts the user commands plugin to allow user interactions with the scene. + - **Sensors**: Inserts the sensors plugin if any cameras are enabled. The rendering engine for + sensors can be specified. + - **ForceTorque**: Inserts the Force Torque Sensor plugin if enabled. + + The method pauses the Gazebo simulation after inserting plugins to ensure that changes take effect. + + Note: + If plugins are already active, this method does not reinsert them. + """ + # SceneBroadcaster + if self.__plugin_scene_broadcaster: + if not self.gazebo.scene_broadcaster_active(self.world.to_gazebo().name()): + self.node.get_logger().info( + "Inserting world plugins for broadcasting scene to GUI clients..." + ) + self.world.to_gazebo().insert_world_plugin( + "gz-sim-scene-broadcaster-system", + "gz::sim::systems::SceneBroadcaster", + ) + + self.gazebo_paused_run() + + # UserCommands + if self.__plugin_user_commands: + self.node.get_logger().info( + "Inserting world plugins to enable user commands..." + ) + self.world.to_gazebo().insert_world_plugin( + "gz-sim-user-commands-system", + "gz::sim::systems::UserCommands", + ) + + self.gazebo_paused_run() + + # Sensors + self._camera_enable = False + for camera in self.cameras: + if camera.enable: + self._camera_enable = True + + if self._camera_enable: + self.node.get_logger().info( + f"Inserting world plugins for sensors with {self.__plugint_sensor_render_engine} rendering engine..." + ) + self.world.to_gazebo().insert_world_plugin( + "libgz-sim-sensors-system.so", + "gz::sim::systems::Sensors", + "" + f"{self.__plugint_sensor_render_engine}" + "", + ) + + if self.__fts_enable: + self.world.to_gazebo().insert_world_plugin( + "gz-sim-forcetorque-system", + "gz::sim::systems::ForceTorque", + ) + + self.gazebo_paused_run() + + def add_robot(self): + """ + Configure and insert the robot into the simulation. + + This method performs the following actions: + - Retrieves the robot model class based on the robot's name. + - Instantiates the robot model with specified parameters including position, orientation, + URDF string, and initial joint positions for both the arm and gripper. + - Ensures that the instantiated robot's name is unique and returns the actual name used. + - Enables contact detection for all gripper links (fingers) of the robot. + - Pauses the Gazebo simulation to apply the changes. + - Resets the robot's joints to their default positions. + + Note: + The method assumes that the robot name is valid and that the associated model class + can be retrieved. If the robot's URDF string or joint positions are not correctly specified, + this may lead to errors during robot instantiation. + + Raises: + RuntimeError: If the robot model class cannot be retrieved for the specified robot name. + """ + + robot_model_class = get_robot_model_class(self.robot.name) + + # Instantiate robot class based on the selected model + self.robot_wrapper = robot_model_class( + world=self.world, + name=self.robot.name, + node=self.node, + position=self.robot.spawn_position, + orientation=quat_to_wxyz(self.robot.spawn_quat_xyzw), + urdf_string=self.robot.urdf_string, + initial_arm_joint_positions=self.robot.joint_positions, + initial_gripper_joint_positions=self.robot.gripper_joint_positions, + ) + + # The desired name is passed as arg on creation, however, a different name might be selected to be unique + # Therefore, return the name back to the task + self.robot_name = self.robot_wrapper.name() + + # Enable contact detection for all gripper links (fingers) + robot_gazebo = self.robot_wrapper.to_gazebo() + robot_gazebo.enable_contacts() + + self.gazebo_paused_run() + + # Reset robot joints to the defaults + self.reset_robot_joint_positions() + + def add_camera(self, camera: CameraData): + """ + Configure and insert a camera into the simulation, placing it with respect to the robot. + + The method performs the following steps: + - Determines the appropriate position and orientation of the camera based on whether it is + relative to the world or to the robot. + - Creates an instance of the `Camera` class with the specified parameters, including + position, orientation, type, dimensions, image formatting, update rate, field of view, + clipping properties, noise characteristics, and ROS 2 bridge publishing options. + - Pauses the Gazebo simulation to apply changes. + - Attaches the camera to the robot if the camera's reference is not the world. + - Broadcasts the transformation (TF) of the camera relative to the robot. + + Args: + - camera (CameraData): The camera data containing configuration options including: + + Raises: + - Exception: If the camera cannot be attached to the robot. + + Notes: + The method expects the camera data to be correctly configured. If the camera's + `relative_to` field does not match any existing model in the world, it will raise an + exception when attempting to attach the camera to the robot. + """ + + if self.world.to_gazebo().name() == camera.relative_to: + camera_position = camera.spawn_position + camera_quat_wxyz = quat_to_wxyz(camera.spawn_quat_xyzw) + else: + # Transform the pose of camera to be with respect to robot - but still represented in world reference frame for insertion into the world + camera_position, camera_quat_wxyz = transform_move_to_model_pose( + world=self.world, + position=camera.spawn_position, + quat=quat_to_wxyz(camera.spawn_quat_xyzw), + target_model=self.robot_wrapper, + target_link=camera.relative_to, + xyzw=False, + ) + + # Create camera + self.camera_wrapper = Camera( + name=camera.name, + world=self.world, + position=camera_position, + orientation=camera_quat_wxyz, + camera_type=camera.type, + width=camera.width, + height=camera.height, + image_format=camera.image_format, + update_rate=camera.update_rate, + horizontal_fov=camera.horizontal_fov, + vertical_fov=camera.vertical_fov, + clip_color=camera.clip_color, + clip_depth=camera.clip_depth, + noise_mean=camera.noise_mean, + noise_stddev=camera.noise_stddev, + ros2_bridge_color=camera.publish_color, + ros2_bridge_depth=camera.publish_depth, + ros2_bridge_points=camera.publish_points, + ) + + self.gazebo_paused_run() + + # Attach to robot if needed + if self.world.to_gazebo().name() != camera.relative_to: + if not self.robot_wrapper.to_gazebo().attach_link( + camera.relative_to, + self.camera_wrapper.name(), + self.camera_wrapper.link_name, + ): + raise Exception("Cannot attach camera link to robot") + + self.__is_camera_attached = True + + self.gazebo_paused_run() + + # Broadcast tf + self.tf2_broadcaster.broadcast_tf( + parent_frame_id=camera.relative_to, + child_frame_id=self.camera_wrapper.frame_id, + translation=camera.spawn_position, + rotation=camera.spawn_quat_xyzw, + xyzw=True, + ) + + def add_object(self, obj: ObjectData): + """ + Configure and insert an object into the simulation. + + Args: + obj (ObjectData): The object data containing configuration options. + + Raises: + NotImplementedError: If the specified type is not supported. + """ + obj_dict = asdict(obj) + obj_dict["world"] = self.world + + try: + object_wrapper = self._create_object_wrapper(obj, obj_dict) + self._register_object(obj, object_wrapper) + self._enable_contact_detection(object_wrapper) + + self.node.get_logger().info( + f"Added object: {obj.name}, relative_to: {obj.relative_to}" + ) + self.node.get_logger().info( + f"Position: {obj.position}, Orientation: {obj.orientation}" + ) + + self._initialize_tf_and_markers(obj, object_wrapper) + + except Exception as ex: + self.node.get_logger().warn(f"Model could not be inserted. Reason: {ex}") + + def _create_object_wrapper(self, obj: ObjectData, obj_dict: dict) -> ModelWrapper: + """Create the appropriate object wrapper based on the type.""" + if obj.type != "": + return self._create_by_type(obj.type, obj_dict) + else: + return self._create_by_data_class(obj, obj_dict) + + def _create_by_type(self, obj_type: str, obj_dict: dict) -> ModelWrapper: + match obj_type: + case "box": + return Box(**obj_dict) + case "plane": + return Plane(**obj_dict) + case "sphere": + return Sphere(**obj_dict) + case "cylinder": + return Cylinder(**obj_dict) + case "model": + return Model(**obj_dict) + case "mesh": + return Mesh(**obj_dict) + case _: + raise NotImplementedError(f"Unsupported object type: {obj_type}") + + def _create_by_data_class(self, obj: ObjectData, obj_dict: dict) -> ModelWrapper: + from ..models.configs import ( + BoxObjectData, + CylinderObjectData, + MeshData, + ModelData, + PlaneObjectData, + SphereObjectData, + ) + + match obj: + case BoxObjectData(): + return Box(**obj_dict) + case PlaneObjectData(): + return Plane(**obj_dict) + case SphereObjectData(): + return Sphere(**obj_dict) + case CylinderObjectData(): + return Cylinder(**obj_dict) + case ModelData(): + return Model(**obj_dict) + case MeshData(): + return Mesh(**obj_dict) + case _: + raise NotImplementedError("Unsupported object data class") + + def _register_object(self, obj: ObjectData, wrapper: ModelWrapper): + model_name = wrapper.name() + self.__objects_unique_names.setdefault(obj.name, []).append(model_name) + self.__object_positions[model_name] = obj.position + self.object_models.append(wrapper.model) + + def _enable_contact_detection(self, wrapper: ModelWrapper): + for link_name in wrapper.link_names(): + link = wrapper.to_gazebo().get_link(link_name=link_name) + link.enable_contact_detection(True) + + def _initialize_tf_and_markers(self, obj: ObjectData, wrapper: ModelWrapper): + self._setup_marker_array() + marker = self._create_marker(obj, wrapper, len(self.object_models)) + if marker: + self.marker_array.markers.append(marker) + # Start broadcasting once when len of markers == 1 + if len(self.marker_array.markers) == 1: + self._setup_tf_broadcaster() + + def _setup_tf_broadcaster(self): + if not hasattr(self, "tf_broadcaster_and_markers"): + self.tf_broadcaster_and_markers = self.node.create_timer( + 0.1, self.publish_marker + ) + + def _setup_marker_array(self): + if not hasattr(self, "marker_array"): + self.marker_array = MarkerArray() + self.marker_publisher = self.node.create_publisher( + MarkerArray, "/markers", 10 + ) + + def _create_marker( + self, obj: ObjectData, wrapper: ModelWrapper, id: int + ) -> Marker | None: + """Create a visualization marker based on the object type.""" + marker = Marker() + marker.header.frame_id = obj.relative_to + marker.header.stamp = self.node.get_clock().now().to_msg() + marker.ns = "assembly" + marker.id = id + + type_map = { + "mesh": Marker.MESH_RESOURCE, + "model": Marker.MESH_RESOURCE, + "sphere": Marker.SPHERE, + "box": Marker.CUBE, + "cylinder": Marker.CYLINDER, + } + + if obj.type not in type_map: + self.node.get_logger().fatal(f"Unsupported object type: {obj.type}") + return None + + marker.type = type_map[obj.type] + self._configure_marker_geometry_and_color(marker, obj) + + model_position, model_orientation = self.rotate_around_parent( + wrapper.base_position(), + wrapper.base_orientation(), + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0, 0.0), + ) + + position = RosPoint() + orientation = RosQuat() + position.x = model_position[0] + position.y = model_position[1] + position.z = model_position[2] + + orientation.w = model_orientation[0] + orientation.x = model_orientation[1] + orientation.y = model_orientation[2] + orientation.z = model_orientation[3] + + marker.pose.position = position + marker.pose.orientation = orientation + marker.action = Marker.ADD + + return marker + + def _configure_marker_geometry_and_color(self, marker: Marker, obj: ObjectData): + """Set marker geometry and color.""" + if obj.type in ["mesh", "model"]: + marker.mesh_resource = f"file://{get_model_meshes_info(obj.name)['visual']}" + marker.mesh_use_embedded_materials = True + marker.scale.x = marker.scale.y = marker.scale.z = 1.0 + elif obj.type == "sphere": + marker.scale.x = marker.scale.y = marker.scale.z = obj.radius + elif obj.type == "box": + marker.scale.x, marker.scale.y, marker.scale.z = obj.size + elif obj.type == "cylinder": + marker.scale.x = marker.scale.y = obj.radius * 2 + marker.scale.z = obj.length + + marker.color.r, marker.color.g, marker.color.b, marker.color.a = ( + 0.0, + 0.5, + 1.0, + 1.0, + ) + + def publish_marker(self): + if not self.marker_array: + return + + for idx, model in enumerate(self.object_models): + if self.marker_array.markers[idx]: + model_position, model_orientation = ( + model.base_position(), + model.base_orientation(), + ) + + position = RosPoint() + orientation = RosQuat() + position.x = model_position[0] + position.y = model_position[1] + position.z = model_position[2] + + orientation.w = model_orientation[0] + orientation.x = model_orientation[1] + orientation.y = model_orientation[2] + orientation.z = model_orientation[3] + self.marker_array.markers[idx].pose.position = position + self.marker_array.markers[idx].pose.orientation = orientation + + self.tf2_broadcaster_dyn.broadcast_tf( + parent_frame_id="world", + child_frame_id=model.base_frame(), + translation=model_position, + rotation=model_orientation, + xyzw=False, + ) + self.marker_publisher.publish(self.marker_array) + + def rotate_around_parent( + self, position, orientation, parent_position, parent_orientation + ): + """ + Apply a 180-degree rotation around the parent's Z-axis to the object's pose. + + Args: + position: Tuple (x, y, z) - the object's position relative to the parent. + orientation: Tuple (w, x, y, z) - the object's orientation relative to the parent. + parent_position: Tuple (x, y, z) - the parent's position. + parent_orientation: Tuple (w, x, y, z) - the parent's orientation. + + Returns: + A tuple containing the updated position and orientation: + (new_position, new_orientation) + """ + # Define a quaternion for 180-degree rotation around Z-axis (w, x, y, z) + rotation_180_z = (0.0, 0.0, 0.0, 1.0) + w1, x1, y1, z1 = orientation + # Кватернион для 180° поворота вокруг Z + w2, x2, y2, z2 = (0, 0, 0, 1) + + # Умножение кватернионов + w3 = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x3 = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y3 = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z3 = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + new_orientation = (w3, x3, y3, z3) + + # Compute the new orientation: parent_orientation * rotation_180_z * orientation + # new_orientation = ( + # parent_orientation[0] * rotation_180_z[0] - parent_orientation[1] * rotation_180_z[1] - + # parent_orientation[2] * rotation_180_z[2] - parent_orientation[3] * rotation_180_z[3], + # parent_orientation[0] * rotation_180_z[1] + parent_orientation[1] * rotation_180_z[0] + + # parent_orientation[2] * rotation_180_z[3] - parent_orientation[3] * rotation_180_z[2], + # parent_orientation[0] * rotation_180_z[2] - parent_orientation[1] * rotation_180_z[3] + + # parent_orientation[2] * rotation_180_z[0] + parent_orientation[3] * rotation_180_z[1], + # parent_orientation[0] * rotation_180_z[3] + parent_orientation[1] * rotation_180_z[2] - + # parent_orientation[2] * rotation_180_z[1] + parent_orientation[3] * rotation_180_z[0], + # ) + + # Apply rotation to the position + rotated_position = ( + -position[0], # Rotate 180 degrees around Z-axis: x -> -x + -position[1], # y -> -y + position[2], # z remains unchanged + ) + + # Add parent position + new_position = ( + parent_position[0] + rotated_position[0], + parent_position[1] + rotated_position[1], + parent_position[2] + rotated_position[2], + ) + + return new_position, new_orientation + + def add_light(self): + """ + Configure and insert a default light source into the simulation. + + This method initializes and adds a light source to the Gazebo simulation based on the + specified light type. Currently, it supports a "sun" type light. If an unsupported + light type is specified, it raises an error. + + Raises: + RuntimeError: If the specified light type is unsupported, including: + - "random_sun": Currently not implemented. + - Any other unrecognized light type. + + Notes: + - The method constructs the light using the properties defined in the `self.light` object, + which should include direction, color, distance, visual properties, and radius, + as applicable for the specified light type. + - After adding the light to the simulation, the method pauses the Gazebo simulation + to ensure the scene is updated. + """ + + # Create light + match self.light.type: + case "sun": + self.light = Sun( + self.world, + direction=self.light.direction, + color=self.light.color, + distance=self.light.distance, + visual=self.light.visual, + radius=self.light.radius, + ) + case "random_sun": + raise RuntimeError("random_sun is not supported yet") + case _: + raise RuntimeError( + f"Type of light [{self.light.type}] currently not supported" + ) + + self.gazebo_paused_run() + + def add_terrain(self): + """ + Configure and insert default terrain into the simulation. + + This method initializes and adds a terrain element to the Gazebo simulation using + the properties specified in the `self.terrain` object. Currently, it creates a + ground plane, but there are plans to support additional terrain types in the future. + + Raises: + NotImplementedError: If there are attempts to add terrain types other than + the default ground. + + Notes: + - The method utilizes properties such as name, spawn position, spawn quaternion + (for orientation), and size defined in `self.terrain` to create the terrain. + - After the terrain is added, contact detection is enabled for all links associated + with the terrain object. + - The method pauses the Gazebo simulation after adding the terrain to ensure + changes are reflected in the environment. + """ + + if self.terrain.type == "flat": + self.terrain_wrapper = Ground( + self.world, + name=self.terrain.name, + position=self.terrain.spawn_position, + orientation=quat_to_wxyz(self.terrain.spawn_quat_xyzw), + size=self.terrain.size, + ambient=self.terrain.ambient, + diffuse=self.terrain.diffuse, + specular=self.terrain.specular, + ) + elif self.terrain.type == "random_flat": + self.terrain_wrapper = RandomGround( + self.world, + name=self.terrain.name, + position=self.terrain.spawn_position, + orientation=quat_to_wxyz(self.terrain.spawn_quat_xyzw), + size=self.terrain.size, + ) + else: + raise ValueError( + "Type of ground is not supported, supported type is: [ground, random_ground]" + ) + + # Enable contact detection + for link_name in self.terrain_wrapper.link_names(): + link = self.terrain_wrapper.to_gazebo().get_link(link_name=link_name) + link.enable_contact_detection(True) + + self.gazebo_paused_run() + + # def randomize_object_color(self, object: ObjectData): + # # debugpy.listen(5678) + # # debugpy.wait_for_client() + # task_object_names = self.__objects_unique_names[object.name] + # for object_name in task_object_names: + # if not self.world.to_gazebo().remove_model(object_name): + # self.node.get_logger().error(f"Failed to remove {object_name}") + # raise RuntimeError(f"Failed to remove {object_name}") + # + # self.gazebo_paused_run() + # del self.__objects_unique_names[object.name] + # self.add_object(object) + + def randomize_object_position(self, object: ObjectData): + """ + Randomize the position and/or orientation of a specified object in the simulation. + + This method retrieves the names of the object instances from the scene and updates + their positions and orientations based on the randomization settings defined in + the `object` parameter. The method applies random transformations to the object's + pose according to the specified flags in the `ObjectData` instance. + + Args: + object (ObjectData): An instance of `ObjectData` containing the properties + of the object to be randomized. + + Raises: + KeyError: If the object name is not found in the unique object names dictionary. + + Notes: + - If `random_pose` is `True`, both position and orientation are randomized. + - If `random_orientation` is `True`, only the orientation is randomized, while + the position remains unchanged. + - If `random_position` is `True`, only the position is randomized, while the + orientation remains unchanged. + """ + task_object_names = self.__objects_unique_names[object.name] + for object_name in task_object_names: + position, quat_random = self.get_random_object_pose(object) + obj = self.world.to_gazebo().get_model(object_name).to_gazebo() + if object.randomize.random_pose: + obj.reset_base_pose(position, quat_random) + elif object.randomize.random_orientation: + obj.reset_base_pose(object.position, quat_random) + elif object.randomize.random_position: + obj.reset_base_pose(position, object.orientation) + + self.__object_positions[object_name] = position + + self.gazebo_paused_run() + + def randomize_camera_pose(self, camera: CameraData): + """ + Randomize the pose of a specified camera in the simulation. + + This method updates the camera's position and orientation based on the selected + randomization mode. The new pose is computed relative to either the world frame or + a specified target model (e.g., the robot). The camera can be detached from the robot + during the pose update and reattached afterward if needed. + + Args: + camera (CameraData): An instance of `CameraData` containing properties of the + camera to be randomized. + + Raises: + TypeError: If the randomization mode specified in `camera.random_pose_mode` is + invalid. + Exception: If the camera cannot be detached from or attached to the robot. + + Notes: + - Supported randomization modes for camera pose: + - `"orbit"`: The camera is positioned in an orbital path around the object. + - `"select_random"`: A random camera pose is selected from predefined options. + - `"select_nearest"`: The camera is positioned based on the nearest predefined + camera pose. + """ + # Get random camera pose, centered at object position (or center of object spawn box) + if "orbit" == camera.random_pose_mode: + camera_position, camera_quat_xyzw = self.get_random_camera_pose_orbit( + camera + ) + elif "select_random" == camera.random_pose_mode: + ( + camera_position, + camera_quat_xyzw, + ) = self.get_random_camera_pose_sample_random(camera) + elif "select_nearest" == camera.random_pose_mode: + ( + camera_position, + camera_quat_xyzw, + ) = self.get_random_camera_pose_sample_nearest(camera) + else: + raise TypeError("Invalid mode for camera pose randomization.") + + if self.world.to_gazebo().name() == camera.relative_to: + transformed_camera_position = camera_position + transformed_camera_quat_wxyz = quat_to_wxyz(camera_quat_xyzw) + else: + # Transform the pose of camera to be with respect to robot - but represented in world reference frame for insertion into the world + ( + transformed_camera_position, + transformed_camera_quat_wxyz, + ) = transform_move_to_model_pose( + world=self.world, + position=camera_position, + quat=quat_to_wxyz(camera_quat_xyzw), + target_model=self.robot_wrapper, + target_link=camera.relative_to, + xyzw=False, + ) + + # Detach camera if needed + if self.__is_camera_attached: + if not self.robot_wrapper.to_gazebo().detach_link( + camera.relative_to, + self.camera_wrapper.name(), + self.camera_wrapper.link_name, + ): + raise Exception("Cannot detach camera link from robot") + + self.gazebo_paused_run() + + # Move pose of the camera + camera_gazebo = self.camera_wrapper.to_gazebo() + camera_gazebo.reset_base_pose( + transformed_camera_position, transformed_camera_quat_wxyz + ) + + self.gazebo_paused_run() + + # Attach to robot again if needed + if self.__is_camera_attached: + if not self.robot_wrapper.to_gazebo().attach_link( + camera.relative_to, + self.camera_wrapper.name(), + self.camera_wrapper.link_name, + ): + raise Exception("Cannot attach camera link to robot") + + self.gazebo_paused_run() + + # Broadcast tf + self.tf2_broadcaster.broadcast_tf( + parent_frame_id=camera.relative_to, + child_frame_id=self.camera_wrapper.frame_id, + translation=camera_position, + rotation=camera_quat_xyzw, + xyzw=True, + ) + + def get_random_camera_pose_orbit(self, camera: CameraData) -> Pose: + """ + Generate a random camera pose in an orbital path around a focal point. + + This method computes a random position for the camera based on an orbital distance + and height range. The camera's orientation is calculated to focus on a specified + focal point. The generated position ensures that it does not fall within a specified + arc behind the robot. + + Args: + camera (CameraData): An instance of `CameraData` containing properties of the + camera, including the orbital distance, height range, + focal point offset, and arc behind the robot to ignore. + + Returns: + Pose: A tuple containing the randomly generated position and orientation + (as a quaternion) of the camera. + + Notes: + - The method utilizes a random number generator to create a uniformly + distributed position within specified bounds. + - The computed camera orientation is represented in the quaternion format. + """ + rng = np.random.default_rng() + + while True: + position = rng.uniform( + low=(-1.0, -1.0, camera.random_pose_orbit_height_range[0]), + high=(1.0, 1.0, camera.random_pose_orbit_height_range[1]), + ) + + norm = np.linalg.norm(position) + position /= norm + + if ( + abs(np.arctan2(position[0], position[1]) + np.pi / 2) + > camera.random_pose_orbit_ignore_arc_behind_robot + ): + break + + rpy = np.array( + [ + 0.0, + np.arctan2( + position[2] - camera.random_pose_focal_point_z_offset, + np.linalg.norm(position[:2]), + ), + np.arctan2(position[1], position[0]) + np.pi, + ] + ) + quat_xyzw = Rotation.from_euler("xyz", rpy).as_quat() + + position *= camera.random_pose_orbit_distance + position[:2] += self.__objects_workspace_centre[:2] + + return tuple(position), tuple(quat_xyzw) + + def get_random_camera_pose_sample_random(self, camera: CameraData) -> Pose: + """ + Select a random position from the predefined camera position options. + + This method randomly selects one of the available camera position options + specified in the `camera` object and processes it to generate the camera's pose. + + Args: + camera (CameraData): An instance of `CameraData` containing the available + position options for random selection. + + Returns: + Pose: A tuple containing the selected camera position and its orientation + (as a quaternion). + """ + # Select a random entry from the options + camera.random_pose_select_position_options + + selection = camera.random_pose_select_position_options[ + np.random.randint(len(camera.random_pose_select_position_options)) + ] + + return self.get_random_camera_pose_sample_process(camera, selection) + + def get_random_camera_pose_sample_nearest(self, camera: CameraData) -> Pose: + """ + Select the nearest entry from the predefined camera position options. + + This method calculates the squared distances of the camera position options + from a central point and selects the nearest one to generate the camera's pose. + + Args: + camera (CameraData): An instance of `CameraData` containing the position options + to choose from. + + Returns: + Pose: A tuple containing the nearest camera position and its orientation + (as a quaternion). + """ + # Select the nearest entry + dist_sqr = np.sum( + ( + np.array(camera.random_pose_select_position_options) + - np.array(camera.random_pose_select_position_options) + ) + ** 2, + axis=1, + ) + nearest = camera.random_pose_select_position_options[np.argmin(dist_sqr)] + + return self.get_random_camera_pose_sample_process(camera, nearest) + + def get_random_camera_pose_sample_process( + self, camera: CameraData, position: Point + ) -> Pose: + """ + Process the selected camera position to determine its orientation. + + This method calculates the camera's orientation based on the selected position, + ensuring that the camera faces a focal point adjusted by a specified offset. + + Args: + camera (CameraData): An instance of `CameraData` that contains properties + used for orientation calculations. + position (Point): A 3D point representing the camera's position. + + Returns: + Pose: A tuple containing the camera's position and its orientation + (as a quaternion). + + Notes: + - The orientation is computed using roll, pitch, and yaw angles based on + the position relative to the workspace center. + """ + # Determine orientation such that camera faces the center + rpy = [ + 0.0, + np.arctan2( + position[2] - camera.random_pose_focal_point_z_offset, + np.linalg.norm( + ( + position[0] - self.__objects_workspace_centre[0], + position[1] - self.__objects_workspace_centre[1], + ), + 2, + ), + ), + np.arctan2( + position[1] - self.__objects_workspace_centre[1], + position[0] - self.__objects_workspace_centre[0], + ) + + np.pi, + ] + quat_xyzw = Rotation.from_euler("xyz", rpy).as_quat() + + return camera.spawn_position, tuple(quat_xyzw) + + def reset_robot_pose(self): + """ + Reset the robot's pose to a randomized position or its default spawn position. + + If the robot's randomizer is enabled for pose randomization, this method will + generate a new random position within the specified spawn volume and assign + a random orientation. Otherwise, it will reset the robot's pose to its default + spawn position and orientation. + + The new position is calculated by adding a random offset within the spawn + volume to the robot's initial spawn position. The orientation is generated + as a random rotation around the z-axis. + + Raises: + Exception: If there are issues resetting the robot pose in the Gazebo simulation. + + Notes: + - The position is a 3D vector represented as [x, y, z]. + - The orientation is represented as a quaternion in the form (x, y, z, w). + """ + if self.robot.randomizer.pose: + position = [ + self.robot.spawn_position[0] + + np.random.RandomState().uniform( + -self.robot.randomizer.spawn_volume[0] / 2, + self.robot.randomizer.spawn_volume[0] / 2, + ), + self.robot.spawn_position[1] + + np.random.RandomState().uniform( + -self.robot.randomizer.spawn_volume[1] / 2, + self.robot.randomizer.spawn_volume[1] / 2, + ), + self.robot.spawn_position[2] + + np.random.RandomState().uniform( + -self.robot.randomizer.spawn_volume[2] / 2, + self.robot.randomizer.spawn_volume[2] / 2, + ), + ] + quat_xyzw = Rotation.from_euler( + "xyz", (0, 0, np.random.RandomState().uniform(-np.pi, np.pi)) + ).as_quat() + quat_xyzw = tuple(quat_xyzw) + else: + position = self.robot.spawn_position + quat_xyzw = self.robot.spawn_quat_xyzw + + gazebo_robot = self.robot_wrapper.to_gazebo() + gazebo_robot.reset_base_pose(position, quat_to_wxyz(quat_xyzw)) + + self.gazebo_paused_run() + + def reset_robot_joint_positions(self): + """ + Reset the robot's joint positions and velocities to their initial values. + + This method retrieves the robot's initial arm joint positions and applies + them to the robot in the Gazebo simulation. If joint position randomization + is enabled, it adds Gaussian noise to the joint positions based on the + specified standard deviation. + + Additionally, the method resets the velocities of both the arm and gripper + joints to zero. After updating the joint states, it executes an unpaused step + in Gazebo to process the modifications and update the joint states. + + Raises: + RuntimeError: If there are issues resetting joint positions or velocities + in the Gazebo simulation, or if the Gazebo step fails. + + Notes: + - This method assumes the robot has a gripper only if `self.robot.with_gripper` is true. + - Joint positions for the gripper are reset only if gripper actuated joint names are available. + """ + gazebo_robot = self.robot_wrapper.to_gazebo() + + arm_joint_positions = self.robot_wrapper.initial_arm_joint_positions + + # Add normal noise if desired + if self.robot.randomizer.joint_positions: + for joint_position in arm_joint_positions: + joint_position += np.random.RandomState().normal( + loc=0.0, scale=self.robot.randomizer.joint_positions_std + ) + + # Arm joints - apply joint positions zero out velocities + if not gazebo_robot.reset_joint_positions( + arm_joint_positions, self.robot_wrapper.robot.actuated_joint_names + ): + raise RuntimeError("Failed to reset robot joint positions") + if not gazebo_robot.reset_joint_velocities( + [0.0] * len(self.robot_wrapper.robot.actuated_joint_names), + self.robot_wrapper.robot.actuated_joint_names, + ): + raise RuntimeError("Failed to reset robot joint velocities") + + # Gripper joints - apply joint positions zero out velocities + if ( + self.robot.with_gripper + and self.robot_wrapper.robot.gripper_actuated_joint_names + ): + if not gazebo_robot.reset_joint_positions( + self.robot_wrapper.initial_gripper_joint_positions, + self.robot_wrapper.robot.gripper_actuated_joint_names, + ): + raise RuntimeError("Failed to reset gripper joint positions") + if not gazebo_robot.reset_joint_velocities( + [0.0] * len(self.robot_wrapper.robot.gripper_actuated_joint_names), + self.robot_wrapper.robot.gripper_actuated_joint_names, + ): + raise RuntimeError("Failed to reset gripper joint velocities") + + # Execute an unpaused run to process model modification and get new JointStates + if not self.gazebo.step(): + raise RuntimeError("Failed to execute an unpaused Gazebo step") + + # if self.robot.with_gripper: + # if ( + # self.robot_wrapper.CLOSED_GRIPPER_JOINT_POSITIONS + # == self.robot.gripper_joint_positions + # ): + # self.gripper_controller.close() + # else: + # self.gripper.open() + + def reset_objects_pose(self, object: ObjectData): + """ + Reset the pose of specified objects in the simulation. + + This method updates the base position and orientation of all instances of + the specified object type in the Gazebo simulation to their defined + positions and orientations. + + Args: + object (ObjectData): The object data containing the name and the desired + position and orientation for the objects to be reset. + + Raises: + KeyError: If the specified object's name is not found in the unique names + registry. + + Notes: + - The method retrieves the list of unique names for the specified object + type and resets the pose of each instance to the provided position and + orientation from the `object` data. + """ + task_object_names = self.__objects_unique_names[object.name] + for object_name in task_object_names: + obj = self.world.to_gazebo().get_model(object_name).to_gazebo() + obj.reset_base_pose(object.position, object.orientation) + + self.gazebo_paused_run() + + def get_random_object_pose( + self, + obj: ObjectData, + name: str = "", + min_distance_to_other_objects: float = 0.2, + min_distance_decay_factor: float = 0.95, + ) -> Pose: + """ + Generate a random pose for an object, ensuring it is sufficiently + distanced from other objects. + + This method computes a random position within a specified spawn volume for + the given object. It checks that the new position does not violate minimum + distance constraints from existing objects in the simulation. + + Args: + obj (ObjectData): The object data containing the base position, + randomization parameters, and relative positioning info. + name (str, optional): The name of the object being positioned. Defaults to an empty string. + Mic_distance_to_other_objects (float, optional): The minimum required distance + from other objects. Defaults to 0.2. + min_distance_decay_factor (float, optional): Factor by which the minimum distance + will decay if the new position is + too close to existing objects. Defaults to 0.95. + + Returns: + Pose: A tuple containing the randomly generated position (x, y, z) and a + randomly generated orientation (quaternion). + + Notes: + - The method uses a uniform distribution to select random coordinates + within the object's defined spawn volume. + - The quaternion is generated to provide a random orientation for the object. + - The method iteratively checks for proximity to other objects and adjusts the + minimum distance if necessary. + """ + is_too_close = True + # The default value for randomization + object_position = obj.position + while is_too_close: + object_position = ( + obj.position[0] + + np.random.uniform( + -obj.randomize.random_spawn_volume[0] / 2, + obj.randomize.random_spawn_volume[0] / 2, + ), + obj.position[1] + + np.random.uniform( + -obj.randomize.random_spawn_volume[1] / 2, + obj.randomize.random_spawn_volume[1] / 2, + ), + obj.position[2] + + np.random.uniform( + -obj.randomize.random_spawn_volume[2] / 2, + obj.randomize.random_spawn_volume[2] / 2, + ), + ) + + if self.world.to_gazebo().name() != obj.relative_to: + # Transform the pose of camera to be with respect to robot - but represented in world reference frame for insertion into the world + object_position = transform_move_to_model_position( + world=self.world, + position=object_position, + target_model=self.robot_wrapper, + target_link=obj.relative_to, + ) + + # Check if position is far enough from other + is_too_close = False + for ( + existing_object_name, + existing_object_position, + ) in self.__object_positions.items(): + if existing_object_name == name: + # Do not compare to itself + continue + if ( + distance.euclidean(object_position, existing_object_position) + < min_distance_to_other_objects + ): + min_distance_to_other_objects *= min_distance_decay_factor + is_too_close = True + break + + quat = np.random.uniform(-1, 1, 4) + quat /= np.linalg.norm(quat) + + return object_position, tuple(quat) + + def is_camera_pose_expired(self, camera: CameraData) -> bool: + """ + Check if the camera pose has expired and needs to be randomized. + + This method evaluates whether the camera's current pose has reached its + limit for randomization based on the specified rollout count. If the + camera has exceeded the allowed number of rollouts, it will reset the + counter and indicate that the pose should be randomized. + + Args: + camera (CameraData): The camera data containing the current rollout + counter and the maximum number of rollouts + allowed before a new pose is required. + + Returns: + bool: True if the camera pose needs to be randomized (i.e., it has + expired), False otherwise. + """ + + if camera.random_pose_rollouts_num == 0: + return False + + camera.random_pose_rollout_counter += 1 + + if camera.random_pose_rollout_counter >= camera.random_pose_rollouts_num: + camera.random_pose_rollout_counter = 0 + return True + + return False + + def gazebo_paused_run(self): + """ + Execute a run in Gazebo while paused. + + This method attempts to run the Gazebo simulation in a paused state. + If the operation fails, it raises a RuntimeError to indicate the failure. + + Raises: + RuntimeError: If the Gazebo run operation fails. + """ + if not self.gazebo.run(paused=True): + raise RuntimeError("Failed to execute a paused Gazebo run") + + def check_object_outside_workspace( + self, + object_position: Point, + ) -> bool: + """ + Check if an object is outside the defined workspace. + + This method evaluates whether the given object position lies outside the + specified workspace boundaries, which are defined by a parallelepiped + determined by minimum and maximum bounds. + + Args: + object_position (Point): A tuple or list representing the 3D position + of the object in the format (x, y, z). + + Returns: + bool: True if the object is outside the workspace, False otherwise. + """ + + return ( + object_position[0] < self.__workspace_min_bound[0] + or object_position[1] < self.__workspace_min_bound[1] + or object_position[2] < self.__workspace_min_bound[2] + or object_position[0] > self.__workspace_max_bound[0] + or object_position[1] > self.__workspace_max_bound[1] + or object_position[2] > self.__workspace_max_bound[2] + ) + + def check_object_overlapping( + self, + object: ObjectData, + allowed_penetration_depth: float = 0.001, + terrain_allowed_penetration_depth: float = 0.002, + ) -> bool: + """ + Check for overlapping objects in the simulation and reset their positions if necessary. + + This method iterates through all objects and ensures that none of them are overlapping. + If an object is found to be overlapping with another object (or the robot), it resets + its position to a random valid pose. Collisions or overlaps with terrain are ignored. + + Args: + object (ObjectData): The object data to check for overlaps. + allowed_penetration_depth (float, optional): The maximum allowed penetration depth + between objects. Defaults to 0.001. + terrain_allowed_penetration_depth (float, optional): The maximum allowed penetration + depth when in contact with terrain. Defaults to 0.002. + + Returns: + bool: True if all objects are in valid positions, False if any had to be reset. + """ + + # Update object positions + for object_name in self.__objects_unique_names[object.name]: + model = self.world.get_model(object_name).to_gazebo() + self.__object_positions[object_name] = model.get_link( + link_name=model.link_names()[0] + ).position() + + for object_name in self.__objects_unique_names[object.name]: + obj = self.world.get_model(object_name).to_gazebo() + + # Make sure the object is inside workspace + if self.check_object_outside_workspace( + self.__object_positions[object_name] + ): + position, quat_random = self.get_random_object_pose( + object, + name=object_name, + ) + obj.reset_base_pose(position, quat_random) + obj.reset_base_world_velocity([0.0, 0.0, 0.0], [0.0, 0.0, 0.0]) + return False + + # Make sure the object is not intersecting other objects + try: + for contact in obj.contacts(): + depth = np.mean([point.depth for point in contact.points]) + if ( + self.terrain_wrapper.name() in contact.body_b + and depth < terrain_allowed_penetration_depth + ): + continue + if ( + self.robot_wrapper.name() in contact.body_b + or depth > allowed_penetration_depth + ): + position, quat_random = self.get_random_object_pose( + object, + name=object_name, + ) + obj.reset_base_pose(position, quat_random) + obj.reset_base_world_velocity([0.0, 0.0, 0.0], [0.0, 0.0, 0.0]) + return False + except Exception as e: + self.node.get_logger().error( + f"Runtime error encountered while checking objects intersections: {e}" + ) + + return True diff --git a/env_manager/env_manager/env_manager/utils/__init__.py b/env_manager/env_manager/env_manager/utils/__init__.py new file mode 100644 index 0000000..88b3cc1 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/__init__.py @@ -0,0 +1,4 @@ +from . import conversions, gazebo, logging, math, types +from .tf2_broadcaster import Tf2Broadcaster, Tf2BroadcasterStandalone +from .tf2_dynamic_broadcaster import Tf2DynamicBroadcaster +from .tf2_listener import Tf2Listener, Tf2ListenerStandalone diff --git a/env_manager/env_manager/env_manager/utils/conversions.py b/env_manager/env_manager/env_manager/utils/conversions.py new file mode 100644 index 0000000..064c33d --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/conversions.py @@ -0,0 +1,182 @@ +from typing import Tuple, Union + +import numpy +# import open3d +# import pyoctree +import sensor_msgs +from scipy.spatial.transform import Rotation + +# from sensor_msgs.msg import PointCloud2 +# from open3d.geometry import PointCloud +from geometry_msgs.msg import Transform + + +# def pointcloud2_to_open3d( +# ros_point_cloud2: PointCloud2, +# include_color: bool = False, +# include_intensity: bool = False, +# # Note: Order does not matter for DL, that's why channel swapping is disabled by default +# fix_rgb_channel_order: bool = False, +# ) -> PointCloud: +# +# # Create output Open3D PointCloud +# open3d_pc = PointCloud() +# +# size = ros_point_cloud2.width * ros_point_cloud2.height +# xyz_dtype = ">f4" if ros_point_cloud2.is_bigendian else " 3: +# bgr = numpy.ndarray( +# shape=(size, 3), +# dtype=numpy.uint8, +# buffer=ros_point_cloud2.data, +# offset=ros_point_cloud2.fields[3].offset, +# strides=(ros_point_cloud2.point_step, 1), +# ) +# if fix_rgb_channel_order: +# # Swap channels to gain rgb (faster than `bgr[:, [2, 1, 0]]`) +# bgr[:, 0], bgr[:, 2] = bgr[:, 2], bgr[:, 0].copy() +# open3d_pc.colors = open3d.utility.Vector3dVector( +# (bgr[valid_points] / 255).astype(numpy.float64) +# ) +# else: +# open3d_pc.colors = open3d.utility.Vector3dVector( +# numpy.zeros((len(valid_points), 3), dtype=numpy.float64) +# ) +# # TODO: Update octree creator once L8 image format is supported in Ignition Gazebo +# # elif include_intensity: +# # # Faster approach, but only the first channel gets the intensity value (rest is 0) +# # intensities = numpy.zeros((len(valid_points), 3), dtype=numpy.float64) +# # intensities[:, [0]] = ( +# # numpy.ndarray( +# # shape=(size, 1), +# # dtype=numpy.uint8, +# # buffer=ros_point_cloud2.data, +# # offset=ros_point_cloud2.fields[3].offset, +# # strides=(ros_point_cloud2.point_step, 1), +# # )[valid_points] +# # / 255 +# # ).astype(numpy.float64) +# # open3d_pc.colors = open3d.utility.Vector3dVector(intensities) +# # # # Slower approach, but all channels get the intensity value +# # # intensities = numpy.ndarray( +# # # shape=(size, 1), +# # # dtype=numpy.uint8, +# # # buffer=ros_point_cloud2.data, +# # # offset=ros_point_cloud2.fields[3].offset, +# # # strides=(ros_point_cloud2.point_step, 1), +# # # ) +# # # open3d_pc.colors = open3d.utility.Vector3dVector( +# # # numpy.tile(intensities[valid_points] / 255, (1, 3)).astype(numpy.float64) +# # # ) +# +# # Return the converted Open3D PointCloud +# return open3d_pc + + +def transform_to_matrix(transform: Transform) -> numpy.ndarray: + + transform_matrix = numpy.zeros((4, 4)) + transform_matrix[3, 3] = 1.0 + + transform_matrix[0:3, 0:3] = open3d.geometry.get_rotation_matrix_from_quaternion( + [ + transform.rotation.w, + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + ] + ) + transform_matrix[0, 3] = transform.translation.x + transform_matrix[1, 3] = transform.translation.y + transform_matrix[2, 3] = transform.translation.z + + return transform_matrix + + +# def open3d_point_cloud_to_octree_points( +# open3d_point_cloud: PointCloud, +# include_color: bool = False, +# include_intensity: bool = False, +# ) -> pyoctree.Points: +# +# octree_points = pyoctree.Points() +# +# if include_color: +# features = numpy.reshape(numpy.asarray(open3d_point_cloud.colors), -1) +# elif include_intensity: +# features = numpy.asarray(open3d_point_cloud.colors)[:, 0] +# else: +# features = [] +# +# octree_points.set_points( +# # XYZ points +# numpy.reshape(numpy.asarray(open3d_point_cloud.points), -1), +# # Normals +# numpy.reshape(numpy.asarray(open3d_point_cloud.normals), -1), +# # Other features, e.g. color +# features, +# # Labels - not used +# [], +# ) +# +# return octree_points + + +def orientation_6d_to_quat( + v1: Tuple[float, float, float], v2: Tuple[float, float, float] +) -> Tuple[float, float, float, float]: + + # Normalize vectors + col1 = v1 / numpy.linalg.norm(v1) + col2 = v2 / numpy.linalg.norm(v2) + + # Find their orthogonal vector via cross product + col3 = numpy.cross(col1, col2) + + # Stack into rotation matrix as columns, convert to quaternion and return + quat_xyzw = Rotation.from_matrix(numpy.array([col1, col2, col3]).T).as_quat() + return quat_xyzw + + +def orientation_quat_to_6d( + quat_xyzw: Tuple[float, float, float, float] +) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]: + + # Convert quaternion into rotation matrix + rot_mat = Rotation.from_quat(quat_xyzw).as_matrix() + + # Return first two columns (already normalised) + return (tuple(rot_mat[:, 0]), tuple(rot_mat[:, 1])) + + +def quat_to_wxyz( + xyzw: tuple[float, float, float, float] +) -> tuple[float, float, float, float]: + + return (xyzw[3], xyzw[0], xyzw[1], xyzw[2]) + + # return xyzw[[3, 0, 1, 2]] + + +def quat_to_xyzw( + wxyz: Union[numpy.ndarray, Tuple[float, float, float, float]] +) -> numpy.ndarray: + + if isinstance(wxyz, tuple): + return (wxyz[1], wxyz[2], wxyz[3], wxyz[0]) + + return wxyz[[1, 2, 3, 0]] diff --git a/env_manager/env_manager/env_manager/utils/gazebo.py b/env_manager/env_manager/env_manager/utils/gazebo.py new file mode 100644 index 0000000..2786de8 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/gazebo.py @@ -0,0 +1,262 @@ +from typing import Tuple, Union + +from gym_gz.scenario.model_wrapper import ModelWrapper +from numpy import exp +from scenario.bindings.core import Link, World +from scipy.spatial.transform import Rotation + +from .conversions import quat_to_wxyz, quat_to_xyzw +from .math import quat_mul +from .types import Pose, Point, Quat + + +#NOTE: Errors in pyright will be fixed only with pybind11 in the scenario module +def get_model_pose( + world: World, + model: ModelWrapper | str, + link: Link | str | None = None, + xyzw: bool = False, +) -> Pose: + """ + Return pose of model's link. Orientation is represented as wxyz quaternion or xyzw based on the passed argument `xyzw`. + """ + + if isinstance(model, str): + # Get reference to the model from its name if string is passed + model = world.to_gazebo().get_model(model).to_gazebo() + + if link is None: + # Use the first link if not specified + link = model.get_link(link_name=model.link_names()[0]) + elif isinstance(link, str): + # Get reference to the link from its name if string + link = model.get_link(link_name=link) + + # Get position and orientation + position = link.position() + quat = link.orientation() + + # Convert to xyzw order if desired + if xyzw: + quat = quat_to_xyzw(quat) + + # Return pose of the model's link + return ( + position, + quat, + ) + + +def get_model_position( + world: World, + model: ModelWrapper | str, + link: Link | str | None = None, +) -> Point: + """ + Return position of model's link. + """ + + if isinstance(model, str): + # Get reference to the model from its name if string is passed + model = world.to_gazebo().get_model(model).to_gazebo() + + if link is None: + # Use the first link if not specified + link = model.get_link(link_name=model.link_names()[0]) + elif isinstance(link, str): + # Get reference to the link from its name if string + link = model.get_link(link_name=link) + + # Return position of the model's link + return link.position() + + +def get_model_orientation( + world: World, + model: ModelWrapper | str, + link: Link | str | None = None, + xyzw: bool = False, +) -> Quat: + """ + Return orientation of model's link that is represented as wxyz quaternion or xyzw based on the passed argument `xyzw`. + """ + + if isinstance(model, str): + # Get reference to the model from its name if string is passed + model = world.to_gazebo().get_model(model).to_gazebo() + + if link is None: + # Use the first link if not specified + link = model.get_link(link_name=model.link_names()[0]) + elif isinstance(link, str): + # Get reference to the link from its name if string + link = model.get_link(link_name=link) + + # Get orientation + quat = link.orientation() + + # Convert to xyzw order if desired + if xyzw: + quat = quat_to_xyzw(quat) + + # Return orientation of the model's link + return quat + + +def transform_move_to_model_pose( + world: World, + position: Point, + quat: Quat, + target_model: ModelWrapper | str, + target_link: Link | str | None = None, + xyzw: bool = False, +) -> Pose: + """ + Transform such that original `position` and `quat` are represented with respect to `target_model::target_link`. + The resulting pose is still represented in world coordinate system. + """ + + target_frame_position, target_frame_quat = get_model_pose( + world, + model=target_model, + link=target_link, + xyzw=True, + ) + + transformed_position = Rotation.from_quat(target_frame_quat).apply(position) + transformed_position = ( + transformed_position[0] + target_frame_position[0], + transformed_position[1] + target_frame_position[1], + transformed_position[2] + target_frame_position[2], + ) + + if not xyzw: + target_frame_quat = quat_to_wxyz(target_frame_quat) + transformed_quat = quat_mul(quat, target_frame_quat, xyzw=xyzw) + + return (transformed_position, transformed_quat) + + +def transform_move_to_model_position( + world: World, + position: Point, + target_model: ModelWrapper | str, + target_link: Link | str | None = None, +) -> Point: + + target_frame_position, target_frame_quat_xyzw = get_model_pose( + world, + model=target_model, + link=target_link, + xyzw=True, + ) + + transformed_position = Rotation.from_quat(target_frame_quat_xyzw).apply(position) + transformed_position = ( + target_frame_position[0] + transformed_position[0], + target_frame_position[1] + transformed_position[1], + target_frame_position[2] + transformed_position[2], + ) + + return transformed_position + + +def transform_move_to_model_orientation( + world: World, + quat: Quat, + target_model: ModelWrapper | str, + target_link: Link | str | None = None, + xyzw: bool = False, +) -> Quat: + + target_frame_quat = get_model_orientation( + world, + model=target_model, + link=target_link, + xyzw=xyzw, + ) + + transformed_quat = quat_mul(quat, target_frame_quat, xyzw=xyzw) + + return transformed_quat + + +def transform_change_reference_frame_pose( + world: World, + position: Point, + quat: Quat, + target_model: ModelWrapper | str, + target_link: Link | str | None, + xyzw: bool = False, +) -> Pose: + """ + Change reference frame of original `position` and `quat` from world coordinate system to `target_model::target_link` coordinate system. + """ + + target_frame_position, target_frame_quat = get_model_pose( + world, + model=target_model, + link=target_link, + xyzw=True, + ) + + transformed_position = ( + position[0] - target_frame_position[0], + position[1] - target_frame_position[1], + position[2] - target_frame_position[2], + ) + transformed_position = Rotation.from_quat(target_frame_quat).apply( + transformed_position, inverse=True + ) + + if not xyzw: + target_frame_quat = quat_to_wxyz(target_frame_quat) + transformed_quat = quat_mul(target_frame_quat, quat, xyzw=xyzw) + + return (tuple(transformed_position), transformed_quat) + + +def transform_change_reference_frame_position( + world: World, + position: Point, + target_model: ModelWrapper | str, + target_link: Link | str | None = None, +) -> Point: + + target_frame_position, target_frame_quat_xyzw = get_model_pose( + world, + model=target_model, + link=target_link, + xyzw=True, + ) + + transformed_position = ( + position[0] - target_frame_position[0], + position[1] - target_frame_position[1], + position[2] - target_frame_position[2], + ) + transformed_position = Rotation.from_quat(target_frame_quat_xyzw).apply( + transformed_position, inverse=True + ) + + return tuple(transformed_position) + + +def transform_change_reference_frame_orientation( + world: World, + quat: Quat, + target_model: ModelWrapper | str, + target_link: Link | str | None = None, + xyzw: bool = False, +) -> Quat: + + target_frame_quat = get_model_orientation( + world, + model=target_model, + link=target_link, + xyzw=xyzw, + ) + + transformed_quat = quat_mul(target_frame_quat, quat, xyzw=xyzw) + + return transformed_quat diff --git a/env_manager/env_manager/env_manager/utils/helper.py b/env_manager/env_manager/env_manager/utils/helper.py new file mode 100644 index 0000000..a5bfd08 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/helper.py @@ -0,0 +1,488 @@ +import numpy as np +from gym_gz.scenario.model_wrapper import ModelWrapper +from rclpy.node import Node +from scenario.bindings.core import World +from scipy.spatial.transform import Rotation + +from .conversions import orientation_6d_to_quat +from .gazebo import ( + get_model_orientation, + get_model_pose, + get_model_position, + transform_change_reference_frame_orientation, + transform_change_reference_frame_pose, + transform_change_reference_frame_position, +) +from .math import quat_mul +from .tf2_listener import Tf2Listener +from .types import Point, Pose, PoseRpy, Quat, Rpy + + +# Helper functions # +def get_relative_ee_position(self, translation: Point) -> Point: + # Scale relative action to metric units + translation = self.scale_relative_translation(translation) + # Get current position + current_position = self.get_ee_position() + # Compute target position + target_position = ( + current_position[0] + translation[0], + current_position[1] + translation[1], + current_position[2] + translation[2], + ) + + # Restrict target position to a limited workspace, if desired + if self._restrict_position_goal_to_workspace: + target_position = self.restrict_position_goal_to_workspace(target_position) + + return target_position + + +def get_relative_ee_orientation( + self, + rotation: float | Quat | PoseRpy, + representation: str = "quat", +) -> Quat: + current_quat_xyzw = self.get_ee_orientation() + + if representation == "z": + current_yaw = Rotation.from_quat(current_quat_xyzw).as_euler("xyz")[2] + current_quat_xyzw = Rotation.from_euler( + "xyz", [np.pi, 0, current_yaw] + ).as_quat() + + if isinstance(rotation, tuple): + if len(rotation) == 4 and representation == "quat": + relative_quat_xyzw = rotation + elif len(rotation) == 6 and representation == "6d": + vectors = (rotation[0:3], rotation[3:6]) + relative_quat_xyzw = orientation_6d_to_quat(vectors[0], vectors[1]) + else: + raise ValueError("Invalid rotation tuple length for representation.") + elif isinstance(rotation, float) and representation == "z": + rotation = self.scale_relative_rotation(rotation) + relative_quat_xyzw = Rotation.from_euler("xyz", [0, 0, rotation]).as_quat() + else: + raise TypeError("Invalid type for rotation or representation.") + + target_quat_xyzw = quat_mul(tuple(current_quat_xyzw), tuple(relative_quat_xyzw)) + + target_quat_xyzw = normalize_quaternion(tuple(relative_quat_xyzw)) + + return target_quat_xyzw + + +def normalize_quaternion( + target_quat_xyzw: tuple[float, float, float, float], +) -> tuple[float, float, float, float]: + quat_array = np.array(target_quat_xyzw) + normalized_quat = quat_array / np.linalg.norm(quat_array) + return tuple(normalized_quat) + + +def scale_relative_translation(self, translation: Point) -> Point: + return ( + self.__scaling_factor_translation * translation[0], + self.__scaling_factor_translation * translation[1], + self.__scaling_factor_translation * translation[2], + ) + + +def scale_relative_rotation( + self, rotation: float | Rpy | np.floating | np.ndarray +) -> float | Rpy: + scaling_factor = self.__scaling_factor_rotation + + if isinstance(rotation, (int, float, np.floating)): + return scaling_factor * rotation + + return tuple(scaling_factor * r for r in rotation) + + +def restrict_position_goal_to_workspace(self, position: Point) -> Point: + return ( + min( + self.workspace_max_bound[0], + max( + self.workspace_min_bound[0], + position[0], + ), + ), + min( + self.workspace_max_bound[1], + max( + self.workspace_min_bound[1], + position[1], + ), + ), + min( + self.workspace_max_bound[2], + max( + self.workspace_min_bound[2], + position[2], + ), + ), + ) + + +# def restrict_servo_translation_to_workspace( +# self, translation: tuple[float, float, float] +# ) -> tuple[float, float, float]: +# current_ee_position = self.get_ee_position() +# +# translation = tuple( +# 0.0 +# if ( +# current_ee_position[i] > self.workspace_max_bound[i] +# and translation[i] > 0.0 +# ) +# or ( +# current_ee_position[i] < self.workspace_min_bound[i] +# and translation[i] < 0.0 +# ) +# else translation[i] +# for i in range(3) +# ) +# +# return translation + + +def get_ee_pose( + node: Node, + world: World, + robot_ee_link_name: str, + robot_name: str, + robot_arm_base_link_name: str, + tf2_listener: Tf2Listener, +) -> Pose | None: + """ + Return the current pose of the end effector with respect to arm base link. + """ + + try: + robot_model = world.to_gazebo().get_model(robot_name).to_gazebo() + ee_position, ee_quat_xyzw = get_model_pose( + world=world, + model=robot_model, + link=robot_ee_link_name, + xyzw=True, + ) + return transform_change_reference_frame_pose( + world=world, + position=ee_position, + quat=ee_quat_xyzw, + target_model=robot_model, + target_link=robot_arm_base_link_name, + xyzw=True, + ) + except Exception as e: + node.get_logger().warn( + f"Cannot get end effector pose from Gazebo ({e}), using tf2..." + ) + transform = tf2_listener.lookup_transform_sync( + source_frame=robot_ee_link_name, + target_frame=robot_arm_base_link_name, + retry=False, + ) + if transform is not None: + return ( + ( + transform.translation.x, + transform.translation.y, + transform.translation.z, + ), + ( + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ), + ) + else: + node.get_logger().error( + "Cannot get pose of the end effector (default values are returned)" + ) + return ( + (0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 1.0), + ) + + +def get_ee_position( + world: World, + robot_name: str, + robot_ee_link_name: str, + robot_arm_base_link_name: str, + node: Node, + tf2_listener: Tf2Listener, +) -> Point: + """ + Return the current position of the end effector with respect to arm base link. + """ + + try: + robot_model = world.to_gazebo().get_model(robot_name).to_gazebo() + ee_position = get_model_position( + world=world, + model=robot_model, + link=robot_ee_link_name, + ) + return transform_change_reference_frame_position( + world=world, + position=ee_position, + target_model=robot_model, + target_link=robot_arm_base_link_name, + ) + except Exception as e: + node.get_logger().debug( + f"Cannot get end effector position from Gazebo ({e}), using tf2..." + ) + transform = tf2_listener.lookup_transform_sync( + source_frame=robot_ee_link_name, + target_frame=robot_arm_base_link_name, + retry=False, + ) + if transform is not None: + return ( + transform.translation.x, + transform.translation.y, + transform.translation.z, + ) + else: + node.get_logger().error( + "Cannot get position of the end effector (default values are returned)" + ) + return (0.0, 0.0, 0.0) + + +def get_ee_orientation( + world: World, + robot_name: str, + robot_ee_link_name: str, + robot_arm_base_link_name: str, + node: Node, + tf2_listener: Tf2Listener, +) -> Quat: + """ + Return the current xyzw quaternion of the end effector with respect to arm base link. + """ + + try: + robot_model = world.to_gazebo().get_model(robot_name).to_gazebo() + ee_quat_xyzw = get_model_orientation( + world=world, + model=robot_model, + link=robot_ee_link_name, + xyzw=True, + ) + return transform_change_reference_frame_orientation( + world=world, + quat=ee_quat_xyzw, + target_model=robot_model, + target_link=robot_arm_base_link_name, + xyzw=True, + ) + except Exception as e: + node.get_logger().warn( + f"Cannot get end effector orientation from Gazebo ({e}), using tf2..." + ) + transform = tf2_listener.lookup_transform_sync( + source_frame=robot_ee_link_name, + target_frame=robot_arm_base_link_name, + retry=False, + ) + if transform is not None: + return ( + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ) + else: + node.get_logger().error( + "Cannot get orientation of the end effector (default values are returned)" + ) + return (0.0, 0.0, 0.0, 1.0) + + +def get_object_position( + object_model: ModelWrapper | str, + node: Node, + world: World, + robot_name: str, + robot_arm_base_link_name: str, +) -> Point: + """ + Return the current position of an object with respect to arm base link. + Note: Only simulated objects are currently supported. + """ + + try: + object_position = get_model_position( + world=world, + model=object_model, + ) + return transform_change_reference_frame_position( + world=world, + position=object_position, + target_model=robot_name, + target_link=robot_arm_base_link_name, + ) + except Exception as e: + node.get_logger().error( + f"Cannot get position of {object_model} object (default values are returned): {e}" + ) + return (0.0, 0.0, 0.0) + + +def get_object_positions( + node: Node, + world: World, + object_names: list[str], + robot_name: str, + robot_arm_base_link_name: str, +) -> dict[str, tuple[float, float, float]]: + """ + Return the current position of all objects with respect to arm base link. + Note: Only simulated objects are currently supported. + """ + + object_positions = {} + + try: + robot_model = world.to_gazebo().get_model(robot_name).to_gazebo() + robot_arm_base_link = robot_model.get_link(link_name=robot_arm_base_link_name) + for object_name in object_names: + object_position = get_model_position( + world=world, + model=object_name, + ) + object_positions[object_name] = transform_change_reference_frame_position( + world=world, + position=object_position, + target_model=robot_model, + target_link=robot_arm_base_link, + ) + except Exception as e: + node.get_logger().error( + f"Cannot get positions of all objects (empty Dict is returned): {e}" + ) + + return object_positions + + +# def substitute_special_frame(self, frame_id: str) -> str: +# if "arm_base_link" == frame_id: +# return self.robot_arm_base_link_name +# elif "base_link" == frame_id: +# return self.robot_base_link_name +# elif "end_effector" == frame_id: +# return self.robot_ee_link_name +# elif "world" == frame_id: +# try: +# # In Gazebo, where multiple worlds are allowed +# return self.world.to_gazebo().name() +# except Exception as e: +# self.get_logger().warn(f"") +# # Otherwise (e.g. real world) +# return "rbs_gym_world" +# else: +# return frame_id + + +def move_to_initial_joint_configuration(self): + pass + + # self.moveit2.move_to_configuration(self.initial_arm_joint_positions) + + # if ( + # self.robot_model_class.CLOSED_GRIPPER_JOINT_POSITIONS + # == self.initial_gripper_joint_positions + # ): + # self.gripper.reset_close() + # else: + # self.gripper.reset_open() + + +def check_terrain_collision( + world: World, + robot_name: str, + terrain_name: str, + robot_base_link_name: str, + robot_arm_link_names: list[str], + robot_gripper_link_names: list[str], +) -> bool: + """ + Returns true if robot links are in collision with the ground. + """ + robot_name_len = len(robot_name) + terrain_model = world.get_model(terrain_name) + + for contact in terrain_model.contacts(): + body_b = contact.body_b + + if body_b.startswith(robot_name) and len(body_b) > robot_name_len: + link = body_b[robot_name_len + 2 :] + + if link != robot_base_link_name and ( + link in robot_arm_link_names or link in robot_gripper_link_names + ): + return True + + return False + + +def check_all_objects_outside_workspace( + workspace: Pose, + object_positions: dict[str, Point], +) -> bool: + """ + Returns True if all objects are outside the workspace. + """ + + return all( + [ + check_object_outside_workspace(workspace, object_position) + for object_position in object_positions.values() + ] + ) + + +def check_object_outside_workspace(workspace: Pose, object_position: Point) -> bool: + """ + Returns True if the object is outside the workspace. + """ + workspace_min_bound, workspace_max_bound = workspace + return any( + coord < min_bound or coord > max_bound + for coord, min_bound, max_bound in zip( + object_position, workspace_min_bound, workspace_max_bound + ) + ) + + +# def add_parameter_overrides(self, parameter_overrides: Dict[str, any]): +# self.add_task_parameter_overrides(parameter_overrides) +# self.add_randomizer_parameter_overrides(parameter_overrides) +# +# +# def add_task_parameter_overrides(self, parameter_overrides: Dict[str, any]): +# self.__task_parameter_overrides.update(parameter_overrides) +# +# +# def add_randomizer_parameter_overrides(self, parameter_overrides: Dict[str, any]): +# self._randomizer_parameter_overrides.update(parameter_overrides) +# +# +# def __consume_parameter_overrides(self): +# for key, value in self.__task_parameter_overrides.items(): +# if hasattr(self, key): +# setattr(self, key, value) +# elif hasattr(self, f"_{key}"): +# setattr(self, f"_{key}", value) +# elif hasattr(self, f"__{key}"): +# setattr(self, f"__{key}", value) +# else: +# self.get_logger().error(f"Override '{key}' is not supperted by the task.") +# +# self.__task_parameter_overrides.clear() diff --git a/env_manager/env_manager/env_manager/utils/logging.py b/env_manager/env_manager/env_manager/utils/logging.py new file mode 100644 index 0000000..e9b8c09 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/logging.py @@ -0,0 +1,25 @@ +from typing import Union + +from gymnasium import logger as gym_logger +from gym_gz.utils import logger as gym_ign_logger + + +def set_log_level(log_level: Union[int, str]): + """ + Set log level for (Gym) Ignition. + """ + + if not isinstance(log_level, int): + log_level = str(log_level).upper() + + if "WARNING" == log_level: + log_level = "WARN" + elif not log_level in ["DEBUG", "INFO", "WARN", "ERROR", "DISABLED"]: + log_level = "DISABLED" + + log_level = getattr(gym_logger, log_level) + + gym_ign_logger.set_level( + level=log_level, + scenario_level=log_level, + ) diff --git a/env_manager/env_manager/env_manager/utils/math.py b/env_manager/env_manager/env_manager/utils/math.py new file mode 100644 index 0000000..291cdf5 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/math.py @@ -0,0 +1,44 @@ +import numpy as np + + +def quat_mul( + quat_0: tuple[float, float, float, float], + quat_1: tuple[float, float, float, float], + xyzw: bool = True, +) -> tuple[float, float, float, float]: + """ + Multiply two quaternions + """ + if xyzw: + x0, y0, z0, w0 = quat_0 + x1, y1, z1, w1 = quat_1 + return ( + x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0, + -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0, + x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0, + -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0, + ) + else: + w0, x0, y0, z0 = quat_0 + w1, x1, y1, z1 = quat_1 + return ( + -x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0, + x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0, + -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0, + x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0, + ) + + +def distance_to_nearest_point( + origin: tuple[float, float, float], points: list[tuple[float, float, float]] +) -> float: + + return np.linalg.norm(np.array(points) - np.array(origin), axis=1).min() + + +def get_nearest_point( + origin: tuple[float, float, float], points: list[tuple[float, float, float]] +) -> tuple[float, float, float]: + + target_distances = np.linalg.norm(np.array(points) - np.array(origin), axis=1) + return points[target_distances.argmin()] diff --git a/env_manager/env_manager/env_manager/utils/tf2_broadcaster.py b/env_manager/env_manager/env_manager/utils/tf2_broadcaster.py new file mode 100644 index 0000000..27fb52a --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/tf2_broadcaster.py @@ -0,0 +1,74 @@ +import sys +from typing import Tuple + +import rclpy +from geometry_msgs.msg import TransformStamped +from rclpy.node import Node +from rclpy.parameter import Parameter +from tf2_ros import StaticTransformBroadcaster + + +class Tf2Broadcaster: + def __init__( + self, + node: Node, + ): + + self._node = node + self.__tf2_broadcaster = StaticTransformBroadcaster(node=self._node) + self._transform_stamped = TransformStamped() + + def broadcast_tf( + self, + parent_frame_id: str, + child_frame_id: str, + translation: Tuple[float, float, float], + rotation: Tuple[float, float, float, float], + xyzw: bool = True, + ): + """ + Broadcast transformation of the camera + """ + + self._transform_stamped.header.frame_id = parent_frame_id + self._transform_stamped.child_frame_id = child_frame_id + + self._transform_stamped.header.stamp = self._node.get_clock().now().to_msg() + + self._transform_stamped.transform.translation.x = float(translation[0]) + self._transform_stamped.transform.translation.y = float(translation[1]) + self._transform_stamped.transform.translation.z = float(translation[2]) + + if xyzw: + self._transform_stamped.transform.rotation.x = float(rotation[0]) + self._transform_stamped.transform.rotation.y = float(rotation[1]) + self._transform_stamped.transform.rotation.z = float(rotation[2]) + self._transform_stamped.transform.rotation.w = float(rotation[3]) + else: + self._transform_stamped.transform.rotation.w = float(rotation[0]) + self._transform_stamped.transform.rotation.x = float(rotation[1]) + self._transform_stamped.transform.rotation.y = float(rotation[2]) + self._transform_stamped.transform.rotation.z = float(rotation[3]) + + self.__tf2_broadcaster.sendTransform(self._transform_stamped) + + +class Tf2BroadcasterStandalone(Node, Tf2Broadcaster): + def __init__( + self, + node_name: str = "env_manager_tf_broadcaster", + use_sim_time: bool = True, + ): + + try: + rclpy.init() + except Exception as e: + if not rclpy.ok(): + sys.exit(f"ROS 2 context could not be initialised: {e}") + + Node.__init__(self, node_name) + self.set_parameters( + [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)] + ) + + Tf2Broadcaster.__init__(self, node=self) diff --git a/env_manager/env_manager/env_manager/utils/tf2_dynamic_broadcaster.py b/env_manager/env_manager/env_manager/utils/tf2_dynamic_broadcaster.py new file mode 100644 index 0000000..4737754 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/tf2_dynamic_broadcaster.py @@ -0,0 +1,74 @@ +import sys +from typing import Tuple + +import rclpy +from geometry_msgs.msg import TransformStamped +from rclpy.node import Node +from rclpy.parameter import Parameter +from tf2_ros import TransformBroadcaster + + +class Tf2DynamicBroadcaster: + def __init__( + self, + node: Node, + ): + + self._node = node + self.__tf2_broadcaster = TransformBroadcaster(node=self._node) + self._transform_stamped = TransformStamped() + + def broadcast_tf( + self, + parent_frame_id: str, + child_frame_id: str, + translation: Tuple[float, float, float], + rotation: Tuple[float, float, float, float], + xyzw: bool = True, + ): + """ + Broadcast transformation of the camera + """ + + self._transform_stamped.header.frame_id = parent_frame_id + self._transform_stamped.child_frame_id = child_frame_id + + self._transform_stamped.header.stamp = self._node.get_clock().now().to_msg() + + self._transform_stamped.transform.translation.x = float(translation[0]) + self._transform_stamped.transform.translation.y = float(translation[1]) + self._transform_stamped.transform.translation.z = float(translation[2]) + + if xyzw: + self._transform_stamped.transform.rotation.x = float(rotation[0]) + self._transform_stamped.transform.rotation.y = float(rotation[1]) + self._transform_stamped.transform.rotation.z = float(rotation[2]) + self._transform_stamped.transform.rotation.w = float(rotation[3]) + else: + self._transform_stamped.transform.rotation.w = float(rotation[0]) + self._transform_stamped.transform.rotation.x = float(rotation[1]) + self._transform_stamped.transform.rotation.y = float(rotation[2]) + self._transform_stamped.transform.rotation.z = float(rotation[3]) + + self.__tf2_broadcaster.sendTransform(self._transform_stamped) + + +class Tf2BroadcasterStandalone(Node, Tf2DynamicBroadcaster): + def __init__( + self, + node_name: str = "env_manager_tf_broadcaster", + use_sim_time: bool = True, + ): + + try: + rclpy.init() + except Exception as e: + if not rclpy.ok(): + sys.exit(f"ROS 2 context could not be initialised: {e}") + + Node.__init__(self, node_name) + self.set_parameters( + [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)] + ) + + Tf2DynamicBroadcaster.__init__(self, node=self) diff --git a/env_manager/env_manager/env_manager/utils/tf2_listener.py b/env_manager/env_manager/env_manager/utils/tf2_listener.py new file mode 100644 index 0000000..ef22bb0 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/tf2_listener.py @@ -0,0 +1,74 @@ +import sys +from typing import Optional + +import rclpy +from geometry_msgs.msg import Transform +from rclpy.node import Node +from rclpy.parameter import Parameter +from tf2_ros import Buffer, TransformListener + + +class Tf2Listener: + def __init__( + self, + node: Node, + ): + + self._node = node + + # Create tf2 buffer and listener for transform lookup + self.__tf2_buffer = Buffer() + TransformListener(buffer=self.__tf2_buffer, node=node) + + def lookup_transform_sync( + self, target_frame: str, source_frame: str, retry: bool = True + ) -> Optional[Transform]: + + try: + return self.__tf2_buffer.lookup_transform( + target_frame=target_frame, + source_frame=source_frame, + time=rclpy.time.Time(), + ).transform + except: + if retry: + while rclpy.ok(): + if self.__tf2_buffer.can_transform( + target_frame=target_frame, + source_frame=source_frame, + time=rclpy.time.Time(), + timeout=rclpy.time.Duration(seconds=1, nanoseconds=0), + ): + return self.__tf2_buffer.lookup_transform( + target_frame=target_frame, + source_frame=source_frame, + time=rclpy.time.Time(), + ).transform + + self._node.get_logger().warn( + f'Lookup of transform from "{source_frame}"' + f' to "{target_frame}" failed, retrying...' + ) + else: + return None + + +class Tf2ListenerStandalone(Node, Tf2Listener): + def __init__( + self, + node_name: str = "rbs_gym_tf_listener", + use_sim_time: bool = True, + ): + + try: + rclpy.init() + except Exception as e: + if not rclpy.ok(): + sys.exit(f"ROS 2 context could not be initialised: {e}") + + Node.__init__(self, node_name) + self.set_parameters( + [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)] + ) + + Tf2Listener.__init__(self, node=self) diff --git a/env_manager/env_manager/env_manager/utils/types.py b/env_manager/env_manager/env_manager/utils/types.py new file mode 100644 index 0000000..a502bd6 --- /dev/null +++ b/env_manager/env_manager/env_manager/utils/types.py @@ -0,0 +1,5 @@ +Point = tuple[float, float, float] +Quat = tuple[float, float, float, float] +Rpy = tuple[float, float, float] +Pose = tuple[Point, Quat] +PoseRpy = tuple[Point, Rpy] diff --git a/env_manager/env_manager/package.nix b/env_manager/env_manager/package.nix new file mode 100644 index 0000000..49bca32 --- /dev/null +++ b/env_manager/env_manager/package.nix @@ -0,0 +1,31 @@ +# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0 +# Copyright 2025 None +# Distributed under the terms of the Apache-2.0 license +{ + lib, + buildRosPackage, + ament-copyright, + ament-flake8, + ament-pep257, + gym-gz-ros-python, + pythonPackages, + scenario, +}: +let + rbs-assets-library = pythonPackages.callPackage ../../../../rbs_assets_library/default.nix {}; +in +buildRosPackage rec { + pname = "ros-jazzy-env-manager"; + version = "0.0.0"; + + src = ./.; + + buildType = "ament_python"; + checkInputs = [ament-copyright ament-flake8 ament-pep257 gym-gz-ros-python pythonPackages.pytest scenario]; + propagatedBuildInputs = [gym-gz-ros-python rbs-assets-library]; + + meta = { + description = "TODO: Package description"; + license = with lib.licenses; [asl20]; + }; +} diff --git a/env_manager/env_manager/package.xml b/env_manager/env_manager/package.xml new file mode 100644 index 0000000..5018612 --- /dev/null +++ b/env_manager/env_manager/package.xml @@ -0,0 +1,21 @@ + + + + env_manager + 0.0.0 + TODO: Package description + narmak + Apache-2.0 + + scenario + gym-gz-ros-python + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/env_manager/env_manager/resource/env_manager b/env_manager/env_manager/resource/env_manager new file mode 100644 index 0000000..e69de29 diff --git a/env_manager/env_manager/setup.cfg b/env_manager/env_manager/setup.cfg new file mode 100644 index 0000000..36219d7 --- /dev/null +++ b/env_manager/env_manager/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/env_manager +[install] +install_scripts=$base/lib/env_manager diff --git a/env_manager/env_manager/setup.py b/env_manager/env_manager/setup.py new file mode 100644 index 0000000..3e62b74 --- /dev/null +++ b/env_manager/env_manager/setup.py @@ -0,0 +1,25 @@ +from setuptools import find_packages, setup + +package_name = 'env_manager' + +setup( + name=package_name, + version='0.0.0', + packages=find_packages(exclude=['test']), + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ], + install_requires=['setuptools', 'trimesh'], + zip_safe=True, + maintainer='narmak', + maintainer_email='ur.narmak@gmail.com', + description='TODO: Package description', + license='Apache-2.0', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + ], + }, +) diff --git a/env_manager/env_manager/test/test_copyright.py b/env_manager/env_manager/test/test_copyright.py new file mode 100644 index 0000000..97a3919 --- /dev/null +++ b/env_manager/env_manager/test/test_copyright.py @@ -0,0 +1,25 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found errors' diff --git a/env_manager/env_manager/test/test_flake8.py b/env_manager/env_manager/test/test_flake8.py new file mode 100644 index 0000000..27ee107 --- /dev/null +++ b/env_manager/env_manager/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, \ + 'Found %d code style errors / warnings:\n' % len(errors) + \ + '\n'.join(errors) diff --git a/env_manager/env_manager/test/test_pep257.py b/env_manager/env_manager/test/test_pep257.py new file mode 100644 index 0000000..b234a38 --- /dev/null +++ b/env_manager/env_manager/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found code style errors / warnings' diff --git a/env_manager/env_manager_interfaces/CMakeLists.txt b/env_manager/env_manager_interfaces/CMakeLists.txt new file mode 100644 index 0000000..a0a5c75 --- /dev/null +++ b/env_manager/env_manager_interfaces/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.8) +project(env_manager_interfaces) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +# find dependencies +find_package(ament_cmake REQUIRED) +find_package(lifecycle_msgs REQUIRED) +find_package(rosidl_default_generators REQUIRED) +# uncomment the following section in order to fill in +# further dependencies manually. +# find_package( REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + srv/StartEnv.srv + srv/ConfigureEnv.srv + srv/LoadEnv.srv + srv/UnloadEnv.srv + srv/ResetEnv.srv + msg/EnvState.msg + DEPENDENCIES lifecycle_msgs +) + +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + # the following line skips the linter which checks for copyrights + # comment the line when a copyright and license is added to all source files + set(ament_cmake_copyright_FOUND TRUE) + # the following line skips cpplint (only works in a git repo) + # comment the line when this package is in a git repo and when + # a copyright and license is added to all source files + set(ament_cmake_cpplint_FOUND TRUE) + ament_lint_auto_find_test_dependencies() +endif() + +ament_package() diff --git a/env_manager/env_manager_interfaces/LICENSE b/env_manager/env_manager_interfaces/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/env_manager/env_manager_interfaces/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/env_manager/env_manager_interfaces/msg/EnvState.msg b/env_manager/env_manager_interfaces/msg/EnvState.msg new file mode 100644 index 0000000..fa4c82e --- /dev/null +++ b/env_manager/env_manager_interfaces/msg/EnvState.msg @@ -0,0 +1,4 @@ +string name +string type +string plugin_name +lifecycle_msgs/State state diff --git a/env_manager/env_manager_interfaces/package.nix b/env_manager/env_manager_interfaces/package.nix new file mode 100644 index 0000000..3402f72 --- /dev/null +++ b/env_manager/env_manager_interfaces/package.nix @@ -0,0 +1,23 @@ +# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0 + +# Copyright 2025 None +# Distributed under the terms of the Apache-2.0 license + +{ lib, buildRosPackage, ament-cmake, ament-lint-auto, ament-lint-common, builtin-interfaces, lifecycle-msgs, rosidl-default-generators, rosidl-default-runtime }: +buildRosPackage rec { + pname = "ros-jazzy-env-manager-interfaces"; + version = "0.0.0"; + + src = ./.; + + buildType = "ament_cmake"; + buildInputs = [ ament-cmake rosidl-default-generators ]; + checkInputs = [ ament-lint-auto ament-lint-common ]; + propagatedBuildInputs = [ builtin-interfaces lifecycle-msgs rosidl-default-runtime ]; + nativeBuildInputs = [ ament-cmake rosidl-default-generators ]; + + meta = { + description = "TODO: Package description"; + license = with lib.licenses; [ asl20 ]; + }; +} diff --git a/env_manager/env_manager_interfaces/package.xml b/env_manager/env_manager_interfaces/package.xml new file mode 100644 index 0000000..4fb3920 --- /dev/null +++ b/env_manager/env_manager_interfaces/package.xml @@ -0,0 +1,27 @@ + + + + env_manager_interfaces + 0.0.0 + TODO: Package description + solid-sinusoid + Apache-2.0 + + ament_cmake + rosidl_default_generators + + builtin_interfaces + lifecycle_msgs + + builtin_interfaces + lifecycle_msgs + rosidl_default_runtime + + ament_lint_auto + ament_lint_common + + rosidl_interface_packages + + ament_cmake + + diff --git a/env_manager/env_manager_interfaces/srv/ConfigureEnv.srv b/env_manager/env_manager_interfaces/srv/ConfigureEnv.srv new file mode 100644 index 0000000..4292382 --- /dev/null +++ b/env_manager/env_manager_interfaces/srv/ConfigureEnv.srv @@ -0,0 +1,3 @@ +string name +--- +bool ok \ No newline at end of file diff --git a/env_manager/env_manager_interfaces/srv/LoadEnv.srv b/env_manager/env_manager_interfaces/srv/LoadEnv.srv new file mode 100644 index 0000000..d6d46cc --- /dev/null +++ b/env_manager/env_manager_interfaces/srv/LoadEnv.srv @@ -0,0 +1,4 @@ +string name +string type +--- +bool ok \ No newline at end of file diff --git a/env_manager/env_manager_interfaces/srv/ResetEnv.srv b/env_manager/env_manager_interfaces/srv/ResetEnv.srv new file mode 100644 index 0000000..5d0cbc4 --- /dev/null +++ b/env_manager/env_manager_interfaces/srv/ResetEnv.srv @@ -0,0 +1,3 @@ + +--- +bool ok diff --git a/env_manager/env_manager_interfaces/srv/StartEnv.srv b/env_manager/env_manager_interfaces/srv/StartEnv.srv new file mode 100644 index 0000000..d6d46cc --- /dev/null +++ b/env_manager/env_manager_interfaces/srv/StartEnv.srv @@ -0,0 +1,4 @@ +string name +string type +--- +bool ok \ No newline at end of file diff --git a/env_manager/env_manager_interfaces/srv/UnloadEnv.srv b/env_manager/env_manager_interfaces/srv/UnloadEnv.srv new file mode 100644 index 0000000..4292382 --- /dev/null +++ b/env_manager/env_manager_interfaces/srv/UnloadEnv.srv @@ -0,0 +1,3 @@ +string name +--- +bool ok \ No newline at end of file diff --git a/env_manager/rbs_gym/LICENSE b/env_manager/rbs_gym/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/env_manager/rbs_gym/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/env_manager/rbs_gym/hyperparams/sac.yml b/env_manager/rbs_gym/hyperparams/sac.yml new file mode 100644 index 0000000..d372fc7 --- /dev/null +++ b/env_manager/rbs_gym/hyperparams/sac.yml @@ -0,0 +1,42 @@ +# Reach +Reach-Gazebo-v0: + policy: "MlpPolicy" + policy_kwargs: + n_critics: 2 + net_arch: [128, 64] + n_timesteps: 200000 + buffer_size: 25000 + learning_starts: 5000 + batch_size: 512 + learning_rate: lin_0.0002 + gamma: 0.95 + tau: 0.001 + ent_coef: "auto_0.1" + target_entropy: "auto" + train_freq: [1, "episode"] + gradient_steps: 100 + noise_type: "normal" + noise_std: 0.025 + use_sde: False + optimize_memory_usage: False + +Reach-ColorImage-Gazebo-v0: + policy: "CnnPolicy" + policy_kwargs: + n_critics: 2 + net_arch: [128, 128] + n_timesteps: 50000 + buffer_size: 25000 + learning_starts: 5000 + batch_size: 32 + learning_rate: lin_0.0002 + gamma: 0.95 + tau: 0.0005 + ent_coef: "auto_0.1" + target_entropy: "auto" + train_freq: [1, "episode"] + gradient_steps: 100 + noise_type: "normal" + noise_std: 0.025 + use_sde: False + optimize_memory_usage: False diff --git a/env_manager/rbs_gym/hyperparams/td3.yml b/env_manager/rbs_gym/hyperparams/td3.yml new file mode 100644 index 0000000..6f57e0a --- /dev/null +++ b/env_manager/rbs_gym/hyperparams/td3.yml @@ -0,0 +1,39 @@ +Reach-Gazebo-v0: + policy: "MlpPolicy" + policy_kwargs: + n_critics: 2 + net_arch: [128, 64] + n_timesteps: 200000 + buffer_size: 25000 + learning_starts: 5000 + batch_size: 512 + learning_rate: lin_0.0002 + gamma: 0.95 + tau: 0.001 + train_freq: [1, "episode"] + gradient_steps: 100 + target_policy_noise: 0.1 + target_noise_clip: 0.2 + noise_type: "normal" + noise_std: 0.025 + optimize_memory_usage: False + +Reach-ColorImage-Gazebo-v0: + policy: "CnnPolicy" + policy_kwargs: + n_critics: 2 + net_arch: [128, 128] + n_timesteps: 50000 + buffer_size: 25000 + learning_starts: 5000 + batch_size: 32 + learning_rate: lin_0.0002 + gamma: 0.95 + tau: 0.0005 + train_freq: [1, "episode"] + gradient_steps: 100 + target_policy_noise: 0.1 + target_noise_clip: 0.2 + noise_type: "normal" + noise_std: 0.025 + optimize_memory_usage: True diff --git a/env_manager/rbs_gym/hyperparams/tqc.yml b/env_manager/rbs_gym/hyperparams/tqc.yml new file mode 100644 index 0000000..7e483af --- /dev/null +++ b/env_manager/rbs_gym/hyperparams/tqc.yml @@ -0,0 +1,46 @@ +# Reach +Reach-Gazebo-v0: + policy: "MlpPolicy" + policy_kwargs: + n_quantiles: 25 + n_critics: 2 + net_arch: [128, 64] + n_timesteps: 200000 + buffer_size: 25000 + learning_starts: 5000 + batch_size: 512 + learning_rate: lin_0.0002 + gamma: 0.95 + tau: 0.001 + ent_coef: "auto_0.1" + target_entropy: "auto" + top_quantiles_to_drop_per_net: 2 + train_freq: [1, "episode"] + gradient_steps: 100 + noise_type: "normal" + noise_std: 0.025 + use_sde: False + optimize_memory_usage: False + +Reach-ColorImage-Gazebo-v0: + policy: "CnnPolicy" + policy_kwargs: + n_quantiles: 25 + n_critics: 2 + net_arch: [128, 128] + n_timesteps: 50000 + buffer_size: 25000 + learning_starts: 5000 + batch_size: 32 + learning_rate: lin_0.0002 + gamma: 0.95 + tau: 0.0005 + ent_coef: "auto_0.1" + target_entropy: "auto" + top_quantiles_to_drop_per_net: 2 + train_freq: [1, "episode"] + gradient_steps: 100 + noise_type: "normal" + noise_std: 0.025 + use_sde: False + optimize_memory_usage: True diff --git a/env_manager/rbs_gym/launch/evaluate.launch.py b/env_manager/rbs_gym/launch/evaluate.launch.py new file mode 100644 index 0000000..4e69f23 --- /dev/null +++ b/env_manager/rbs_gym/launch/evaluate.launch.py @@ -0,0 +1,426 @@ +from launch import LaunchContext, LaunchDescription +from launch.actions import ( + DeclareLaunchArgument, + IncludeLaunchDescription, + OpaqueFunction, + SetEnvironmentVariable, + TimerAction +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch.substitutions import LaunchConfiguration, PathJoinSubstitution +from launch_ros.substitutions import FindPackageShare +from launch_ros.actions import Node +import os +from os import cpu_count +from ament_index_python.packages import get_package_share_directory + +def launch_setup(context, *args, **kwargs): + # Initialize Arguments + robot_type = LaunchConfiguration("robot_type") + # General arguments + with_gripper_condition = LaunchConfiguration("with_gripper") + controllers_file = LaunchConfiguration("controllers_file") + cartesian_controllers = LaunchConfiguration("cartesian_controllers") + description_package = LaunchConfiguration("description_package") + description_file = LaunchConfiguration("description_file") + robot_name = LaunchConfiguration("robot_name") + start_joint_controller = LaunchConfiguration("start_joint_controller") + initial_joint_controller = LaunchConfiguration("initial_joint_controller") + launch_simulation = LaunchConfiguration("launch_sim") + launch_moveit = LaunchConfiguration("launch_moveit") + launch_task_planner = LaunchConfiguration("launch_task_planner") + launch_perception = LaunchConfiguration("launch_perception") + moveit_config_package = LaunchConfiguration("moveit_config_package") + moveit_config_file = LaunchConfiguration("moveit_config_file") + use_sim_time = LaunchConfiguration("use_sim_time") + sim_gazebo = LaunchConfiguration("sim_gazebo") + hardware = LaunchConfiguration("hardware") + env_manager = LaunchConfiguration("env_manager") + launch_controllers = LaunchConfiguration("launch_controllers") + gazebo_gui = LaunchConfiguration("gazebo_gui") + gripper_name = LaunchConfiguration("gripper_name") + + # training arguments + env = LaunchConfiguration("env") + algo = LaunchConfiguration("algo") + num_threads = LaunchConfiguration("num_threads") + seed = LaunchConfiguration("seed") + log_folder = LaunchConfiguration("log_folder") + verbose = LaunchConfiguration("verbose") + # use_sim_time = LaunchConfiguration("use_sim_time") + log_level = LaunchConfiguration("log_level") + env_kwargs = LaunchConfiguration("env_kwargs") + + n_episodes = LaunchConfiguration("n_episodes") + exp_id = LaunchConfiguration("exp_id") + load_best = LaunchConfiguration("load_best") + load_checkpoint = LaunchConfiguration("load_checkpoint") + stochastic = LaunchConfiguration("stochastic") + reward_log = LaunchConfiguration("reward_log") + norm_reward = LaunchConfiguration("norm_reward") + no_render = LaunchConfiguration("no_render") + + + sim_gazebo = LaunchConfiguration("sim_gazebo") + launch_simulation = LaunchConfiguration("launch_sim") + + initial_joint_controllers_file_path = os.path.join( + get_package_share_directory('rbs_arm'), 'config', 'controllers.yaml' + ) + + single_robot_setup = IncludeLaunchDescription( + PythonLaunchDescriptionSource([ + PathJoinSubstitution([ + FindPackageShare('rbs_bringup'), + "launch", + "rbs_robot.launch.py" + ]) + ]), + launch_arguments={ + "env_manager": env_manager, + "with_gripper": with_gripper_condition, + "gripper_name": gripper_name, + # "controllers_file": controllers_file, + "robot_type": robot_type, + "controllers_file": initial_joint_controllers_file_path, + "cartesian_controllers": cartesian_controllers, + "description_package": description_package, + "description_file": description_file, + "robot_name": robot_name, + "start_joint_controller": start_joint_controller, + "initial_joint_controller": initial_joint_controller, + "launch_simulation": launch_simulation, + "launch_moveit": launch_moveit, + "launch_task_planner": launch_task_planner, + "launch_perception": launch_perception, + "moveit_config_package": moveit_config_package, + "moveit_config_file": moveit_config_file, + "use_sim_time": use_sim_time, + "sim_gazebo": sim_gazebo, + "hardware": hardware, + "launch_controllers": launch_controllers, + # "gazebo_gui": gazebo_gui + }.items() + ) + + args = [ + "--env", + env, + "--env-kwargs", + env_kwargs, + "--algo", + algo, + "--seed", + seed, + "--num-threads", + num_threads, + "--n-episodes", + n_episodes, + "--log-folder", + log_folder, + "--exp-id", + exp_id, + "--load-best", + load_best, + "--load-checkpoint", + load_checkpoint, + "--stochastic", + stochastic, + "--reward-log", + reward_log, + "--norm-reward", + norm_reward, + "--no-render", + no_render, + "--verbose", + verbose, + "--ros-args", + "--log-level", + log_level, + ] + + rl_task = Node( + package="rbs_gym", + executable="evaluate", + output="log", + arguments=args, + parameters=[{"use_sim_time": use_sim_time}], + ) + + + + delay_robot_control_stack = TimerAction( + period=10.0, + actions=[single_robot_setup] + ) + + nodes_to_start = [ + # env, + rl_task, + delay_robot_control_stack + ] + return nodes_to_start + + +def generate_launch_description(): + declared_arguments = [] + declared_arguments.append( + DeclareLaunchArgument( + "robot_type", + description="Type of robot by name", + choices=["rbs_arm","ur3", "ur3e", "ur5", "ur5e", "ur10", "ur10e", "ur16e"], + default_value="rbs_arm", + ) + ) + # General arguments + declared_arguments.append( + DeclareLaunchArgument( + "controllers_file", + default_value="controllers_gazebosim.yaml", + description="YAML file with the controllers configuration.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_package", + default_value="rbs_arm", + description="Description package with robot URDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom description.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_file", + default_value="rbs_arm_modular.xacro", + description="URDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "robot_name", + default_value="arm0", + description="Name for robot, used to apply namespace for specific robot in multirobot setup", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "start_joint_controller", + default_value="false", + description="Enable headless mode for robot control", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "initial_joint_controller", + default_value="joint_trajectory_controller", + description="Robot controller to start.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_package", + default_value="rbs_arm", + description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom moveit config.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_file", + default_value="rbs_arm.srdf.xacro", + description="MoveIt SRDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_sim_time", + default_value="true", + description="Make MoveIt to use simulation time.\ + This is needed for the trajectory planing in simulation.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "gripper_name", + default_value="rbs_gripper", + choices=["rbs_gripper", ""], + description="choose gripper by name (leave empty if hasn't)", + ) + ) + declared_arguments.append( + DeclareLaunchArgument("with_gripper", + default_value="true", + description="With gripper or not?") + ) + declared_arguments.append( + DeclareLaunchArgument("sim_gazebo", + default_value="true", + description="Gazebo Simulation") + ) + declared_arguments.append( + DeclareLaunchArgument("env_manager", + default_value="false", + description="Launch env_manager?") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_sim", + default_value="true", + description="Launch simulator (Gazebo)?\ + Most general arg") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_moveit", + default_value="false", + description="Launch moveit?") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_perception", + default_value="false", + description="Launch perception?") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_task_planner", + default_value="false", + description="Launch task_planner?") + ) + declared_arguments.append( + DeclareLaunchArgument("cartesian_controllers", + default_value="true", + description="Load cartesian\ + controllers?") + ) + declared_arguments.append( + DeclareLaunchArgument("hardware", + choices=["gazebo", "mock"], + default_value="gazebo", + description="Choose your harware_interface") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_controllers", + default_value="true", + description="Launch controllers?") + ) + declared_arguments.append( + DeclareLaunchArgument("gazebo_gui", + default_value="true", + description="Launch gazebo with gui?") + ) + # training arguments + declared_arguments.append( + DeclareLaunchArgument( + "env", + default_value="Reach-Gazebo-v0", + description="Environment ID", + )) + declared_arguments.append( + DeclareLaunchArgument( + "env_kwargs", + default_value="", + description="Optional keyword argument to pass to the env constructor.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "vec_env", + default_value="dummy", + description="Type of VecEnv to use (dummy or subproc).", + )) + # Algorithm and training + declared_arguments.append( + DeclareLaunchArgument( + "algo", + default_value="sac", + description="RL algorithm to use during the training.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "num_threads", + default_value="-1", + description="Number of threads for PyTorch (-1 to use default).", + )) + # Random seed + declared_arguments.append( + DeclareLaunchArgument( + "seed", + default_value="84", + description="Random generator seed.", + )) + # Logging + declared_arguments.append( + DeclareLaunchArgument( + "log_folder", + default_value="logs", + description="Path to the log directory.", + )) + # Verbosity + declared_arguments.append( + DeclareLaunchArgument( + "verbose", + default_value="1", + description="Verbose mode (0: no output, 1: INFO).", + )) + # HER specifics + declared_arguments.append( + DeclareLaunchArgument( + "log_level", + default_value="error", + description="The level of logging that is applied to all ROS 2 nodes launched by this script.", + )) + + + declared_arguments.append( + DeclareLaunchArgument( + "n_episodes", + default_value="200", + description="Number of evaluation episodes.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "exp_id", + default_value="0", + description="Experiment ID (default: 0: latest, -1: no exp folder).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "load_best", + default_value="false", + description="Load best model instead of last model if available.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "load_checkpoint", + default_value="0", + description="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "stochastic", + default_value="false", + description="Use stochastic actions instead of deterministic.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "reward_log", + default_value="reward_logs", + description="Where to log reward.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "norm_reward", + default_value="false", + description="Normalize reward if applicable (trained with VecNormalize)", + )) + declared_arguments.append( + DeclareLaunchArgument( + "no_render", + default_value="true", + description="Do not render the environment (useful for tests).", + )) + + + + env_variables = [ + SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"), + SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2)) + ] + + return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables) diff --git a/env_manager/rbs_gym/launch/optimize.launch.py b/env_manager/rbs_gym/launch/optimize.launch.py new file mode 100644 index 0000000..7e06a9c --- /dev/null +++ b/env_manager/rbs_gym/launch/optimize.launch.py @@ -0,0 +1,519 @@ +from launch import LaunchDescription +from launch.actions import ( + DeclareLaunchArgument, + IncludeLaunchDescription, + OpaqueFunction, + TimerAction +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch.substitutions import LaunchConfiguration, PathJoinSubstitution +from launch_ros.substitutions import FindPackageShare +from launch_ros.actions import Node +import os +from os import cpu_count +from ament_index_python.packages import get_package_share_directory + +def launch_setup(context, *args, **kwargs): + # Initialize Arguments + robot_type = LaunchConfiguration("robot_type") + # General arguments + with_gripper_condition = LaunchConfiguration("with_gripper") + controllers_file = LaunchConfiguration("controllers_file") + cartesian_controllers = LaunchConfiguration("cartesian_controllers") + description_package = LaunchConfiguration("description_package") + description_file = LaunchConfiguration("description_file") + robot_name = LaunchConfiguration("robot_name") + start_joint_controller = LaunchConfiguration("start_joint_controller") + initial_joint_controller = LaunchConfiguration("initial_joint_controller") + launch_simulation = LaunchConfiguration("launch_sim") + launch_moveit = LaunchConfiguration("launch_moveit") + launch_task_planner = LaunchConfiguration("launch_task_planner") + launch_perception = LaunchConfiguration("launch_perception") + moveit_config_package = LaunchConfiguration("moveit_config_package") + moveit_config_file = LaunchConfiguration("moveit_config_file") + use_sim_time = LaunchConfiguration("use_sim_time") + sim_gazebo = LaunchConfiguration("sim_gazebo") + hardware = LaunchConfiguration("hardware") + env_manager = LaunchConfiguration("env_manager") + launch_controllers = LaunchConfiguration("launch_controllers") + gripper_name = LaunchConfiguration("gripper_name") + + # training arguments + env = LaunchConfiguration("env") + env_kwargs = LaunchConfiguration("env_kwargs") + algo = LaunchConfiguration("algo") + hyperparams = LaunchConfiguration("hyperparams") + n_timesteps = LaunchConfiguration("n_timesteps") + num_threads = LaunchConfiguration("num_threads") + seed = LaunchConfiguration("seed") + preload_replay_buffer = LaunchConfiguration("preload_replay_buffer") + log_folder = LaunchConfiguration("log_folder") + tensorboard_log = LaunchConfiguration("tensorboard_log") + log_interval = LaunchConfiguration("log_interval") + uuid = LaunchConfiguration("uuid") + eval_episodes = LaunchConfiguration("eval_episodes") + verbose = LaunchConfiguration("verbose") + truncate_last_trajectory = LaunchConfiguration("truncate_last_trajectory") + use_sim_time = LaunchConfiguration("use_sim_time") + log_level = LaunchConfiguration("log_level") + sampler = LaunchConfiguration("sampler") + pruner = LaunchConfiguration("pruner") + n_trials = LaunchConfiguration("n_trials") + n_startup_trials = LaunchConfiguration("n_startup_trials") + n_evaluations = LaunchConfiguration("n_evaluations") + n_jobs = LaunchConfiguration("n_jobs") + storage = LaunchConfiguration("storage") + study_name = LaunchConfiguration("study_name") + + + sim_gazebo = LaunchConfiguration("sim_gazebo") + launch_simulation = LaunchConfiguration("launch_sim") + + initial_joint_controllers_file_path = os.path.join( + get_package_share_directory('rbs_arm'), 'config', 'controllers.yaml' + ) + + single_robot_setup = IncludeLaunchDescription( + PythonLaunchDescriptionSource([ + PathJoinSubstitution([ + FindPackageShare('rbs_bringup'), + "launch", + "rbs_robot.launch.py" + ]) + ]), + launch_arguments={ + "env_manager": env_manager, + "with_gripper": with_gripper_condition, + "gripper_name": gripper_name, + "controllers_file": controllers_file, + "robot_type": robot_type, + "controllers_file": initial_joint_controllers_file_path, + "cartesian_controllers": cartesian_controllers, + "description_package": description_package, + "description_file": description_file, + "robot_name": robot_name, + "start_joint_controller": start_joint_controller, + "initial_joint_controller": initial_joint_controller, + "launch_simulation": launch_simulation, + "launch_moveit": launch_moveit, + "launch_task_planner": launch_task_planner, + "launch_perception": launch_perception, + "moveit_config_package": moveit_config_package, + "moveit_config_file": moveit_config_file, + "use_sim_time": use_sim_time, + "sim_gazebo": sim_gazebo, + "hardware": hardware, + "launch_controllers": launch_controllers, + # "gazebo_gui": gazebo_gui + }.items() + ) + + args = [ + "--env", + env, + "--env-kwargs", + env_kwargs, + "--algo", + algo, + "--seed", + seed, + "--num-threads", + num_threads, + "--n-timesteps", + n_timesteps, + "--preload-replay-buffer", + preload_replay_buffer, + "--log-folder", + log_folder, + "--tensorboard-log", + tensorboard_log, + "--log-interval", + log_interval, + "--uuid", + uuid, + "--optimize-hyperparameters", + "True", + "--sampler", + sampler, + "--pruner", + pruner, + "--n-trials", + n_trials, + "--n-startup-trials", + n_startup_trials, + "--n-evaluations", + n_evaluations, + "--n-jobs", + n_jobs, + "--storage", + storage, + "--study-name", + study_name, + "--eval-episodes", + eval_episodes, + "--verbose", + verbose, + "--truncate-last-trajectory", + truncate_last_trajectory, + "--ros-args", + "--log-level", + log_level, + ] + + rl_task = Node( + package="rbs_gym", + executable="train", + output="log", + arguments = args, + parameters=[{"use_sim_time": True}] + ) + + + delay_robot_control_stack = TimerAction( + period=10.0, + actions=[single_robot_setup] + ) + + nodes_to_start = [ + rl_task, + delay_robot_control_stack + ] + return nodes_to_start + + +def generate_launch_description(): + declared_arguments = [] + declared_arguments.append( + DeclareLaunchArgument( + "robot_type", + description="Type of robot by name", + choices=["rbs_arm","ur3", "ur3e", "ur5", "ur5e", "ur10", "ur10e", "ur16e"], + default_value="rbs_arm", + ) + ) + # General arguments + declared_arguments.append( + DeclareLaunchArgument( + "controllers_file", + default_value="controllers_gazebosim.yaml", + description="YAML file with the controllers configuration.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_package", + default_value="rbs_arm", + description="Description package with robot URDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom description.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_file", + default_value="rbs_arm_modular.xacro", + description="URDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "robot_name", + default_value="arm0", + description="Name for robot, used to apply namespace for specific robot in multirobot setup", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "start_joint_controller", + default_value="false", + description="Enable headless mode for robot control", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "initial_joint_controller", + default_value="joint_trajectory_controller", + description="Robot controller to start.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_package", + default_value="rbs_arm", + description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom moveit config.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_file", + default_value="rbs_arm.srdf.xacro", + description="MoveIt SRDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_sim_time", + default_value="true", + description="Make MoveIt to use simulation time.\ + This is needed for the trajectory planing in simulation.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "gripper_name", + default_value="rbs_gripper", + choices=["rbs_gripper", ""], + description="choose gripper by name (leave empty if hasn't)", + ) + ) + declared_arguments.append( + DeclareLaunchArgument("with_gripper", + default_value="true", + description="With gripper or not?") + ) + declared_arguments.append( + DeclareLaunchArgument("sim_gazebo", + default_value="true", + description="Gazebo Simulation") + ) + declared_arguments.append( + DeclareLaunchArgument("env_manager", + default_value="false", + description="Launch env_manager?") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_sim", + default_value="true", + description="Launch simulator (Gazebo)?\ + Most general arg") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_moveit", + default_value="false", + description="Launch moveit?") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_perception", + default_value="false", + description="Launch perception?") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_task_planner", + default_value="false", + description="Launch task_planner?") + ) + declared_arguments.append( + DeclareLaunchArgument("cartesian_controllers", + default_value="true", + description="Load cartesian\ + controllers?") + ) + declared_arguments.append( + DeclareLaunchArgument("hardware", + choices=["gazebo", "mock"], + default_value="gazebo", + description="Choose your harware_interface") + ) + declared_arguments.append( + DeclareLaunchArgument("launch_controllers", + default_value="true", + description="Launch controllers?") + ) + declared_arguments.append( + DeclareLaunchArgument("gazebo_gui", + default_value="true", + description="Launch gazebo with gui?") + ) + # training arguments + declared_arguments.append( + DeclareLaunchArgument( + "env", + default_value="Reach-Gazebo-v0", + description="Environment ID", + )) + declared_arguments.append( + DeclareLaunchArgument( + "env_kwargs", + default_value="", + description="Optional keyword argument to pass to the env constructor.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "vec_env", + default_value="dummy", + description="Type of VecEnv to use (dummy or subproc).", + )) + # Algorithm and training + declared_arguments.append( + DeclareLaunchArgument( + "algo", + default_value="sac", + description="RL algorithm to use during the training.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "n_timesteps", + default_value="-1", + description="Overwrite the number of timesteps.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "hyperparams", + default_value="", + description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "num_threads", + default_value="-1", + description="Number of threads for PyTorch (-1 to use default).", + )) + # Continue training an already trained agent + declared_arguments.append( + DeclareLaunchArgument( + "trained_agent", + default_value="", + description="Path to a pretrained agent to continue training.", + )) + # Random seed + declared_arguments.append( + DeclareLaunchArgument( + "seed", + default_value="84", + description="Random generator seed.", + )) + # Saving of model + declared_arguments.append( + DeclareLaunchArgument( + "save_freq", + default_value="10000", + description="Save the model every n steps (if negative, no checkpoint).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "save_replay_buffer", + default_value="False", + description="Save the replay buffer too (when applicable).", + )) + # Pre-load a replay buffer and start training on it + declared_arguments.append( + DeclareLaunchArgument( + "preload_replay_buffer", + default_value="", + description="Path to a replay buffer that should be preloaded before starting the training process.", + )) + # Logging + declared_arguments.append( + DeclareLaunchArgument( + "log_folder", + default_value="logs", + description="Path to the log directory.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "tensorboard_log", + default_value="tensorboard_logs", + description="Tensorboard log dir.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "log_interval", + default_value="-1", + description="Override log interval (default: -1, no change).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "uuid", + default_value="False", + description="Ensure that the run has a unique ID.", + )) + + declared_arguments.append( + DeclareLaunchArgument( + "sampler", + default_value="tpe", + description="Sampler to use when optimizing hyperparameters (random, tpe or skopt).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "pruner", + default_value="median", + description="Pruner to use when optimizing hyperparameters (halving, median or none).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "n_trials", + default_value="10", + description="Number of trials for optimizing hyperparameters.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "n_startup_trials", + default_value="5", + description="Number of trials before using optuna sampler.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "n_evaluations", + default_value="2", + description="Number of evaluations for hyperparameter optimization.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "n_jobs", + default_value="1", + description="Number of parallel jobs when optimizing hyperparameters.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "storage", + default_value="", + description="Database storage path if distributed optimization should be used.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "study_name", + default_value="", + description="Study name for distributed optimization.", + )) + + # Evaluation + declared_arguments.append( + DeclareLaunchArgument( + "eval_freq", + default_value="-1", + description="Evaluate the agent every n steps (if negative, no evaluation).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "eval_episodes", + default_value="5", + description="Number of episodes to use for evaluation.", + )) + # Verbosity + declared_arguments.append( + DeclareLaunchArgument( + "verbose", + default_value="1", + description="Verbose mode (0: no output, 1: INFO).", + )) + # HER specifics + declared_arguments.append( + DeclareLaunchArgument( + "truncate_last_trajectory", + default_value="True", + description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer." + )), + declared_arguments.append( + DeclareLaunchArgument( + "log_level", + default_value="error", + description="The level of logging that is applied to all ROS 2 nodes launched by this script.", + )) + +# env_variables = [ +# SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"), +# SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2)) +# ] + + return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)]) diff --git a/env_manager/rbs_gym/launch/test_env.launch.py b/env_manager/rbs_gym/launch/test_env.launch.py new file mode 100644 index 0000000..9cfbdde --- /dev/null +++ b/env_manager/rbs_gym/launch/test_env.launch.py @@ -0,0 +1,507 @@ +import os +from os import cpu_count + +import xacro +import yaml +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import ( + DeclareLaunchArgument, + IncludeLaunchDescription, + OpaqueFunction, + SetEnvironmentVariable, + TimerAction, +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch.substitutions import LaunchConfiguration, PathJoinSubstitution +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare +from robot_builder.external.ros2_control import ControllerManager +from robot_builder.parser.urdf import URDF_parser + + +def launch_setup(context, *args, **kwargs): + # Initialize Arguments + robot_type = LaunchConfiguration("robot_type") + # General arguments + with_gripper_condition = LaunchConfiguration("with_gripper") + description_package = LaunchConfiguration("description_package") + description_file = LaunchConfiguration("description_file") + use_moveit = LaunchConfiguration("use_moveit") + moveit_config_package = LaunchConfiguration("moveit_config_package") + moveit_config_file = LaunchConfiguration("moveit_config_file") + use_sim_time = LaunchConfiguration("use_sim_time") + scene_config_file = LaunchConfiguration("scene_config_file").perform(context) + + ee_link_name = LaunchConfiguration("ee_link_name").perform(context) + base_link_name = LaunchConfiguration("base_link_name").perform(context) + + control_space = LaunchConfiguration("control_space").perform(context) + control_strategy = LaunchConfiguration("control_strategy").perform(context) + use_rbs_utils = LaunchConfiguration("use_rbs_utils") + + interactive = LaunchConfiguration("interactive").perform(context) + + # training arguments + env = LaunchConfiguration("env") + use_sim_time = LaunchConfiguration("use_sim_time") + log_level = LaunchConfiguration("log_level") + env_kwargs = LaunchConfiguration("env_kwargs") + + description_package_abs_path = get_package_share_directory( + description_package.perform(context) + ) + + xacro_file = os.path.join( + description_package_abs_path, + "urdf", + description_file.perform(context), + ) + + if not scene_config_file == "": + config_file = {"config_file": scene_config_file} + else: + config_file = {} + + description_package_abs_path = get_package_share_directory( + description_package.perform(context) + ) + + controllers_file = os.path.join( + description_package_abs_path, "config", "controllers.yaml" + ) + + xacro_file = os.path.join( + description_package_abs_path, + "urdf", + description_file.perform(context), + ) + + # xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml" + xacro_config_file = os.path.join( + description_package_abs_path, "urdf", "xacro_args.yaml" + ) + + # TODO: hide this to another place + # Load xacro_args + def param_constructor(loader, node, local_vars): + value = loader.construct_scalar(node) + return LaunchConfiguration(value).perform( + local_vars.get("context", "Launch context if not defined") + ) + + def variable_constructor(loader, node, local_vars): + value = loader.construct_scalar(node) + return local_vars.get(value, f"Variable '{value}' not found") + + def load_xacro_args(yaml_file, local_vars): + # Get valut from ros2 argument + yaml.add_constructor( + "!param", lambda loader, node: param_constructor(loader, node, local_vars) + ) + + # Get value from local variable in this code + # The local variable should be initialized before the loader was called + yaml.add_constructor( + "!variable", + lambda loader, node: variable_constructor(loader, node, local_vars), + ) + + with open(yaml_file, "r") as file: + return yaml.load(file, Loader=yaml.FullLoader) + + mappings_data = load_xacro_args(xacro_config_file, locals()) + + robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data) + + robot_description_semantic_content = "" + + robot_description_content = robot_description_doc.toprettyxml(indent=" ") + robot_description = {"robot_description": robot_description_content} + + # Parse robot and configure controller's file for ControllerManager + robot = URDF_parser.load_string( + robot_description_content, ee_link_name=ee_link_name + ) + ControllerManager.save_to_yaml( + robot, description_package_abs_path, "controllers.yaml" + ) + + single_robot_setup = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + PathJoinSubstitution( + [FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"] + ) + ] + ), + launch_arguments={ + "with_gripper": with_gripper_condition, + "controllers_file": controllers_file, + "robot_type": robot_type, + "description_package": description_package, + "description_file": description_file, + "robot_name": robot_type, + "use_moveit": "false", + "moveit_config_package": moveit_config_package, + "moveit_config_file": moveit_config_file, + "use_sim_time": "true", + "use_skills": "false", + "use_controllers": "true", + "robot_description": robot_description_content, + "robot_description_semantic": robot_description_semantic_content, + "base_link_name": base_link_name, + "ee_link_name": ee_link_name, + "control_space": control_space, + "control_strategy": control_strategy, + "use_rbs_utils": use_rbs_utils, + "interactive_control": "false", + }.items(), + ) + + args = [ + "--env", + env, + "--env-kwargs", + env_kwargs, + "--ros-args", + "--log-level", + log_level, + ] + + rl_task = Node( + package="rbs_gym", + executable="test_agent", + output="log", + arguments=args, + parameters=[{"use_sim_time": True}, robot_description], + ) + + clock_bridge = Node( + package="ros_gz_bridge", + executable="parameter_bridge", + arguments=["/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock"], + output="screen", + ) + + delay_robot_control_stack = TimerAction(period=10.0, actions=[single_robot_setup]) + + nodes_to_start = [ + # env, + rl_task, + clock_bridge, + delay_robot_control_stack, + ] + return nodes_to_start + + +def generate_launch_description(): + declared_arguments = [] + declared_arguments.append( + DeclareLaunchArgument( + "robot_type", + description="Type of robot by name", + choices=[ + "rbs_arm", + "ar4", + "ur3", + "ur3e", + "ur5", + "ur5e", + "ur10", + "ur10e", + "ur16e", + ], + default_value="rbs_arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_package", + default_value="rbs_arm", + description="Description package with robot URDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom description.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_file", + default_value="rbs_arm_modular.xacro", + description="URDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "robot_name", + default_value="arm0", + description="Name for robot, used to apply namespace for specific robot in multirobot setup", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_package", + default_value="rbs_arm", + description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom moveit config.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_file", + default_value="rbs_arm.srdf.xacro", + description="MoveIt SRDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_sim_time", + default_value="true", + description="Make MoveIt to use simulation time.\ + This is needed for the trajectory planing in simulation.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "with_gripper", default_value="true", description="With gripper or not?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_moveit", default_value="false", description="Launch moveit?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "launch_perception", default_value="false", description="Launch perception?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_controllers", + default_value="true", + description="Launch controllers?", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "scene_config_file", + default_value="", + description="Path to a scene configuration file", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "ee_link_name", + default_value="", + description="End effector name of robot arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "base_link_name", + default_value="", + description="Base link name if robot arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "control_space", + default_value="task", + choices=["task", "joint"], + description="Specify the control space for the robot (e.g., task space).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "control_strategy", + default_value="position", + choices=["position", "velocity", "effort"], + description="Specify the control strategy (e.g., position control).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "interactive", + default_value="true", + description="Wheter to run the motion_control_handle controller", + ), + ) + + declared_arguments.append( + DeclareLaunchArgument( + "use_rbs_utils", + default_value="false", + description="Wheter to use rbs_utils", + ), + ) + # training arguments + declared_arguments.append( + DeclareLaunchArgument( + "env", + default_value="Reach-Gazebo-v0", + description="Environment ID", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "env_kwargs", + default_value="", + description="Optional keyword argument to pass to the env constructor.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "vec_env", + default_value="dummy", + description="Type of VecEnv to use (dummy or subproc).", + ) + ) + # Algorithm and training + declared_arguments.append( + DeclareLaunchArgument( + "algo", + default_value="sac", + description="RL algorithm to use during the training.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "n_timesteps", + default_value="-1", + description="Overwrite the number of timesteps.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "hyperparams", + default_value="", + description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "num_threads", + default_value="-1", + description="Number of threads for PyTorch (-1 to use default).", + ) + ) + # Continue training an already trained agent + declared_arguments.append( + DeclareLaunchArgument( + "trained_agent", + default_value="", + description="Path to a pretrained agent to continue training.", + ) + ) + # Random seed + declared_arguments.append( + DeclareLaunchArgument( + "seed", + default_value="-1", + description="Random generator seed.", + ) + ) + # Saving of model + declared_arguments.append( + DeclareLaunchArgument( + "save_freq", + default_value="10000", + description="Save the model every n steps (if negative, no checkpoint).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "save_replay_buffer", + default_value="False", + description="Save the replay buffer too (when applicable).", + ) + ) + # Pre-load a replay buffer and start training on it + declared_arguments.append( + DeclareLaunchArgument( + "preload_replay_buffer", + default_value="", + description="Path to a replay buffer that should be preloaded before starting the training process.", + ) + ) + # Logging + declared_arguments.append( + DeclareLaunchArgument( + "log_folder", + default_value="logs", + description="Path to the log directory.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "tensorboard_log", + default_value="tensorboard_logs", + description="Tensorboard log dir.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "log_interval", + default_value="-1", + description="Override log interval (default: -1, no change).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "uuid", + default_value="False", + description="Ensure that the run has a unique ID.", + ) + ) + # Evaluation + declared_arguments.append( + DeclareLaunchArgument( + "eval_freq", + default_value="-1", + description="Evaluate the agent every n steps (if negative, no evaluation).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "eval_episodes", + default_value="5", + description="Number of episodes to use for evaluation.", + ) + ) + # Verbosity + declared_arguments.append( + DeclareLaunchArgument( + "verbose", + default_value="1", + description="Verbose mode (0: no output, 1: INFO).", + ) + ) + # HER specifics + declared_arguments.append( + DeclareLaunchArgument( + "truncate_last_trajectory", + default_value="True", + description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "log_level", + default_value="error", + description="The level of logging that is applied to all ROS 2 nodes launched by this script.", + ) + ) + cpu_c = cpu_count() + if cpu_c is not None: + env_variables = [ + SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"), + SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_c // 2)), + ] + else: + env_variables = [ + SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"), + ] + + return LaunchDescription( + declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables + ) diff --git a/env_manager/rbs_gym/launch/train.launch.py b/env_manager/rbs_gym/launch/train.launch.py new file mode 100644 index 0000000..abea297 --- /dev/null +++ b/env_manager/rbs_gym/launch/train.launch.py @@ -0,0 +1,524 @@ +from launch import LaunchDescription +from launch.actions import ( + DeclareLaunchArgument, + IncludeLaunchDescription, + OpaqueFunction, + SetEnvironmentVariable, + TimerAction +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch.substitutions import LaunchConfiguration, PathJoinSubstitution +from launch_ros.substitutions import FindPackageShare +from launch_ros.actions import Node +import os +from os import cpu_count +from ament_index_python.packages import get_package_share_directory +import yaml +import xacro + +def launch_setup(context, *args, **kwargs): + # Initialize Arguments + robot_type = LaunchConfiguration("robot_type") + # General arguments + with_gripper_condition = LaunchConfiguration("with_gripper") + description_package = LaunchConfiguration("description_package") + description_file = LaunchConfiguration("description_file") + use_moveit = LaunchConfiguration("use_moveit") + moveit_config_package = LaunchConfiguration("moveit_config_package") + moveit_config_file = LaunchConfiguration("moveit_config_file") + use_sim_time = LaunchConfiguration("use_sim_time") + scene_config_file = LaunchConfiguration("scene_config_file").perform(context) + + ee_link_name = LaunchConfiguration("ee_link_name").perform(context) + base_link_name = LaunchConfiguration("base_link_name").perform(context) + + control_space = LaunchConfiguration("control_space").perform(context) + control_strategy = LaunchConfiguration("control_strategy").perform(context) + + interactive = LaunchConfiguration("interactive").perform(context) + + # training arguments + env = LaunchConfiguration("env") + algo = LaunchConfiguration("algo") + hyperparams = LaunchConfiguration("hyperparams") + n_timesteps = LaunchConfiguration("n_timesteps") + num_threads = LaunchConfiguration("num_threads") + seed = LaunchConfiguration("seed") + trained_agent = LaunchConfiguration("trained_agent") + save_freq = LaunchConfiguration("save_freq") + save_replay_buffer = LaunchConfiguration("save_replay_buffer") + preload_replay_buffer = LaunchConfiguration("preload_replay_buffer") + log_folder = LaunchConfiguration("log_folder") + tensorboard_log = LaunchConfiguration("tensorboard_log") + log_interval = LaunchConfiguration("log_interval") + uuid = LaunchConfiguration("uuid") + eval_freq = LaunchConfiguration("eval_freq") + eval_episodes = LaunchConfiguration("eval_episodes") + verbose = LaunchConfiguration("verbose") + truncate_last_trajectory = LaunchConfiguration("truncate_last_trajectory") + use_sim_time = LaunchConfiguration("use_sim_time") + log_level = LaunchConfiguration("log_level") + env_kwargs = LaunchConfiguration("env_kwargs") + + track = LaunchConfiguration("track") + + description_package_abs_path = get_package_share_directory( + description_package.perform(context) + ) + + + xacro_file = os.path.join( + description_package_abs_path, + "urdf", + description_file.perform(context), + ) + + if not scene_config_file == "": + config_file = {"config_file": scene_config_file} + else: + config_file = {} + + description_package_abs_path = get_package_share_directory( + description_package.perform(context) + ) + + controllers_file = os.path.join( + description_package_abs_path, "config", "controllers.yaml" + ) + + xacro_file = os.path.join( + description_package_abs_path, + "urdf", + description_file.perform(context), + ) + + # xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml" + xacro_config_file = os.path.join( + description_package_abs_path, + "urdf", + "xacro_args.yaml" + ) + + + # TODO: hide this to another place + # Load xacro_args + def param_constructor(loader, node, local_vars): + value = loader.construct_scalar(node) + return LaunchConfiguration(value).perform( + local_vars.get("context", "Launch context if not defined") + ) + + def variable_constructor(loader, node, local_vars): + value = loader.construct_scalar(node) + return local_vars.get(value, f"Variable '{value}' not found") + + def load_xacro_args(yaml_file, local_vars): + # Get valut from ros2 argument + yaml.add_constructor( + "!param", lambda loader, node: param_constructor(loader, node, local_vars) + ) + + # Get value from local variable in this code + # The local variable should be initialized before the loader was called + yaml.add_constructor( + "!variable", + lambda loader, node: variable_constructor(loader, node, local_vars), + ) + + with open(yaml_file, "r") as file: + return yaml.load(file, Loader=yaml.FullLoader) + + mappings_data = load_xacro_args(xacro_config_file, locals()) + + robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data) + + robot_description_semantic_content = "" + + + robot_description_content = robot_description_doc.toprettyxml(indent=" ") + robot_description = {"robot_description": robot_description_content} + + single_robot_setup = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + PathJoinSubstitution( + [FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"] + ) + ] + ), + launch_arguments={ + "with_gripper": with_gripper_condition, + "controllers_file": controllers_file, + "robot_type": robot_type, + "description_package": description_package, + "description_file": description_file, + "robot_name": robot_type, + "use_moveit": use_moveit, + "moveit_config_package": moveit_config_package, + "moveit_config_file": moveit_config_file, + "use_sim_time": use_sim_time, + "use_controllers": "true", + "robot_description": robot_description_content, + "robot_description_semantic": robot_description_semantic_content, + "base_link_name": base_link_name, + "ee_link_name": ee_link_name, + "control_space": control_space, + "control_strategy": control_strategy, + "interactive_control": interactive, + }.items(), + ) + + args = [ + "--env", + env, + "--env-kwargs", + env_kwargs, + "--algo", + algo, + "--hyperparams", + hyperparams, + "--n-timesteps", + n_timesteps, + "--num-threads", + num_threads, + "--seed", + seed, + "--trained-agent", + trained_agent, + "--save-freq", + save_freq, + "--save-replay-buffer", + save_replay_buffer, + "--preload-replay-buffer", + preload_replay_buffer, + "--log-folder", + log_folder, + "--tensorboard-log", + tensorboard_log, + "--log-interval", + log_interval, + "--uuid", + uuid, + "--eval-freq", + eval_freq, + "--eval-episodes", + eval_episodes, + "--verbose", + verbose, + "--track", + track, + "--truncate-last-trajectory", + truncate_last_trajectory, + "--ros-args", + "--log-level", + log_level, + ] + + clock_bridge = Node( + package='ros_gz_bridge', + executable='parameter_bridge', + arguments=['/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock'], + output='screen') + + rl_task = Node( + package="rbs_gym", + executable="train", + output="log", + arguments=args, + parameters=[{"use_sim_time": True}] + ) + + + delay_robot_control_stack = TimerAction( + period=20.0, + actions=[single_robot_setup] + ) + + nodes_to_start = [ + # env, + rl_task, + clock_bridge, + delay_robot_control_stack + ] + return nodes_to_start + + +def generate_launch_description(): + declared_arguments = [] + declared_arguments.append( + DeclareLaunchArgument( + "robot_type", + description="Type of robot by name", + choices=[ + "rbs_arm", + "ar4", + "ur3", + "ur3e", + "ur5", + "ur5e", + "ur10", + "ur10e", + "ur16e", + ], + default_value="rbs_arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_package", + default_value="rbs_arm", + description="Description package with robot URDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom description.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_file", + default_value="rbs_arm_modular.xacro", + description="URDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "robot_name", + default_value="arm0", + description="Name for robot, used to apply namespace for specific robot in multirobot setup", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_package", + default_value="rbs_arm", + description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom moveit config.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_file", + default_value="rbs_arm.srdf.xacro", + description="MoveIt SRDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_sim_time", + default_value="true", + description="Make MoveIt to use simulation time.\ + This is needed for the trajectory planing in simulation.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "with_gripper", default_value="true", description="With gripper or not?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_moveit", default_value="false", description="Launch moveit?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "launch_perception", default_value="false", description="Launch perception?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_controllers", + default_value="true", + description="Launch controllers?", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "scene_config_file", + default_value="", + description="Path to a scene configuration file", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "ee_link_name", + default_value="", + description="End effector name of robot arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "base_link_name", + default_value="", + description="Base link name if robot arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "control_space", + default_value="task", + choices=["task", "joint"], + description="Specify the control space for the robot (e.g., task space).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "control_strategy", + default_value="position", + choices=["position", "velocity", "effort"], + description="Specify the control strategy (e.g., position control).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "interactive", + default_value="true", + description="Wheter to run the motion_control_handle controller", + ), + ) + # training arguments + declared_arguments.append( + DeclareLaunchArgument( + "env", + default_value="Reach-Gazebo-v0", + description="Environment ID", + )) + declared_arguments.append( + DeclareLaunchArgument( + "env_kwargs", + default_value="", + description="Optional keyword argument to pass to the env constructor.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "vec_env", + default_value="dummy", + description="Type of VecEnv to use (dummy or subproc).", + )) + # Algorithm and training + declared_arguments.append( + DeclareLaunchArgument( + "algo", + default_value="sac", + description="RL algorithm to use during the training.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "n_timesteps", + default_value="-1", + description="Overwrite the number of timesteps.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "hyperparams", + default_value="", + description="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "num_threads", + default_value="-1", + description="Number of threads for PyTorch (-1 to use default).", + )) + # Continue training an already trained agent + declared_arguments.append( + DeclareLaunchArgument( + "trained_agent", + default_value="", + description="Path to a pretrained agent to continue training.", + )) + # Random seed + declared_arguments.append( + DeclareLaunchArgument( + "seed", + default_value="-1", + description="Random generator seed.", + )) + # Saving of model + declared_arguments.append( + DeclareLaunchArgument( + "save_freq", + default_value="10000", + description="Save the model every n steps (if negative, no checkpoint).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "save_replay_buffer", + default_value="False", + description="Save the replay buffer too (when applicable).", + )) + # Pre-load a replay buffer and start training on it + declared_arguments.append( + DeclareLaunchArgument( + "preload_replay_buffer", + default_value="", + description="Path to a replay buffer that should be preloaded before starting the training process.", + )) + # Logging + declared_arguments.append( + DeclareLaunchArgument( + "log_folder", + default_value="logs", + description="Path to the log directory.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "tensorboard_log", + default_value="tensorboard_logs", + description="Tensorboard log dir.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "log_interval", + default_value="-1", + description="Override log interval (default: -1, no change).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "uuid", + default_value="False", + description="Ensure that the run has a unique ID.", + )) + # Evaluation + declared_arguments.append( + DeclareLaunchArgument( + "eval_freq", + default_value="-1", + description="Evaluate the agent every n steps (if negative, no evaluation).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "eval_episodes", + default_value="5", + description="Number of episodes to use for evaluation.", + )) + # Verbosity + declared_arguments.append( + DeclareLaunchArgument( + "verbose", + default_value="1", + description="Verbose mode (0: no output, 1: INFO).", + )) + declared_arguments.append( + DeclareLaunchArgument( + "truncate_last_trajectory", + default_value="True", + description="When using HER with online sampling the last trajectory in the replay buffer will be truncated after) reloading the replay buffer." + )) + declared_arguments.append( + DeclareLaunchArgument( + "log_level", + default_value="error", + description="The level of logging that is applied to all ROS 2 nodes launched by this script.", + )) + declared_arguments.append( + DeclareLaunchArgument( + "track", + default_value="true", + description="The level of logging that is applied to all ROS 2 nodes launched by this script.", + )) + + env_variables = [ + SetEnvironmentVariable(name="OMP_DYNAMIC", value="TRUE"), + SetEnvironmentVariable(name="OMP_NUM_THREADS", value=str(cpu_count() // 2)) + ] + + return LaunchDescription(declared_arguments + [OpaqueFunction(function=launch_setup)] + env_variables) diff --git a/env_manager/rbs_gym/package.nix b/env_manager/rbs_gym/package.nix new file mode 100644 index 0000000..6411afe --- /dev/null +++ b/env_manager/rbs_gym/package.nix @@ -0,0 +1,20 @@ +# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0 + +# Copyright 2025 None +# Distributed under the terms of the Apache-2.0 license + +{ lib, buildRosPackage, ament-copyright, ament-flake8, ament-pep257, pythonPackages }: +buildRosPackage rec { + pname = "ros-jazzy-rbs-gym"; + version = "0.0.0"; + + src = ./.; + + buildType = "ament_python"; + checkInputs = [ ament-copyright ament-flake8 ament-pep257 pythonPackages.pytest ]; + + meta = { + description = "TODO: Package description"; + license = with lib.licenses; [ asl20 ]; + }; +} diff --git a/env_manager/rbs_gym/package.xml b/env_manager/rbs_gym/package.xml new file mode 100644 index 0000000..c11a38c --- /dev/null +++ b/env_manager/rbs_gym/package.xml @@ -0,0 +1,18 @@ + + + + rbs_gym + 0.0.0 + TODO: Package description + narmak + Apache-2.0 + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/env_manager/rbs_gym/rbs_gym/__init__.py b/env_manager/rbs_gym/rbs_gym/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/env_manager/rbs_gym/rbs_gym/envs/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/__init__.py new file mode 100644 index 0000000..f0db77a --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/__init__.py @@ -0,0 +1,184 @@ +# Note: The `open3d` and `stable_baselines3` modules must be imported prior to `gym_gz` +from env_manager.models.configs import ( + RobotData, + SceneData, + SphereObjectData, + TerrainData, +) +from env_manager.models.configs.objects import ObjectRandomizerData +from env_manager.models.configs.robot import RobotRandomizerData +from env_manager.models.configs.scene import PluginsData +# import open3d # isort:skip +import stable_baselines3 # isort:skip + +# Note: If installed, `tensorflow` module must be imported before `gym_gz`/`scenario` +# Otherwise, protobuf version incompatibility will cause an error +try: + from importlib.util import find_spec + + if find_spec("tensorflow") is not None: + import tensorflow +except: + pass + +from os import environ, path +from typing import Any, Dict, Tuple + +import numpy as np +from ament_index_python.packages import get_package_share_directory +from gymnasium.envs.registration import register +from rbs_assets_library import get_world_file + +from rbs_gym.utils.utils import str2bool + +from . import tasks + +###################### +# Runtime Entrypoint # +###################### + +RBS_ENVS_TASK_ENTRYPOINT: str = "gym_gz.runtimes.gazebo_runtime:GazeboRuntime" + + +################### +# Robot Specifics # +################### +# Default robot model to use in the tasks where robot can be static +RBS_ENVS_ROBOT_MODEL: str = "rbs_arm" + +###################### +# Datasets and paths # +###################### +# Path to directory containing base SDF worlds +RBS_ENVS_WORLDS_DIR: str = path.join(get_package_share_directory("rbs_gym"), "worlds") + + +########### +# Presets # +########### +# Gravity preset for Earth +# GRAVITY_EARTH: Tuple[float, float, float] = (0.0, 0.0, -9.80665) +# GRAVITY_EARTH_STD: Tuple[float, float, float] = (0.0, 0.0, 0.0232) + +############################ +# Additional Configuration # +############################ +# BROADCAST_GUI: bool = str2bool( +# environ.get("RBS_ENVS_BROADCAST_INTERACTIVE_GUI", default=True) +# ) + + +######### +# Reach # +######### +REACH_MAX_EPISODE_STEPS: int = 100 +REACH_KWARGS: Dict[str, Any] = { + "agent_rate": 4.0, + "robot_model": RBS_ENVS_ROBOT_MODEL, + "workspace_frame_id": "world", + "workspace_centre": (-0.45, 0.0, 0.35), + "workspace_volume": (0.7, 0.7, 0.7), + "ignore_new_actions_while_executing": False, + "use_servo": True, + "scaling_factor_translation": 8.0, + "scaling_factor_rotation": 3.0, + "restrict_position_goal_to_workspace": False, + "enable_gripper": False, + "sparse_reward": False, + "collision_reward": -10.0, + "act_quick_reward": -0.01, + "required_accuracy": 0.05, + "num_threads": 3, +} +REACH_KWARGS_SIM: Dict[str, Any] = { + "physics_rate": 100, + "real_time_factor": float(np.finfo(np.float32).max), + "world": get_world_file("ungravitational"), +} + + +REACH_RANDOMIZER: str = "rbs_gym.envs.randomizers:ManipulationGazeboEnvRandomizer" + +SCENE_CONFIGURATION: SceneData = SceneData( + physics_rollouts_num=0, + robot=RobotData( + name="rbs_arm", + joint_positions=[0.0, 0.5, 3.14159, 1.5, 0.0, 1.4, 0.0], + with_gripper=True, + gripper_joint_positions=0.00, + randomizer=RobotRandomizerData(joint_positions=True), + ), + objects=[ + SphereObjectData( + name="sphere", + type="sphere", + relative_to="base_link", + position=(0.0, 0.3, 0.5), + static=True, + collision=False, + color=(0.0, 1.0, 0.0, 0.8), + randomize=ObjectRandomizerData(random_pose=True, models_rollouts_num=2), + ) + ], + plugins=PluginsData( + scene_broadcaster=True, user_commands=True, fts_broadcaster=True + ), +) + + +# Task +register( + id="Reach-v0", + entry_point=RBS_ENVS_TASK_ENTRYPOINT, + max_episode_steps=REACH_MAX_EPISODE_STEPS, + kwargs={ + "task_cls": tasks.Reach, + **REACH_KWARGS, + }, +) +register( + id="Reach-ColorImage-v0", + entry_point=RBS_ENVS_TASK_ENTRYPOINT, + max_episode_steps=REACH_MAX_EPISODE_STEPS, + kwargs={ + "task_cls": tasks.ReachColorImage, + **REACH_KWARGS, + }, +) +register( + id="Reach-DepthImage-v0", + entry_point=RBS_ENVS_TASK_ENTRYPOINT, + max_episode_steps=REACH_MAX_EPISODE_STEPS, + kwargs={ + "task_cls": tasks.ReachDepthImage, + **REACH_KWARGS, + }, +) + +# Gazebo wrapper +register( + id="Reach-Gazebo-v0", + entry_point=REACH_RANDOMIZER, + max_episode_steps=REACH_MAX_EPISODE_STEPS, + kwargs={"env": "Reach-v0", **REACH_KWARGS_SIM, "scene_args": SCENE_CONFIGURATION}, +) +register( + id="Reach-ColorImage-Gazebo-v0", + entry_point=REACH_RANDOMIZER, + max_episode_steps=REACH_MAX_EPISODE_STEPS, + kwargs={ + "env": "Reach-ColorImage-v0", + **REACH_KWARGS_SIM, + "scene_args": SCENE_CONFIGURATION, + }, +) +register( + id="Reach-DepthImage-Gazebo-v0", + entry_point=REACH_RANDOMIZER, + max_episode_steps=REACH_MAX_EPISODE_STEPS, + kwargs={ + "env": "Reach-DepthImage-v0", + **REACH_KWARGS_SIM, + "scene_args": SCENE_CONFIGURATION, + }, +) diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/control/__init__.py new file mode 100644 index 0000000..b551c3d --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/control/__init__.py @@ -0,0 +1,3 @@ +from .cartesian_force_controller import CartesianForceController +from .grippper_controller import GripperController +from .joint_effort_controller import JointEffortController diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_force_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_force_controller.py new file mode 100644 index 0000000..3c60500 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_force_controller.py @@ -0,0 +1,41 @@ +from typing import Optional +from geometry_msgs.msg import WrenchStamped +from rclpy.node import Node +from rclpy.parameter import Parameter + +class CartesianForceController: + def __init__(self, node, namespace: Optional[str] = "") -> None: + self.node = node + self.publisher = node.create_publisher(WrenchStamped, + namespace + "/cartesian_force_controller/target_wrench", 10) + self.timer = node.create_timer(0.1, self.timer_callback) + self.publish = False + self._target_wrench = WrenchStamped() + + @property + def target_wrench(self) -> WrenchStamped: + return self._target_wrench + + @target_wrench.setter + def target_wrench(self, wrench: WrenchStamped): + self._target_wrench = wrench + + def timer_callback(self): + if self.publish: + self.publisher.publish(self._target_wrench) + + +class CartesianForceControllerStandalone(Node, CartesianForceController): + def __init__(self, node_name:str = "rbs_gym_controller", use_sim_time: bool = True): + + try: + rclpy.init() + except Exception as e: + if not rclpy.ok(): + sys.exit(f"ROS 2 context could not be initialised: {e}") + + Node.__init__(self, node_name) + self.set_parameters( + [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)] + ) + CartesianForceController.__init__(self, node=self) diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_velocity_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_velocity_controller.py new file mode 100644 index 0000000..b5a12c0 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/control/cartesian_velocity_controller.py @@ -0,0 +1,121 @@ +import rclpy +from rclpy.node import Node +import numpy as np +import quaternion +from geometry_msgs.msg import Twist +from geometry_msgs.msg import PoseStamped +import tf2_ros +import sys +import time +import threading +import os + + +class VelocityController: + """Convert Twist messages to PoseStamped + + Use this node to integrate twist messages into a moving target pose in + Cartesian space. An initial TF lookup assures that the target pose always + starts at the robot's end-effector. + """ + + def __init__(self, node: Node, topic_pose: str, topic_twist: str, base_frame: str, ee_frame: str): + self.node = node + + self._frame_id = base_frame + self._end_effector = ee_frame + + self.tf_buffer = tf2_ros.Buffer() + self.tf_listener = tf2_ros.TransformListener(self.tf_buffer, self) + self.rot = np.quaternion(0, 0, 0, 1) + self.pos = [0, 0, 0] + + self.pub = node.create_publisher(PoseStamped, topic_pose, 3) + self.sub = node.create_subscription(Twist, topic_twist, self.twist_cb, 1) + self.last = time.time() + + self.startup_done = False + period = 1.0 / node.declare_parameter("publishing_rate", 100).value + self.timer = node.create_timer(period, self.publish) + self.publish_it = False + + self.thread = threading.Thread(target=self.startup, daemon=True) + self.thread.start() + + def startup(self): + """Make sure to start at the robot's current pose""" + # Wait until we entered spinning in the main thread. + time.sleep(1) + try: + start = self.tf_buffer.lookup_transform( + target_frame=self._frame_id, + source_frame=self._end_effector, + time=rclpy.time.Time(), + ) + + except ( + tf2_ros.InvalidArgumentException, + tf2_ros.LookupException, + tf2_ros.ConnectivityException, + tf2_ros.ExtrapolationException, + ) as e: + print(f"Startup failed: {e}") + os._exit(1) + + self.pos[0] = start.transform.translation.x + self.pos[1] = start.transform.translation.y + self.pos[2] = start.transform.translation.z + self.rot.x = start.transform.rotation.x + self.rot.y = start.transform.rotation.y + self.rot.z = start.transform.rotation.z + self.rot.w = start.transform.rotation.w + self.startup_done = True + + def twist_cb(self, data): + """Numerically integrate twist message into a pose + + Use global self.frame_id as reference for the navigation commands. + """ + now = time.time() + dt = now - self.last + self.last = now + + # Position update + self.pos[0] += data.linear.x * dt + self.pos[1] += data.linear.y * dt + self.pos[2] += data.linear.z * dt + + # Orientation update + wx = data.angular.x + wy = data.angular.y + wz = data.angular.z + + _, q = quaternion.integrate_angular_velocity( + lambda _: (wx, wy, wz), 0, dt, self.rot + ) + + self.rot = q[-1] # the last one is after dt passed + + def publish(self): + if not self.startup_done: + return + if not self.publish_it: + return + try: + msg = PoseStamped() + msg.header.stamp = self.get_clock().now().to_msg() + msg.header.frame_id = self.frame_id + msg.pose.position.x = self.pos[0] + msg.pose.position.y = self.pos[1] + msg.pose.position.z = self.pos[2] + msg.pose.orientation.x = self.rot.x + msg.pose.orientation.y = self.rot.y + msg.pose.orientation.z = self.rot.z + msg.pose.orientation.w = self.rot.w + + self.pub.publish(msg) + except Exception: + # Swallow 'publish() to closed topic' error. + # This rarely happens on killing this node. + pass + diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/grippper_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/grippper_controller.py new file mode 100644 index 0000000..79a9e2a --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/control/grippper_controller.py @@ -0,0 +1,50 @@ +from typing import Optional +from control_msgs.action import GripperCommand +from rclpy.action import ActionClient + + +class GripperController: + def __init__(self, node, + open_position: Optional[float] = 0.0, + close_position: Optional[float] = 0.0, + max_effort: Optional[float] = 0.0, + namespace: Optional[str] = ""): + self._action_client = ActionClient(node, GripperCommand, + namespace + "/gripper_controller/gripper_cmd") + self._open_position = open_position + self._close_position = close_position + self._max_effort = max_effort + self.is_executed = False + + def open(self): + self.send_goal(self._open_position) + + def close(self): + self.send_goal(self._close_position) + + def send_goal(self, goal: float): + goal_msg = GripperCommand.Goal() + goal_msg._command.position = goal + goal_msg._command.max_effort = self._max_effort + self._action_client.wait_for_server() + self._send_goal_future = self._action_client.send_goal_async(goal_msg) + self._send_goal_future.add_done_callback(self.goal_response_callback) + + def goal_response_callback(self, future): + goal_handle = future.result() + if not goal_handle.accepted: + self.get_logger().info('Goal rejected :(') + return + + self.get_logger().info('Goal accepted :)') + + self._get_result_future = goal_handle.get_result_async() + self._get_result_future.add_done_callback(self.get_result_callback) + + def get_result_callback(self, future): + result = future.result().result + self.get_logger().info(f"Gripper position: {result.position}") + + def wait_until_executed(self): + while not self.is_executed: + pass diff --git a/env_manager/rbs_gym/rbs_gym/envs/control/joint_effort_controller.py b/env_manager/rbs_gym/rbs_gym/envs/control/joint_effort_controller.py new file mode 100644 index 0000000..b58ff8a --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/control/joint_effort_controller.py @@ -0,0 +1,25 @@ +from typing import Optional +from std_msgs.msg import Float64MultiArray + +class JointEffortController: + def __init__(self, node, namespace: Optional[str] = "") -> None: + self.node = node + self.publisher = node.create_publisher(Float64MultiArray, + namespace + "/joint_effort_controller/commands", 10) + # self.timer = node.create_timer(0.1, self.timer_callback) + # self.publish = True + self._effort_array = Float64MultiArray() + + @property + def target_effort(self) -> Float64MultiArray: + return self._effort_array + + @target_effort.setter + def target_effort(self, data: Float64MultiArray): + self._effort_array = data + + # def timer_callback(self): + # if self.publish: + # self.publisher.publish(self._target_wrench) + + diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/observation/__init__.py new file mode 100644 index 0000000..50b9243 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/observation/__init__.py @@ -0,0 +1,4 @@ +from .camera_subscriber import CameraSubscriber, CameraSubscriberStandalone +from .twist_subscriber import TwistSubscriber +from .joint_states import JointStates +# from .octree import OctreeCreator diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/camera_subscriber.py b/env_manager/rbs_gym/rbs_gym/envs/observation/camera_subscriber.py new file mode 100644 index 0000000..a60df57 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/observation/camera_subscriber.py @@ -0,0 +1,118 @@ +import sys +from threading import Lock, Thread +from typing import Optional, Union + +import rclpy +from rclpy.callback_groups import CallbackGroup +from rclpy.executors import SingleThreadedExecutor +from rclpy.node import Node +from rclpy.parameter import Parameter +from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, +) +from sensor_msgs.msg import Image, PointCloud2 + + +class CameraSubscriber: + def __init__( + self, + node: Node, + topic: str, + is_point_cloud: bool, + callback_group: Optional[CallbackGroup] = None, + ): + + self._node = node + + # Prepare the subscriber + if is_point_cloud: + camera_msg_type = PointCloud2 + else: + camera_msg_type = Image + self.__observation = camera_msg_type() + self._node.create_subscription( + msg_type=camera_msg_type, + topic=topic, + callback=self.observation_callback, + qos_profile=QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + durability=QoSDurabilityPolicy.VOLATILE, + history=QoSHistoryPolicy.KEEP_LAST, + depth=1, + ), + callback_group=callback_group, + ) + self.__observation_mutex = Lock() + self.__new_observation_available = False + + def observation_callback(self, msg): + """ + Callback for getting observation. + """ + + self.__observation_mutex.acquire() + self.__observation = msg + self.__new_observation_available = True + self._node.get_logger().debug("New observation received.") + self.__observation_mutex.release() + + def get_observation(self) -> Union[PointCloud2, Image]: + """ + Get the last received observation. + """ + + self.__observation_mutex.acquire() + observation = self.__observation + self.__observation_mutex.release() + return observation + + def reset_new_observation_checker(self): + """ + Reset checker of new observations, i.e. `self.new_observation_available()` + """ + + self.__observation_mutex.acquire() + self.__new_observation_available = False + self.__observation_mutex.release() + + @property + def new_observation_available(self): + """ + Check if new observation is available since `self.reset_new_observation_checker()` was called + """ + + return self.__new_observation_available + + +class CameraSubscriberStandalone(Node, CameraSubscriber): + def __init__( + self, + topic: str, + is_point_cloud: bool, + node_name: str = "rbs_gym_camera_sub", + use_sim_time: bool = True, + ): + + try: + rclpy.init() + except Exception as e: + if not rclpy.ok(): + sys.exit(f"ROS 2 context could not be initialised: {e}") + + Node.__init__(self, node_name) + self.set_parameters( + [Parameter("use_sim_time", type_=Parameter.Type.BOOL, value=use_sim_time)] + ) + + CameraSubscriber.__init__( + self, node=self, topic=topic, is_point_cloud=is_point_cloud + ) + + # Spin the node in a separate thread + self._executor = SingleThreadedExecutor() + self._executor.add_node(self) + self._executor_thread = Thread(target=self._executor.spin, daemon=True, args=()) + self._executor_thread.start() diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/joint_states.py b/env_manager/rbs_gym/rbs_gym/envs/observation/joint_states.py new file mode 100644 index 0000000..bc1a097 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/observation/joint_states.py @@ -0,0 +1,107 @@ +from array import array +from threading import Lock +from typing import Optional + +from rclpy.callback_groups import CallbackGroup +from rclpy.node import Node +from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, +) + +from sensor_msgs.msg import JointState + +class JointStates: + def __init__( + self, + node: Node, + topic: str, + callback_group: Optional[CallbackGroup] = None, + ): + + self._node = node + + self.__observation = JointState() + self._node.create_subscription( + msg_type=JointState, + topic=topic, + callback=self.observation_callback, + qos_profile=QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + durability=QoSDurabilityPolicy.VOLATILE, + history=QoSHistoryPolicy.KEEP_LAST, + depth=1, + ), + callback_group=callback_group, + ) + self.__observation_mutex = Lock() + self.__new_observation_available = False + + self.__observation.position + + def observation_callback(self, msg): + """ + Callback for getting observation. + """ + + self.__observation_mutex.acquire() + self.__observation = msg + self.__new_observation_available = True + self._node.get_logger().debug("New observation received.") + self.__observation_mutex.release() + + def get_observation(self) -> JointState: + """ + Get the last received observation. + """ + self.__observation_mutex.acquire() + observation = self.__observation + self.__observation_mutex.release() + return observation + + def get_positions(self) -> array: + """ + Get the last recorded observation position + """ + self.__observation_mutex.acquire() + observation = self.__observation.position + self.__observation_mutex.release() + return observation + + def get_velocities(self) -> array: + """ + Get the last recorded observation velocity + """ + self.__observation_mutex.acquire() + observation = self.__observation.velocity + self.__observation_mutex.release() + return observation + + def get_efforts(self) -> array: + """ + Get the last recorded observation effort + """ + self.__observation_mutex.acquire() + observation = self.__observation.effort + self.__observation_mutex.release() + return observation + + def reset_new_observation_checker(self): + """ + Reset checker of new observations, i.e. `self.new_observation_available()` + """ + + self.__observation_mutex.acquire() + self.__new_observation_available = False + self.__observation_mutex.release() + + @property + def new_observation_available(self): + """ + Check if new observation is available since `self.reset_new_observation_checker()` was called + """ + + return self.__new_observation_available + diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/octree.py b/env_manager/rbs_gym/rbs_gym/envs/observation/octree.py new file mode 100644 index 0000000..55caaaf --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/observation/octree.py @@ -0,0 +1,211 @@ +from typing import List, Tuple + +import numpy as np +import ocnn +import open3d +import torch +from rclpy.node import Node +from sensor_msgs.msg import PointCloud2 + +from env_manager.utils import Tf2Listener, conversions + + +class OctreeCreator: + def __init__( + self, + node: Node, + tf2_listener: Tf2Listener, + reference_frame_id: str, + min_bound: Tuple[float, float, float] = (-1.0, -1.0, -1.0), + max_bound: Tuple[float, float, float] = (1.0, 1.0, 1.0), + include_color: bool = False, + # Note: For efficiency, the first channel of RGB is used for intensity + include_intensity: bool = False, + depth: int = 4, + full_depth: int = 2, + adaptive: bool = False, + adp_depth: int = 4, + normals_radius: float = 0.05, + normals_max_nn: int = 10, + node_dis: bool = True, + node_feature: bool = False, + split_label: bool = False, + th_normal: float = 0.1, + th_distance: float = 2.0, + extrapolate: bool = False, + save_pts: bool = False, + key2xyz: bool = False, + debug_draw: bool = False, + debug_write_octree: bool = False, + ): + + self._node = node + + # Listener of tf2 transforms is shared with the owner + self.__tf2_listener = tf2_listener + + # Parameters + self._reference_frame_id = reference_frame_id + self._min_bound = min_bound + self._max_bound = max_bound + self._include_color = include_color + self._include_intensity = include_intensity + self._normals_radius = normals_radius + self._normals_max_nn = normals_max_nn + self._debug_draw = debug_draw + self._debug_write_octree = debug_write_octree + + # Create a converter between points and octree + self._points_to_octree = ocnn.Points2Octree( + depth=depth, + full_depth=full_depth, + node_dis=node_dis, + node_feature=node_feature, + split_label=split_label, + adaptive=adaptive, + adp_depth=adp_depth, + th_normal=th_normal, + th_distance=th_distance, + extrapolate=extrapolate, + save_pts=save_pts, + key2xyz=key2xyz, + bb_min=min_bound, + bb_max=max_bound, + ) + + def __call__(self, ros_point_cloud2: PointCloud2) -> torch.Tensor: + + # Convert to Open3D PointCloud + open3d_point_cloud = conversions.pointcloud2_to_open3d( + ros_point_cloud2=ros_point_cloud2, + include_color=self._include_color, + include_intensity=self._include_intensity, + ) + + # Preprocess point cloud (transform to robot frame, crop to workspace and estimate normals) + open3d_point_cloud = self.preprocess_point_cloud( + open3d_point_cloud=open3d_point_cloud, + camera_frame_id=ros_point_cloud2.header.frame_id, + reference_frame_id=self._reference_frame_id, + min_bound=self._min_bound, + max_bound=self._max_bound, + normals_radius=self._normals_radius, + normals_max_nn=self._normals_max_nn, + ) + + # Draw if needed + if self._debug_draw: + open3d.visualization.draw_geometries( + [ + open3d_point_cloud, + open3d.geometry.TriangleMesh.create_coordinate_frame( + size=0.2, origin=[0.0, 0.0, 0.0] + ), + ], + point_show_normal=True, + ) + + # Construct octree from such point cloud + octree = self.construct_octree( + open3d_point_cloud, + include_color=self._include_color, + include_intensity=self._include_intensity, + ) + + # Write if needed + if self._debug_write_octree: + ocnn.write_octree(octree, "octree.octree") + + return octree + + def preprocess_point_cloud( + self, + open3d_point_cloud: open3d.geometry.PointCloud, + camera_frame_id: str, + reference_frame_id: str, + min_bound: List[float], + max_bound: List[float], + normals_radius: float, + normals_max_nn: int, + ) -> open3d.geometry.PointCloud: + + # Check if point cloud has any points + if not open3d_point_cloud.has_points(): + self._node.get_logger().warn( + "Point cloud has no points. Pre-processing skipped." + ) + return open3d_point_cloud + + # Get transformation from camera to robot and use it to transform point + # cloud into robot's base coordinate frame + if camera_frame_id != reference_frame_id: + transform = self.__tf2_listener.lookup_transform_sync( + target_frame=reference_frame_id, source_frame=camera_frame_id + ) + transform_mat = conversions.transform_to_matrix(transform=transform) + open3d_point_cloud = open3d_point_cloud.transform(transform_mat) + + # Crop point cloud to include only the workspace + open3d_point_cloud = open3d_point_cloud.crop( + bounding_box=open3d.geometry.AxisAlignedBoundingBox( + min_bound=min_bound, max_bound=max_bound + ) + ) + + # Check if any points remain in the area after cropping + if not open3d_point_cloud.has_points(): + self._node.get_logger().warn( + "Point cloud has no points after cropping it to the workspace volume." + ) + return open3d_point_cloud + + # Estimate normal vector for each cloud point and orient these towards the camera + open3d_point_cloud.estimate_normals( + search_param=open3d.geometry.KDTreeSearchParamHybrid( + radius=normals_radius, max_nn=normals_max_nn + ), + fast_normal_computation=True, + ) + + open3d_point_cloud.orient_normals_towards_camera_location( + camera_location=transform_mat[0:3, 3] + ) + + return open3d_point_cloud + + def construct_octree( + self, + open3d_point_cloud: open3d.geometry.PointCloud, + include_color: bool, + include_intensity: bool, + ) -> torch.Tensor: + + # In case the point cloud has no points, add a single point + # This is a workaround because I was not able to create an empty octree without getting a segfault + # TODO: Figure out a better way of making an empty octree (it does not occur if setup correctly, so probably not worth it) + if not open3d_point_cloud.has_points(): + open3d_point_cloud.points.append( + ( + (self._min_bound[0] + self._max_bound[0]) / 2, + (self._min_bound[1] + self._max_bound[1]) / 2, + (self._min_bound[2] + self._max_bound[2]) / 2, + ) + ) + open3d_point_cloud.normals.append((0.0, 0.0, 0.0)) + if include_color or include_intensity: + open3d_point_cloud.colors.append((0.0, 0.0, 0.0)) + + # Convert open3d point cloud into octree points + octree_points = conversions.open3d_point_cloud_to_octree_points( + open3d_point_cloud=open3d_point_cloud, + include_color=include_color, + include_intensity=include_intensity, + ) + + # Convert octree points into 1D Tensor (via ndarray) + # Note: Copy of points here is necessary as ndarray would otherwise be immutable + octree_points_ndarray = np.frombuffer(np.copy(octree_points.buffer()), np.uint8) + octree_points_tensor = torch.from_numpy(octree_points_ndarray) + + # Finally, create an octree from the points + return self._points_to_octree(octree_points_tensor) diff --git a/env_manager/rbs_gym/rbs_gym/envs/observation/twist_subscriber.py b/env_manager/rbs_gym/rbs_gym/envs/observation/twist_subscriber.py new file mode 100644 index 0000000..6834a37 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/observation/twist_subscriber.py @@ -0,0 +1,81 @@ +import sys +from threading import Lock, Thread +from typing import Optional, Union + +import rclpy +from rclpy.callback_groups import CallbackGroup +from rclpy.executors import SingleThreadedExecutor +from rclpy.node import Node +from rclpy.parameter import Parameter +from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, +) +from geometry_msgs.msg import TwistStamped + +class TwistSubscriber: + def __init__( + self, + node: Node, + topic: str, + callback_group: Optional[CallbackGroup] = None, + ): + + self._node = node + + self.__observation = TwistStamped() + self._node.create_subscription( + msg_type=TwistStamped, + topic=topic, + callback=self.observation_callback, + qos_profile=QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + durability=QoSDurabilityPolicy.VOLATILE, + history=QoSHistoryPolicy.KEEP_LAST, + depth=1, + ), + callback_group=callback_group, + ) + self.__observation_mutex = Lock() + self.__new_observation_available = False + + def observation_callback(self, msg): + """ + Callback for getting observation. + """ + + self.__observation_mutex.acquire() + self.__observation = msg + self.__new_observation_available = True + self._node.get_logger().debug("New observation received.") + self.__observation_mutex.release() + + def get_observation(self) -> TwistStamped: + """ + Get the last received observation. + """ + + self.__observation_mutex.acquire() + observation = self.__observation + self.__observation_mutex.release() + return observation + + def reset_new_observation_checker(self): + """ + Reset checker of new observations, i.e. `self.new_observation_available()` + """ + + self.__observation_mutex.acquire() + self.__new_observation_available = False + self.__observation_mutex.release() + + @property + def new_observation_available(self): + """ + Check if new observation is available since `self.reset_new_observation_checker()` was called + """ + + return self.__new_observation_available + diff --git a/env_manager/rbs_gym/rbs_gym/envs/randomizers/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/randomizers/__init__.py new file mode 100644 index 0000000..9beec96 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/randomizers/__init__.py @@ -0,0 +1 @@ +from .manipulation import ManipulationGazeboEnvRandomizer diff --git a/env_manager/rbs_gym/rbs_gym/envs/randomizers/manipulation.py b/env_manager/rbs_gym/rbs_gym/envs/randomizers/manipulation.py new file mode 100644 index 0000000..5777a2c --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/randomizers/manipulation.py @@ -0,0 +1,345 @@ +import abc + +from env_manager.models.configs import ( + SceneData, +) +from env_manager.scene import Scene +from env_manager.utils.logging import set_log_level +from gym_gz import randomizers +from gym_gz.randomizers import gazebo_env_randomizer +from gym_gz.randomizers.gazebo_env_randomizer import MakeEnvCallable +from scenario import gazebo as scenario +from scenario.bindings.gazebo import GazeboSimulator + +from rbs_gym.envs import tasks + +SupportedTasks = tasks.Reach | tasks.ReachColorImage | tasks.ReachDepthImage + + +class ManipulationGazeboEnvRandomizer( + gazebo_env_randomizer.GazeboEnvRandomizer, + randomizers.abc.PhysicsRandomizer, + randomizers.abc.TaskRandomizer, + abc.ABC, +): + """ + Basic randomizer of environments for robotic manipulation inside Ignition Gazebo. This randomizer + also populates the simulated world with robot, terrain, lighting and other entities. + """ + + POST_RANDOMIZATION_MAX_STEPS = 50 + + metadata = {"render_modes": ["human"]} + + def __init__( + self, + env: MakeEnvCallable, + scene_args: SceneData = SceneData(), + **kwargs, + ): + self._scene_data = scene_args + # TODO (a lot of work): Implement proper physics randomization. + if scene_args.physics_rollouts_num != 0: + raise TypeError( + "Proper physics randomization at each reset is not yet implemented. Please set `physics_rollouts_num=0`." + ) + + # Update kwargs before passing them to the task constructor (some tasks might need them) + # TODO: update logic when will need choose between cameras + cameras: list[dict[str, str | int]] = [] + for camera in self._scene_data.camera: + camera_info: dict[str, str | int] = {} + camera_info["name"] = camera.name + camera_info["type"] = camera.type + camera_info["width"] = camera.width + camera_info["height"] = camera.height + cameras.append(camera_info) + + kwargs.update({"camera_info": cameras}) + + # Initialize base classes + randomizers.abc.TaskRandomizer.__init__(self) + randomizers.abc.PhysicsRandomizer.__init__( + self, randomize_after_rollouts_num=scene_args.physics_rollouts_num + ) + gazebo_env_randomizer.GazeboEnvRandomizer.__init__( + self, env=env, physics_randomizer=self, **kwargs + ) + + self.__env_initialised = False + + ########################## + # PhysicsRandomizer impl # + ########################## + + def init_physics_preset(self, task: SupportedTasks): + self.set_gravity(task=task) + + def randomize_physics(self, task: SupportedTasks, **kwargs): + self.set_gravity(task=task) + + def set_gravity(self, task: SupportedTasks): + if not task.world.to_gazebo().set_gravity( + ( + task.np_random.normal(loc=self._gravity[0], scale=self._gravity_std[0]), + task.np_random.normal(loc=self._gravity[1], scale=self._gravity_std[1]), + task.np_random.normal(loc=self._gravity[2], scale=self._gravity_std[2]), + ) + ): + raise RuntimeError("Failed to set the gravity") + + def get_engine(self): + return scenario.PhysicsEngine_dart + + ####################### + # TaskRandomizer impl # + ####################### + + def randomize_task(self, task: SupportedTasks, **kwargs): + """ + Randomization of the task, which is called on each reset of the environment. + Note that this randomizer reset is called before `reset_task()`. + """ + + # Get gazebo instance associated with the task + if "gazebo" not in kwargs: + raise ValueError("Randomizer does not have access to the gazebo interface") + if isinstance(kwargs["gazebo"], GazeboSimulator): + gazebo: GazeboSimulator = kwargs["gazebo"] + else: + raise RuntimeError("Provide GazeboSimulator instance") + + # Initialise the environment on the first iteration + if not self.__env_initialised: + self.init_env(task=task, gazebo=gazebo) + self.__env_initialised = True + + # Perform pre-randomization steps + self.pre_randomization(task=task) + + self._scene.reset_scene() + + # Perform post-randomization steps + # TODO: Something in post-randomization causes GUI client to freeze during reset (the simulation server still works fine) + self.post_randomization(task, gazebo) + + ################### + # Randomizer impl # + ################### + + # Initialisation # + def init_env(self, task: SupportedTasks, gazebo: scenario.GazeboSimulator): + """ + Initialise an instance of the environment before the very first iteration + """ + + self._scene = Scene( + task, + gazebo, + self._scene_data, + task.get_parameter("robot_description").get_parameter_value().string_value, + ) + + # Init Scene for task + task.scene = self._scene + # Set log level for (Gym) Ignition + set_log_level(log_level=task.get_logger().get_effective_level().name) + + self._scene.init_scene() + + # Pre-randomization # + def pre_randomization(self, task: SupportedTasks): + """ + Perform steps that are required before randomization is performed. + """ + + for obj in self._scene.objects: + # If desired, select random spawn position from the segments + # It is performed here because object spawn position might be of interest also for robot and camera randomization + + segments_len = len(obj.randomize.random_spawn_position_segments) + if segments_len > 1: + # Randomly select a segment between two points + start_index = task.np_random.randint(segments_len - 1) + segment = ( + obj.randomize.random_spawn_position_segments[start_index], + obj.randomize.random_spawn_position_segments[start_index + 1], + ) + + # Randomly select a point on the segment and use it as the new object spawn position + intersect = task.np_random.random() + direction = ( + segment[1][0] - segment[0][0], + segment[1][1] - segment[0][1], + segment[1][2] - segment[0][2], + ) + obj.position = ( + segment[0][0] + intersect * direction[0], + segment[0][1] + intersect * direction[1], + segment[0][2] + intersect * direction[2], + ) + + # TODO: add bounding box with multiple objects + + # # Update also the workspace centre (and bounding box) if desired + # if self._object_random_spawn_position_update_workspace_centre: + # task.workspace_centre = ( + # self._object_spawn_position[0], + # self._object_spawn_position[1], + # # Z workspace is currently kept the same on purpose + # task.workspace_centre[2], + # ) + # workspace_volume_half = ( + # task.workspace_volume[0] / 2, + # task.workspace_volume[1] / 2, + # task.workspace_volume[2] / 2, + # ) + # task.workspace_min_bound = ( + # task.workspace_centre[0] - workspace_volume_half[0], + # task.workspace_centre[1] - workspace_volume_half[1], + # task.workspace_centre[2] - workspace_volume_half[2], + # ) + # task.workspace_max_bound = ( + # task.workspace_centre[0] + workspace_volume_half[0], + # task.workspace_centre[1] + workspace_volume_half[1], + # task.workspace_centre[2] + workspace_volume_half[2], + # ) + + def post_randomization( + self, task: SupportedTasks, gazebo: scenario.GazeboSimulator + ): + """ + Perform steps required once randomization is complete and the simulation can be stepped unpaused. + """ + + def perform_gazebo_step(): + if not gazebo.step(): + raise RuntimeError("Failed to execute an unpaused Gazebo step") + + def wait_for_new_observations(): + attempts = 0 + while True: + attempts += 1 + if attempts % self.POST_RANDOMIZATION_MAX_STEPS == 0: + task.get_logger().debug( + f"Waiting for new joint state after reset. Iteration #{attempts}..." + ) + else: + task.get_logger().debug("Waiting for new joint state after reset.") + + perform_gazebo_step() + + # If camera_sub is defined, ensure all new observations are available + if hasattr(task, "camera_subs") and not all( + sub.new_observation_available for sub in task.camera_subs + ): + continue + + return # Observations are ready + + attempts = 0 + processed_objects = set() + + # Early exit if the maximum number of steps is already reached + if self.POST_RANDOMIZATION_MAX_STEPS == 0: + task.get_logger().error( + "Robot keeps falling through the terrain. There is something wrong..." + ) + return + + # Ensure no objects are overlapping + for obj in self._scene.objects: + if not obj.randomize.random_pose or obj.name in processed_objects: + continue + + # Try repositioning until no overlap or maximum attempts reached + for _ in range(self.POST_RANDOMIZATION_MAX_STEPS): + task.get_logger().debug( + f"Checking overlap for {obj.name}, attempt {attempts + 1}" + ) + if self._scene.check_object_overlapping(obj): + processed_objects.add(obj.name) + break # No overlap, move to next object + else: + task.get_logger().debug( + f"Objects overlapping, trying new positions for {obj.name}" + ) + + perform_gazebo_step() + + else: + task.get_logger().warn( + f"Could not place {obj.name} without overlap after {self.POST_RANDOMIZATION_MAX_STEPS} attempts" + ) + continue # Move to next object in case of failure + + # Execute steps until new observations are available + if hasattr(task, "camera_subs") or task._enable_gripper: + wait_for_new_observations() + + # Final check if observations are not available within the maximum steps + if self.POST_RANDOMIZATION_MAX_STEPS == attempts: + task.get_logger().error("Cannot obtain new observation.") + + # ============================= + # Additional features and debug + # ============================= + + # def visualise_workspace( + # self, + # task: SupportedTasks, + # gazebo: scenario.GazeboSimulator, + # color: Tuple[float, float, float, float] = (0, 1, 0, 0.8), + # ): + # # Insert a translucent box visible only in simulation with no physical interactions + # models.Box( + # world=task.world, + # name="_workspace_volume", + # position=self._object_spawn_position, + # orientation=(0, 0, 0, 1), + # size=task.workspace_volume, + # collision=False, + # visual=True, + # gui_only=True, + # static=True, + # color=color, + # ) + # # Execute a paused run to process model insertion + # if not gazebo.run(paused=True): + # raise RuntimeError("Failed to execute a paused Gazebo run") + # + # def visualise_spawn_volume( + # self, + # task: SupportedTasks, + # gazebo: scenario.GazeboSimulator, + # color: Tuple[float, float, float, float] = (0, 0, 1, 0.8), + # color_with_height: Tuple[float, float, float, float] = (1, 0, 1, 0.7), + # ): + # # Insert translucent boxes visible only in simulation with no physical interactions + # models.Box( + # world=task.world, + # name="_object_random_spawn_volume", + # position=self._object_spawn_position, + # orientation=(0, 0, 0, 1), + # size=self._object_random_spawn_volume, + # collision=False, + # visual=True, + # gui_only=True, + # static=True, + # color=color, + # ) + # models.Box( + # world=task.world, + # name="_object_random_spawn_volume_with_height", + # position=self._object_spawn_position, + # orientation=(0, 0, 0, 1), + # size=self._object_random_spawn_volume, + # collision=False, + # visual=True, + # gui_only=True, + # static=True, + # color=color_with_height, + # ) + # # Execute a paused run to process model insertion + # if not gazebo.run(paused=True): + # raise RuntimeError("Failed to execute a paused Gazebo run") diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/__init__.py new file mode 100644 index 0000000..15d114f --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/__init__.py @@ -0,0 +1,4 @@ +# from .curriculums import * +# from .grasp import * +# from .grasp_planetary import * +from .reach import * diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/__init__.py new file mode 100644 index 0000000..3e3828e --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/__init__.py @@ -0,0 +1 @@ +from .grasp import GraspCurriculum diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/common.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/common.py new file mode 100644 index 0000000..5caf7bc --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/common.py @@ -0,0 +1,700 @@ +from __future__ import annotations + +import enum +import itertools +import math +from collections import deque +from typing import Callable, Deque, Dict, Optional, Tuple, Type + +import numpy as np +from gym_gz.base.task import Task +from gym_gz.utils.typing import Reward +from tf2_ros.buffer_interface import TypeException + +INFO_MEAN_STEP_KEY: str = "__mean_step__" +INFO_MEAN_EPISODE_KEY: str = "__mean_episode__" + + +@enum.unique +class CurriculumStage(enum.Enum): + """ + Ordered enum that represents stages of a curriculum for RL task. + """ + + @classmethod + def first(self) -> CurriculumStage: + + return self(1) + + @classmethod + def last(self) -> CurriculumStage: + + return self(len(self)) + + def next(self) -> Optional[CurriculumStage]: + + next_value = self.value + 1 + + if next_value > self.last().value: + return None + else: + return self(next_value) + + def previous(self) -> Optional[CurriculumStage]: + + previous_value = self.value - 1 + + if previous_value < self.first().value: + return None + else: + return self(previous_value) + + +class StageRewardCurriculum: + """ + Curriculum that begins to compute rewards for a stage once all previous stages are complete. + """ + + PERSISTENT_ID: str = "PERSISTENT" + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + curriculum_stage: Type[CurriculumStage], + stage_reward_multiplier: float, + dense_reward: bool = False, + **kwargs, + ): + + if 0 == len(curriculum_stage): + raise TypeException(f"{curriculum_stage} has length of 0") + + self.__use_dense_reward = dense_reward + if self.__use_dense_reward: + raise ValueError( + "Dense reward is currently not implemented for any curriculum" + ) + + # Setup internals + self._stage_type = curriculum_stage + self._stage_reward_functions: Dict[curriculum_stage, Callable] = { + curriculum_stage(stage): getattr(self, f"get_reward_{stage.name}") + for stage in iter(curriculum_stage) + } + self.__stage_reward_multipliers: Dict[curriculum_stage, float] = { + curriculum_stage(stage): stage_reward_multiplier ** (stage.value - 1) + for stage in iter(curriculum_stage) + } + + self.stages_completed_this_episode: Dict[curriculum_stage, bool] = { + curriculum_stage(stage): False for stage in iter(curriculum_stage) + } + self.__stages_rewards_this_episode: Dict[curriculum_stage, float] = { + curriculum_stage(stage): 0.0 for stage in iter(curriculum_stage) + } + self.__stages_rewards_this_episode[self.PERSISTENT_ID] = 0.0 + + self.__episode_succeeded: bool = False + self.__episode_failed: bool = False + + def get_reward(self, only_last_stage: bool = False, **kwargs) -> Reward: + + reward = 0.0 + + # Determine the stage at which to start computing reward [performance - done stages give no reward] + if only_last_stage: + first_stage_to_process = self._stage_type.last() + else: + for stage in iter(self._stage_type): + if not self.stages_completed_this_episode[stage]: + first_stage_to_process = stage + break + + # Iterate over all stages that might need to be processed + for stage in range(first_stage_to_process.value, len(self._stage_type) + 1): + stage = self._stage_type(stage) + + # Compute reward for the current stage + stage_reward = self._stage_reward_functions[stage](**kwargs) + # Multiply by the reward multiplier + stage_reward *= self.__stage_reward_multipliers[stage] + # Add to the total step reward + reward += stage_reward + # Add reward to the list for info + self.__stages_rewards_this_episode[stage] += stage_reward + + # Break if stage is not yet completed [performance - next stages won't give any reward] + if not self.stages_completed_this_episode[stage]: + break + + # If the last stage is complete, the episode has succeeded + self.__episode_succeeded = self.stages_completed_this_episode[ + self._stage_type.last() + ] + if self.__episode_succeeded: + return reward + + # Add persistent reward that is added regardless of the episode (unless task already succeeded) + persistent_reward = self.get_persistent_reward(**kwargs) + # Add to the total step reward + reward += persistent_reward + # Add reward to the list for info + self.__stages_rewards_this_episode[self.PERSISTENT_ID] += persistent_reward + + return reward + + def is_done(self) -> bool: + + if self.__episode_succeeded: + # The episode ended with success + self.on_episode_success() + return True + elif self.__episode_failed: + # The episode ended due to failure + self.on_episode_failure() + return True + else: + # Otherwise, the episode is not yet done + return False + + def get_info(self) -> Dict: + + # Whether the episode succeeded + info = { + "is_success": self.__episode_succeeded, + } + + # What stage was reached during this episode so far + for stage in iter(self._stage_type): + reached_stage = stage + if not self.stages_completed_this_episode[stage]: + break + info.update( + { + f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}ep_reached_stage_mean": reached_stage.value, + } + ) + + # Rewards for the individual stages + info.update( + { + f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}ep_rew_mean_{stage.value}_{stage.name.lower()}": self.__stages_rewards_this_episode[ + stage + ] + for stage in iter(self._stage_type) + } + ) + info.update( + { + f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}ep_rew_mean_{self.PERSISTENT_ID.lower()}": self.__stages_rewards_this_episode[ + self.PERSISTENT_ID + ] + } + ) + + return info + + def reset_task(self): + + if not (self.__episode_succeeded or self.__episode_failed): + # The episode ended due to timeout + self.on_episode_timeout() + + # Reset internals + self.stages_completed_this_episode = dict.fromkeys( + self.stages_completed_this_episode, False + ) + self.__stages_rewards_this_episode = dict.fromkeys( + self.__stages_rewards_this_episode, 0.0 + ) + self.__stages_rewards_this_episode[self.PERSISTENT_ID] = 0.0 + self.__episode_succeeded = False + self.__episode_failed = False + + @property + def episode_succeeded(self) -> bool: + + return self.__episode_succeeded + + @episode_succeeded.setter + def episode_succeeded(self, value: bool): + + self.__episode_succeeded = value + + @property + def episode_failed(self) -> bool: + + return self.__episode_failed + + @episode_failed.setter + def episode_failed(self, value: bool): + + self.__episode_failed = value + + @property + def use_dense_reward(self) -> bool: + + return self.__use_dense_reward + + def get_persistent_reward(self, **kwargs) -> float: + """ + Virtual method. + """ + + reward = 0.0 + + return reward + + def on_episode_success(self): + """ + Virtual method. + """ + + pass + + def on_episode_failure(self): + """ + Virtual method. + """ + + pass + + def on_episode_timeout(self): + """ + Virtual method. + """ + + pass + + +class SuccessRateImpl: + """ + Moving average over the success rate of last N episodes. + """ + + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + initial_success_rate: float = 0.0, + rolling_average_n: int = 100, + **kwargs, + ): + + self.__success_rate = initial_success_rate + self.__rolling_average_n = rolling_average_n + + # Setup internals + self.__previous_success_rate_weight: int = 0 + self.__collected_samples: int = 0 + + def get_info(self) -> Dict: + + info = { + f"{self.INFO_CURRICULUM_PREFIX}_success_rate": self.__success_rate, + } + + return info + + def update_success_rate(self, is_success: bool): + + # Until `rolling_average_n` is reached, use number of collected samples during computations + if self.__collected_samples < self.__rolling_average_n: + self.__previous_success_rate_weight = self.__collected_samples + self.__collected_samples += 1 + + self.__success_rate = ( + self.__previous_success_rate_weight * self.__success_rate + + float(is_success) + ) / self.__collected_samples + + @property + def success_rate(self) -> float: + + return self.__success_rate + + +class WorkspaceScaleCurriculum: + """ + Curriculum that increases the workspace size as the success rate increases. + """ + + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + task: Task, + success_rate_impl: SuccessRateImpl, + min_workspace_scale: float, + max_workspace_volume: Tuple[float, float, float], + max_workspace_scale_success_rate_threshold: float, + **kwargs, + ): + + self.__task = task + self.__success_rate_impl = success_rate_impl + self.__min_workspace_scale = min_workspace_scale + self.__max_workspace_volume = max_workspace_volume + self.__max_workspace_scale_success_rate_threshold = ( + max_workspace_scale_success_rate_threshold + ) + + def get_info(self) -> Dict: + + info = { + f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}workspace_scale": self.__workspace_scale, + } + + return info + + def reset_task(self): + + # Update workspace size + self.__update_workspace_size() + + def __update_workspace_size(self): + + self.__workspace_scale = min( + 1.0, + max( + self.__min_workspace_scale, + self.__success_rate_impl.success_rate + / self.__max_workspace_scale_success_rate_threshold, + ), + ) + + workspace_volume_new = ( + self.__workspace_scale * self.__max_workspace_volume[0], + self.__workspace_scale * self.__max_workspace_volume[1], + # Z workspace is currently kept the same on purpose + self.__max_workspace_volume[2], + ) + workspace_volume_half_new = ( + workspace_volume_new[0] / 2, + workspace_volume_new[1] / 2, + workspace_volume_new[2] / 2, + ) + workspace_min_bound_new = ( + self.__task.workspace_centre[0] - workspace_volume_half_new[0], + self.__task.workspace_centre[1] - workspace_volume_half_new[1], + self.__task.workspace_centre[2] - workspace_volume_half_new[2], + ) + workspace_max_bound_new = ( + self.__task.workspace_centre[0] + workspace_volume_half_new[0], + self.__task.workspace_centre[1] + workspace_volume_half_new[1], + self.__task.workspace_centre[2] + workspace_volume_half_new[2], + ) + + self.__task.add_task_parameter_overrides( + { + "workspace_volume": workspace_volume_new, + "workspace_min_bound": workspace_min_bound_new, + "workspace_max_bound": workspace_max_bound_new, + } + ) + + +class ObjectSpawnVolumeScaleCurriculum: + """ + Curriculum that increases the object spawn volume as the success rate increases. + """ + + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + task: Task, + success_rate_impl: SuccessRateImpl, + min_object_spawn_volume_scale: float, + max_object_spawn_volume: Tuple[float, float, float], + max_object_spawn_volume_scale_success_rate_threshold: float, + **kwargs, + ): + + self.__task = task + self.__success_rate_impl = success_rate_impl + self.__min_object_spawn_volume_scale = min_object_spawn_volume_scale + self.__max_object_spawn_volume = max_object_spawn_volume + self.__max_object_spawn_volume_scale_success_rate_threshold = ( + max_object_spawn_volume_scale_success_rate_threshold + ) + + def get_info(self) -> Dict: + + info = { + f"{self.INFO_CURRICULUM_PREFIX}{INFO_MEAN_EPISODE_KEY}object_spawn_volume_scale": self.__object_spawn_volume_scale, + } + + return info + + def reset_task(self): + + # Update object_spawn_volume size + self.__update_object_spawn_volume_size() + + def __update_object_spawn_volume_size(self): + + self.__object_spawn_volume_scale = min( + 1.0, + max( + self.__min_object_spawn_volume_scale, + self.__success_rate_impl.success_rate + / self.__max_object_spawn_volume_scale_success_rate_threshold, + ), + ) + + object_spawn_volume_volume_new = ( + self.__object_spawn_volume_scale * self.__max_object_spawn_volume[0], + self.__object_spawn_volume_scale * self.__max_object_spawn_volume[1], + self.__object_spawn_volume_scale * self.__max_object_spawn_volume[2], + ) + + self.__task.add_randomizer_parameter_overrides( + { + "object_random_spawn_volume": object_spawn_volume_volume_new, + } + ) + + +class ObjectCountCurriculum: + """ + Curriculum that increases the number of objects as the success rate increases. + """ + + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + task: Task, + success_rate_impl: SuccessRateImpl, + object_count_min: int, + object_count_max: int, + max_object_count_success_rate_threshold: float, + **kwargs, + ): + + self.__task = task + self.__success_rate_impl = success_rate_impl + self.__object_count_min = object_count_min + self.__object_count_max = object_count_max + self.__max_object_count_success_rate_threshold = ( + max_object_count_success_rate_threshold + ) + + self.__object_count_min_max_diff = object_count_max - object_count_min + if self.__object_count_min_max_diff < 0: + raise Exception( + "'object_count_min' cannot be larger than 'object_count_max'" + ) + + def get_info(self) -> Dict: + + info = { + f"{self.INFO_CURRICULUM_PREFIX}object_count": self.__object_count, + } + + return info + + def reset_task(self): + + # Update object count + self.__update_object_count() + + def __update_object_count(self): + + self.__object_count = min( + self.__object_count_max, + math.floor( + self.__object_count_min + + ( + self.__success_rate_impl.success_rate + / self.__max_object_count_success_rate_threshold + ) + * self.__object_count_min_max_diff + ), + ) + + self.__task.add_randomizer_parameter_overrides( + { + "object_count": self.__object_count, + } + ) + + +class ArmStuckChecker: + """ + Checker for arm getting stuck. + """ + + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + task: Task, + arm_stuck_n_steps: int, + arm_stuck_min_joint_difference_norm: float, + **kwargs, + ): + + self.__task = task + self.__arm_stuck_min_joint_difference_norm = arm_stuck_min_joint_difference_norm + + # List of previous join positions (used to compute difference norm with an older previous reading) + self.__previous_joint_positions: Deque[np.ndarray] = deque( + [], maxlen=arm_stuck_n_steps + ) + # Counter of how many time the robot got stuck + self.__robot_stuck_total_counter: int = 0 + + # Initialize list of indices for the arm. + # It is assumed that these indices do not change during the operation + self.__arm_joint_indices = None + + def get_info(self) -> Dict: + + info = { + f"{self.INFO_CURRICULUM_PREFIX}robot_stuck_count": self.__robot_stuck_total_counter, + } + + return info + + def reset_task(self): + + self.__previous_joint_positions.clear() + + joint_positions = self.__get_arm_joint_positions() + if joint_positions is not None: + self.__previous_joint_positions.append(joint_positions) + + def is_robot_stuck(self) -> bool: + + # Get current position and append to the list of previous ones + current_joint_positions = self.__get_arm_joint_positions() + if current_joint_positions is not None: + self.__previous_joint_positions.append(current_joint_positions) + + # Stop checking if there is not yet enough entries in the list + if ( + len(self.__previous_joint_positions) + < self.__previous_joint_positions.maxlen + ): + return False + + # Make sure the length of joint position matches + if len(current_joint_positions) != len(self.__previous_joint_positions[0]): + return False + + # Compute joint difference norm only with the `t - arm_stuck_n_steps` entry first (performance reason) + joint_difference_norm = np.linalg.norm( + current_joint_positions - self.__previous_joint_positions[0] + ) + + # If the difference is large enough, the arm does not appear to be stuck, so skip computing all other entries + if joint_difference_norm > self.__arm_stuck_min_joint_difference_norm: + return False + + # If it is too small, consider all other entries as well + joint_difference_norms = np.linalg.norm( + current_joint_positions + - list(itertools.islice(self.__previous_joint_positions, 1, None)), + axis=1, + ) + + # Return true (stuck) if all joint difference entries are too small + is_stuck = all( + joint_difference_norms < self.__arm_stuck_min_joint_difference_norm + ) + self.__robot_stuck_total_counter += int(is_stuck) + return is_stuck + + def __get_arm_joint_positions(self) -> Optional[np.ndarray[float]]: + + joint_state = self.__task.moveit2.joint_state + + if joint_state is None: + return None + + if self.__arm_joint_indices is None: + self.__arm_joint_indices = [ + i + for i, joint_name in enumerate(joint_state.name) + if joint_name in self.__task.robot_arm_joint_names + ] + + return np.take(joint_state.position, self.__arm_joint_indices) + + +class AttributeCurriculum: + """ + Curriculum that increases the value of an attribute (e.g. requirement) as the success rate increases. + Currently support only attributes that are increasing. + """ + + INFO_CURRICULUM_PREFIX: str = "curriculum/" + + def __init__( + self, + success_rate_impl: SuccessRateImpl, + attribute_owner: Type, + attribute_name: str, + initial_value: float, + target_value: float, + target_value_threshold: float, + **kwargs, + ): + + self.__success_rate_impl = success_rate_impl + self.__attribute_owner = attribute_owner + self.__attribute_name = attribute_name + self.__initial_value = initial_value + self.__target_value_threshold = target_value_threshold + + # Initialise current value of the attribute + self.__current_value = initial_value + + # Store difference for faster computations + self.__value_diff = target_value - initial_value + + def get_info(self) -> Dict: + + info = { + f"{self.INFO_CURRICULUM_PREFIX}{self.__attribute_name}": self.__current_value, + } + + return info + + def reset_task(self): + + # Update object count + self.__update_attribute() + + def __update_attribute(self): + + scale = min( + 1.0, + max( + self.__initial_value, + self.__success_rate_impl.success_rate / self.__target_value_threshold, + ), + ) + + self.__current_value = self.__initial_value + (scale * self.__value_diff) + + if hasattr(self.__attribute_owner, self.__attribute_name): + setattr(self.__attribute_owner, self.__attribute_name, self.__current_value) + elif hasattr(self.__attribute_owner, f"_{self.__attribute_name}"): + setattr( + self.__attribute_owner, + f"_{self.__attribute_name}", + self.__current_value, + ) + elif hasattr(self.__attribute_owner, f"__{self.__attribute_name}"): + setattr( + self.__attribute_owner, + f"__{self.__attribute_name}", + self.__current_value, + ) + else: + raise Exception( + f"Attribute owner '{self.__attribute_owner}' does not have any attribute named {self.__attribute_name}." + ) diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/grasp.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/grasp.py new file mode 100644 index 0000000..bd6a751 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/curriculums/grasp.py @@ -0,0 +1,341 @@ +from typing import Dict, List, Tuple + +from gym_gz.base.task import Task + +from rbs_gym.envs.tasks.curriculums.common import * +from rbs_gym.envs.utils.math import distance_to_nearest_point + + +class GraspStage(CurriculumStage): + """ + Ordered enum that represents stages of a curriculum for Grasp (and GraspPlanetary) task. + """ + + REACH = 1 + TOUCH = 2 + GRASP = 3 + LIFT = 4 + + +class GraspCurriculum( + StageRewardCurriculum, + SuccessRateImpl, + WorkspaceScaleCurriculum, + ObjectSpawnVolumeScaleCurriculum, + ObjectCountCurriculum, + ArmStuckChecker, +): + """ + Curriculum learning implementation for grasp task that provides termination (success/fail) and reward for each stage of the task. + """ + + def __init__( + self, + task: Task, + stages_base_reward: float, + reach_required_distance: float, + lift_required_height: float, + persistent_reward_each_step: float, + persistent_reward_terrain_collision: float, + persistent_reward_all_objects_outside_workspace: float, + persistent_reward_arm_stuck: float, + enable_stage_reward_curriculum: bool, + enable_workspace_scale_curriculum: bool, + enable_object_spawn_volume_scale_curriculum: bool, + enable_object_count_curriculum: bool, + reach_required_distance_min: Optional[float] = None, + reach_required_distance_max: Optional[float] = None, + reach_required_distance_max_threshold: Optional[float] = None, + lift_required_height_min: Optional[float] = None, + lift_required_height_max: Optional[float] = None, + lift_required_height_max_threshold: Optional[float] = None, + **kwargs, + ): + + StageRewardCurriculum.__init__(self, curriculum_stage=GraspStage, **kwargs) + SuccessRateImpl.__init__(self, **kwargs) + WorkspaceScaleCurriculum.__init__( + self, task=task, success_rate_impl=self, **kwargs + ) + ObjectSpawnVolumeScaleCurriculum.__init__( + self, task=task, success_rate_impl=self, **kwargs + ) + ObjectCountCurriculum.__init__( + self, task=task, success_rate_impl=self, **kwargs + ) + ArmStuckChecker.__init__(self, task=task, **kwargs) + + # Grasp task/environment that will be used to extract information from the scene + self.__task = task + + # Parameters + self.__stages_base_reward = stages_base_reward + self.reach_required_distance = reach_required_distance + self.lift_required_height = lift_required_height + self.__persistent_reward_each_step = persistent_reward_each_step + self.__persistent_reward_terrain_collision = persistent_reward_terrain_collision + self.__persistent_reward_all_objects_outside_workspace = ( + persistent_reward_all_objects_outside_workspace + ) + self.__persistent_reward_arm_stuck = persistent_reward_arm_stuck + self.__enable_stage_reward_curriculum = enable_stage_reward_curriculum + self.__enable_workspace_scale_curriculum = enable_workspace_scale_curriculum + self.__enable_object_spawn_volume_scale_curriculum = ( + enable_object_spawn_volume_scale_curriculum + ) + self.__enable_object_count_curriculum = enable_object_count_curriculum + + # Make sure that the persistent rewards for each step are negative + if self.__persistent_reward_each_step > 0.0: + self.__persistent_reward_each_step *= -1.0 + if self.__persistent_reward_terrain_collision > 0.0: + self.__persistent_reward_terrain_collision *= -1.0 + if self.__persistent_reward_all_objects_outside_workspace > 0.0: + self.__persistent_reward_all_objects_outside_workspace *= -1.0 + if self.__persistent_reward_arm_stuck > 0.0: + self.__persistent_reward_arm_stuck *= -1.0 + + # Setup curriculum for Reach distance requirement (if enabled) + reach_required_distance_min = ( + reach_required_distance_min + if reach_required_distance_min is not None + else reach_required_distance + ) + reach_required_distance_max = ( + reach_required_distance_max + if reach_required_distance_max is not None + else reach_required_distance + ) + reach_required_distance_max_threshold = ( + reach_required_distance_max_threshold + if reach_required_distance_max_threshold is not None + else 0.5 + ) + self.__reach_required_distance_curriculum_enabled = ( + not reach_required_distance_min == reach_required_distance_max + ) + if self.__reach_required_distance_curriculum_enabled: + self.__reach_required_distance_curriculum = AttributeCurriculum( + success_rate_impl=self, + attribute_owner=self, + attribute_name="reach_required_distance", + initial_value=reach_required_distance_min, + target_value=reach_required_distance_max, + target_value_threshold=reach_required_distance_max_threshold, + ) + + # Setup curriculum for Lift height requirement (if enabled) + lift_required_height_min = ( + lift_required_height_min + if lift_required_height_min is not None + else lift_required_height + ) + lift_required_height_max = ( + lift_required_height_max + if lift_required_height_max is not None + else lift_required_height + ) + lift_required_height_max_threshold = ( + lift_required_height_max_threshold + if lift_required_height_max_threshold is not None + else 0.5 + ) + # Offset Lift height requirement by the robot base offset + lift_required_height += task.robot_model_class.BASE_LINK_Z_OFFSET + lift_required_height_min += task.robot_model_class.BASE_LINK_Z_OFFSET + lift_required_height_max += task.robot_model_class.BASE_LINK_Z_OFFSET + lift_required_height_max_threshold += task.robot_model_class.BASE_LINK_Z_OFFSET + + self.__lift_required_height_curriculum_enabled = ( + not lift_required_height_min == lift_required_height_max + ) + if self.__lift_required_height_curriculum_enabled: + self.__lift_required_height_curriculum = AttributeCurriculum( + success_rate_impl=self, + attribute_owner=self, + attribute_name="lift_required_height", + initial_value=lift_required_height_min, + target_value=lift_required_height_max, + target_value_threshold=lift_required_height_max_threshold, + ) + + def get_reward(self) -> Reward: + + if self.__enable_stage_reward_curriculum: + # Try to get reward from each stage + return StageRewardCurriculum.get_reward( + self, + ee_position=self.__task.get_ee_position(), + object_positions=self.__task.get_object_positions(), + touched_objects=self.__task.get_touched_objects(), + grasped_objects=self.__task.get_grasped_objects(), + ) + else: + # If curriculum for stages is disabled, compute reward only for the last stage + return StageRewardCurriculum.get_reward( + self, + only_last_stage=True, + object_positions=self.__task.get_object_positions(), + grasped_objects=self.__task.get_grasped_objects(), + ) + + def is_done(self) -> bool: + + return StageRewardCurriculum.is_done(self) + + def get_info(self) -> Dict: + + info = StageRewardCurriculum.get_info(self) + info.update(SuccessRateImpl.get_info(self)) + if self.__enable_workspace_scale_curriculum: + info.update(WorkspaceScaleCurriculum.get_info(self)) + if self.__enable_object_spawn_volume_scale_curriculum: + info.update(ObjectSpawnVolumeScaleCurriculum.get_info(self)) + if self.__enable_object_count_curriculum: + info.update(ObjectCountCurriculum.get_info(self)) + if self.__persistent_reward_arm_stuck: + info.update(ArmStuckChecker.get_info(self)) + if self.__reach_required_distance_curriculum_enabled: + info.update(self.__reach_required_distance_curriculum.get_info()) + if self.__lift_required_height_curriculum_enabled: + info.update(self.__lift_required_height_curriculum.get_info()) + + return info + + def reset_task(self): + + StageRewardCurriculum.reset_task(self) + if self.__enable_workspace_scale_curriculum: + WorkspaceScaleCurriculum.reset_task(self) + if self.__enable_object_spawn_volume_scale_curriculum: + ObjectSpawnVolumeScaleCurriculum.reset_task(self) + if self.__enable_object_count_curriculum: + ObjectCountCurriculum.reset_task(self) + if self.__persistent_reward_arm_stuck: + ArmStuckChecker.reset_task(self) + if self.__reach_required_distance_curriculum_enabled: + self.__reach_required_distance_curriculum.reset_task() + if self.__lift_required_height_curriculum_enabled: + self.__lift_required_height_curriculum.reset_task() + + def on_episode_success(self): + + self.update_success_rate(is_success=True) + + def on_episode_failure(self): + + self.update_success_rate(is_success=False) + + def on_episode_timeout(self): + + self.update_success_rate(is_success=False) + + def get_reward_REACH( + self, + ee_position: Tuple[float, float, float], + object_positions: Dict[str, Tuple[float, float, float]], + **kwargs, + ) -> float: + + if not object_positions: + return 0.0 + + nearest_object_distance = distance_to_nearest_point( + origin=ee_position, points=list(object_positions.values()) + ) + + self.__task.get_logger().debug( + f"[Curriculum] Distance to nearest object: {nearest_object_distance}" + ) + if nearest_object_distance < self.reach_required_distance: + self.__task.get_logger().info( + f"[Curriculum] An object is now closer than the required distance of {self.reach_required_distance}" + ) + self.stages_completed_this_episode[GraspStage.REACH] = True + return self.__stages_base_reward + else: + return 0.0 + + def get_reward_TOUCH(self, touched_objects: List[str], **kwargs) -> float: + + if touched_objects: + self.__task.get_logger().info( + f"[Curriculum] Touched objects: {touched_objects}" + ) + self.stages_completed_this_episode[GraspStage.TOUCH] = True + return self.__stages_base_reward + else: + return 0.0 + + def get_reward_GRASP(self, grasped_objects: List[str], **kwargs) -> float: + + if grasped_objects: + self.__task.get_logger().info( + f"[Curriculum] Grasped objects: {grasped_objects}" + ) + self.stages_completed_this_episode[GraspStage.GRASP] = True + return self.__stages_base_reward + else: + return 0.0 + + def get_reward_LIFT( + self, + object_positions: Dict[str, Tuple[float, float, float]], + grasped_objects: List[str], + **kwargs, + ) -> float: + + if not (grasped_objects or object_positions): + return 0.0 + + for grasped_object in grasped_objects: + grasped_object_height = object_positions[grasped_object][2] + + self.__task.get_logger().debug( + f"[Curriculum] Height of grasped object '{grasped_objects}': {grasped_object_height}" + ) + if grasped_object_height > self.lift_required_height: + self.__task.get_logger().info( + f"[Curriculum] Lifted object: {grasped_object}" + ) + self.stages_completed_this_episode[GraspStage.LIFT] = True + return self.__stages_base_reward + + return 0.0 + + def get_persistent_reward( + self, object_positions: Dict[str, Tuple[float, float, float]], **kwargs + ) -> float: + + # Subtract a small reward each step to provide incentive to act quickly + reward = self.__persistent_reward_each_step + + # Negative reward for colliding with terrain + if self.__persistent_reward_terrain_collision: + if self.__task.check_terrain_collision(): + self.__task.get_logger().info( + "[Curriculum] Robot collided with the terrain" + ) + reward += self.__persistent_reward_terrain_collision + + # Negative reward for having all objects outside of the workspace + if self.__persistent_reward_all_objects_outside_workspace: + if self.__task.check_all_objects_outside_workspace( + object_positions=object_positions + ): + self.__task.get_logger().warn( + "[Curriculum] All objects are outside of the workspace" + ) + reward += self.__persistent_reward_all_objects_outside_workspace + self.episode_failed = True + + # Negative reward for arm getting stuck + if self.__persistent_reward_arm_stuck: + if ArmStuckChecker.is_robot_stuck(self): + self.__task.get_logger().error( + f"[Curriculum] Robot appears to be stuck, resetting..." + ) + reward += self.__persistent_reward_arm_stuck + self.episode_failed = True + + return reward diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/manipulation.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/manipulation.py new file mode 100644 index 0000000..b1bb7a0 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/manipulation.py @@ -0,0 +1,707 @@ +import abc +import multiprocessing +import sys +from itertools import count +from threading import Thread +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import rclpy +from env_manager.models.configs import RobotData +from env_manager.models.robots import RobotWrapper, get_robot_model_class +from env_manager.scene import Scene +from env_manager.utils import Tf2Broadcaster, Tf2Listener +from env_manager.utils.conversions import orientation_6d_to_quat +from env_manager.utils.gazebo import ( + Point, + Pose, + Quat, + get_model_orientation, + get_model_pose, + get_model_position, + transform_change_reference_frame_orientation, + transform_change_reference_frame_pose, + transform_change_reference_frame_position, +) +from env_manager.utils.math import quat_mul +from gym_gz.base.task import Task +from gym_gz.scenario.model_wrapper import ModelWrapper +from gym_gz.utils.typing import ( + Action, + ActionSpace, + Observation, + ObservationSpace, + Reward, +) +from rclpy.callback_groups import ReentrantCallbackGroup +from rclpy.executors import MultiThreadedExecutor, SingleThreadedExecutor +from rclpy.node import Node +from robot_builder.elements.robot import Robot +from scipy.spatial.transform import Rotation + +from rbs_gym.envs.control import ( + CartesianForceController, + GripperController, + JointEffortController, +) + + +class Manipulation(Task, Node, abc.ABC): + _ids = count(0) + + def __init__( + self, + agent_rate: float, + robot_model: str, + workspace_frame_id: str, + workspace_centre: Tuple[float, float, float], + workspace_volume: Tuple[float, float, float], + ignore_new_actions_while_executing: bool, + use_servo: bool, + scaling_factor_translation: float, + scaling_factor_rotation: float, + restrict_position_goal_to_workspace: bool, + enable_gripper: bool, + num_threads: int, + **kwargs, + ): + # Get next ID for this task instance + self.id = next(self._ids) + + # Initialize the Task base class + Task.__init__(self, agent_rate=agent_rate) + + # Initialize ROS 2 context (if not done before) + try: + rclpy.init() + except Exception as e: + if not rclpy.ok(): + sys.exit(f"ROS 2 context could not be initialised: {e}") + + # Initialize ROS 2 Node base class + Node.__init__(self, f"rbs_gym_{self.id}") + + self._allow_undeclared_parameters = True + self.declare_parameter("robot_description", "") + # Create callback group that allows execution of callbacks in parallel without restrictions + self._callback_group = ReentrantCallbackGroup() + + # Create executor + if num_threads == 1: + executor = SingleThreadedExecutor() + elif num_threads > 1: + executor = MultiThreadedExecutor( + num_threads=num_threads, + ) + else: + executor = MultiThreadedExecutor(num_threads=multiprocessing.cpu_count()) + + # Add this node to the executor + executor.add_node(self) + + # Spin this node in background thread(s) + self._executor_thread = Thread(target=executor.spin, daemon=True, args=()) + self._executor_thread.start() + + # Get class of the robot model based on passed argument + # self.robot_model_class = get_robot_model_class(robot_model) + + # Store passed arguments for later use + # self.workspace_centre = ( + # workspace_centre[0], + # workspace_centre[1], + # workspace_centre[2] + self.robot_model_class.BASE_LINK_Z_OFFSET, + # ) + # self.workspace_volume = workspace_volume + # self._restrict_position_goal_to_workspace = restrict_position_goal_to_workspace + # self._use_servo = use_servo + # self.__scaling_factor_translation = scaling_factor_translation + # self.__scaling_factor_rotation = scaling_factor_rotation + self._enable_gripper = enable_gripper + self._scene: Scene | None = None + # + # # Get workspace bounds, useful is many computations + # workspace_volume_half = ( + # workspace_volume[0] / 2, + # workspace_volume[1] / 2, + # workspace_volume[2] / 2, + # ) + # self.workspace_min_bound = ( + # self.workspace_centre[0] - workspace_volume_half[0], + # self.workspace_centre[1] - workspace_volume_half[1], + # self.workspace_centre[2] - workspace_volume_half[2], + # ) + # self.workspace_max_bound = ( + # self.workspace_centre[0] + workspace_volume_half[0], + # self.workspace_centre[1] + workspace_volume_half[1], + # self.workspace_centre[2] + workspace_volume_half[2], + # ) + + # Determine robot name and prefix based on current ID of the task + # self.robot_prefix: str = self.robot_model_class.DEFAULT_PREFIX + # if 0 == self.id: + # self.robot_name = self.robot_model_class.ROBOT_MODEL_NAME + # else: + # self.robot_name = f"{self.robot_model_class.ROBOT_MODEL_NAME}{self.id}" + # if self.robot_prefix.endswith("_"): + # self.robot_prefix = f"{self.robot_prefix[:-1]}{self.id}_" + # elif self.robot_prefix == "": + # self.robot_prefix = f"robot{self.id}_" + + # Names of specific robot links, useful all around the code + # self.robot_base_link_name = self.robot_model_class.get_robot_base_link_name( + # self.robot_prefix + # ) + # self.robot_arm_base_link_name = self.robot_model_class.get_arm_base_link_name( + # self.robot_prefix + # ) + # self.robot_ee_link_name = self.robot_model_class.get_ee_link_name( + # self.robot_prefix + # ) + # self.robot_arm_link_names = self.robot_model_class.get_arm_link_names( + # self.robot_prefix + # ) + # self.robot_gripper_link_names = self.robot_model_class.get_gripper_link_names( + # self.robot_prefix + # ) + # self.robot_arm_joint_names = self.robot_model_class.get_arm_joint_names( + # self.robot_prefix + # ) + # self.robot_gripper_joint_names = self.robot_model_class.get_gripper_joint_names( + # self.robot_prefix + # ) + + # Get exact name substitution of the frame for workspace + self.workspace_frame_id = self.substitute_special_frame(workspace_frame_id) + + # Specify initial positions (default configuration is used here) + # self.initial_arm_joint_positions = ( + # self.robot_model_class.DEFAULT_ARM_JOINT_POSITIONS + # ) + # self.initial_gripper_joint_positions = ( + # self.robot_model_class.DEFAULT_GRIPPER_JOINT_POSITIONS + # ) + + # Names of important models (in addition to robot model) + + # Setup listener and broadcaster of transforms via tf2 + self.tf2_listener = Tf2Listener(node=self) + self.tf2_broadcaster = Tf2Broadcaster(node=self) + + self.cartesian_control = True # TODO: make it as an external parameter + + # Setup controllers + self.controller = CartesianForceController(self) + if not self.cartesian_control: + self.joint_controller = JointEffortController(self) + if self._enable_gripper: + self.gripper = GripperController(self, 0.064) + + # Initialize task and randomizer overrides (e.g. from curriculum) + # Both of these are consumed at the beginning of reset + self.__task_parameter_overrides: dict[str, Any] = {} + self._randomizer_parameter_overrides: dict[str, Any] = {} + + def create_spaces(self) -> Tuple[ActionSpace, ObservationSpace]: + action_space = self.create_action_space() + observation_space = self.create_observation_space() + + return action_space, observation_space + + def create_action_space(self) -> ActionSpace: + raise NotImplementedError() + + def create_observation_space(self) -> ObservationSpace: + raise NotImplementedError() + + def set_action(self, action: Action): + raise NotImplementedError() + + def get_observation(self) -> Observation: + raise NotImplementedError() + + def get_reward(self) -> Reward: + raise NotImplementedError() + + def is_done(self) -> bool: + raise NotImplementedError() + + def reset_task(self): + # self.__consume_parameter_overrides() + pass + + # # Helper functions # + # def get_relative_ee_position( + # self, translation: Tuple[float, float, float] + # ) -> Tuple[float, float, float]: + # # Scale relative action to metric units + # translation = self.scale_relative_translation(translation) + # # Get current position + # current_position = self.get_ee_position() + # # Compute target position + # target_position = ( + # current_position[0] + translation[0], + # current_position[1] + translation[1], + # current_position[2] + translation[2], + # ) + # + # # Restrict target position to a limited workspace, if desired + # if self._restrict_position_goal_to_workspace: + # target_position = self.restrict_position_goal_to_workspace(target_position) + # + # return target_position + # + # def get_relative_ee_orientation( + # self, + # rotation: Union[ + # float, + # Tuple[float, float, float, float], + # Tuple[float, float, float, float, float, float], + # ], + # representation: str = "quat", + # ) -> Tuple[float, float, float, float]: + # # Get current orientation + # current_quat_xyzw = self.get_ee_orientation() + # + # # For 'z' representation, result should always point down + # # Therefore, create a new quatertnion that contains only yaw component + # if "z" == representation: + # current_yaw = Rotation.from_quat(current_quat_xyzw).as_euler("xyz")[2] + # current_quat_xyzw = Rotation.from_euler( + # "xyz", [np.pi, 0, current_yaw] + # ).as_quat() + # + # # Convert relative orientation representation to quaternion + # relative_quat_xyzw = None + # if "quat" == representation: + # relative_quat_xyzw = rotation + # elif "6d" == representation: + # vectors = tuple( + # rotation[x : x + 3] for x, _ in enumerate(rotation) if x % 3 == 0 + # ) + # relative_quat_xyzw = orientation_6d_to_quat(vectors[0], vectors[1]) + # elif "z" == representation: + # rotation = self.scale_relative_rotation(rotation) + # relative_quat_xyzw = Rotation.from_euler("xyz", [0, 0, rotation]).as_quat() + # + # # Compute target position (combine quaternions) + # target_quat_xyzw = quat_mul(current_quat_xyzw, relative_quat_xyzw) + # + # # Normalise quaternion (should not be needed, but just to be safe) + # target_quat_xyzw /= np.linalg.norm(target_quat_xyzw) + # + # return target_quat_xyzw + + # def scale_relative_translation( + # self, translation: Tuple[float, float, float] + # ) -> Tuple[float, float, float]: + # return ( + # self.__scaling_factor_translation * translation[0], + # self.__scaling_factor_translation * translation[1], + # self.__scaling_factor_translation * translation[2], + # ) + + # def scale_relative_rotation( + # self, + # rotation: Union[float, Tuple[float, float, float], np.floating, np.ndarray], + # ) -> float: + # if not hasattr(rotation, "__len__"): + # return self.__scaling_factor_rotation * rotation + # else: + # return ( + # self.__scaling_factor_rotation * rotation[0], + # self.__scaling_factor_rotation * rotation[1], + # self.__scaling_factor_rotation * rotation[2], + # ) + # + # def restrict_position_goal_to_workspace( + # self, position: Tuple[float, float, float] + # ) -> Tuple[float, float, float]: + # return ( + # min( + # self.workspace_max_bound[0], + # max( + # self.workspace_min_bound[0], + # position[0], + # ), + # ), + # min( + # self.workspace_max_bound[1], + # max( + # self.workspace_min_bound[1], + # position[1], + # ), + # ), + # min( + # self.workspace_max_bound[2], + # max( + # self.workspace_min_bound[2], + # position[2], + # ), + # ), + # ) + + # def restrict_servo_translation_to_workspace( + # self, translation: tuple[float, float, float] + # ) -> tuple[float, float, float]: + # current_ee_position = self.get_ee_position() + # + # translation = tuple( + # 0.0 + # if ( + # current_ee_position[i] > self.workspace_max_bound[i] + # and translation[i] > 0.0 + # ) + # or ( + # current_ee_position[i] < self.workspace_min_bound[i] + # and translation[i] < 0.0 + # ) + # else translation[i] + # for i in range(3) + # ) + # + # return translation + + def get_ee_pose( + self, + ) -> Pose: + """ + Return the current pose of the end effector with respect to arm base link. + """ + + try: + robot_model = self.world.to_gazebo().get_model(self.robot.name).to_gazebo() + ee_position, ee_quat_xyzw = get_model_pose( + world=self.world, + model=robot_model, + link=self.robot.ee_link, + xyzw=True, + ) + return transform_change_reference_frame_pose( + world=self.world, + position=ee_position, + quat=ee_quat_xyzw, + target_model=robot_model, + target_link=self.robot.base_link, + xyzw=True, + ) + except Exception as e: + self.get_logger().warn( + f"Cannot get end effector pose from Gazebo ({e}), using tf2..." + ) + transform = self.tf2_listener.lookup_transform_sync( + source_frame=self.robot.ee_link, + target_frame=self.robot.base_link, + retry=False, + ) + if transform is not None: + return ( + ( + transform.translation.x, + transform.translation.y, + transform.translation.z, + ), + ( + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ), + ) + else: + self.get_logger().error( + "Cannot get pose of the end effector (default values are returned)" + ) + return ( + (0.0, 0.0, 0.0), + (0.0, 0.0, 0.0, 1.0), + ) + + def get_ee_position(self) -> Point: + """ + Return the current position of the end effector with respect to arm base link. + """ + + try: + robot_model = self.robot_wrapper + ee_position = get_model_position( + world=self.world, + model=robot_model, + link=self.robot.ee_link, + ) + return transform_change_reference_frame_position( + world=self.world, + position=ee_position, + target_model=robot_model, + target_link=self.robot.base_link, + ) + except Exception as e: + self.get_logger().debug( + f"Cannot get end effector position from Gazebo ({e}), using tf2..." + ) + transform = self.tf2_listener.lookup_transform_sync( + source_frame=self.robot.ee_link, + target_frame=self.robot.base_link, + retry=False, + ) + if transform is not None: + return ( + transform.translation.x, + transform.translation.y, + transform.translation.z, + ) + else: + self.get_logger().error( + "Cannot get position of the end effector (default values are returned)" + ) + return (0.0, 0.0, 0.0) + + def get_ee_orientation(self) -> Quat: + """ + Return the current xyzw quaternion of the end effector with respect to arm base link. + """ + + try: + robot_model = self.robot_wrapper + ee_quat_xyzw = get_model_orientation( + world=self.world, + model=robot_model, + link=self.robot.ee_link, + xyzw=True, + ) + return transform_change_reference_frame_orientation( + world=self.world, + quat=ee_quat_xyzw, + target_model=robot_model, + target_link=self.robot.base_link, + xyzw=True, + ) + except Exception as e: + self.get_logger().warn( + f"Cannot get end effector orientation from Gazebo ({e}), using tf2..." + ) + transform = self.tf2_listener.lookup_transform_sync( + source_frame=self.robot.ee_link, + target_frame=self.robot.base_link, + retry=False, + ) + if transform is not None: + return ( + transform.rotation.x, + transform.rotation.y, + transform.rotation.z, + transform.rotation.w, + ) + else: + self.get_logger().error( + "Cannot get orientation of the end effector (default values are returned)" + ) + return (0.0, 0.0, 0.0, 1.0) + + def get_object_position( + self, object_model: ModelWrapper | str + ) -> Point: + """ + Return the current position of an object with respect to arm base link. + Note: Only simulated objects are currently supported. + """ + + try: + object_position = get_model_position( + world=self.world, + model=object_model, + ) + return transform_change_reference_frame_position( + world=self.world, + position=object_position, + target_model=self.robot_wrapper.name(), + target_link=self.robot.base_link, + ) + except Exception as e: + self.get_logger().error( + f"Cannot get position of {object_model} object (default values are returned): {e}" + ) + return (0.0, 0.0, 0.0) + + def get_object_positions(self) -> dict[str, Point]: + """ + Return the current position of all objects with respect to arm base link. + Note: Only simulated objects are currently supported. + """ + + object_positions = {} + + try: + robot_model = self.robot_wrapper + robot_arm_base_link = robot_model.get_link(link_name=self.robot.base_link) + for object_name in self.scene.__objects_unique_names: + object_position = get_model_position( + world=self.world, + model=object_name, + ) + object_positions[object_name] = ( + transform_change_reference_frame_position( + world=self.world, + position=object_position, + target_model=robot_model, + target_link=robot_arm_base_link, + ) + ) + except Exception as e: + self.get_logger().error( + f"Cannot get positions of all objects (empty Dict is returned): {e}" + ) + + return object_positions + + def substitute_special_frame(self, frame_id: str) -> str: + if "arm_base_link" == frame_id: + return self.robot.base_link + elif "base_link" == frame_id: + return self.robot.base_link + elif "end_effector" == frame_id: + return self.robot.ee_link + elif "world" == frame_id: + try: + # In Gazebo, where multiple worlds are allowed + return self.world.to_gazebo().name() + except Exception as e: + self.get_logger().warn(f"{e}") + # Otherwise (e.g. real world) + return "rbs_gym_world" + else: + return frame_id + + def wait_until_action_executed(self): + if self._enable_gripper: + #FIXME: code is not tested + self.gripper.wait_until_executed() + + def move_to_initial_joint_configuration(self): + #TODO: Write this code for cartesian_control + pass + + # self.moveit2.move_to_configuration(self.initial_arm_joint_positions) + + # if ( + # self.robot_model_class.CLOSED_GRIPPER_JOINT_POSITIONS + # == self.initial_gripper_joint_positions + # ): + # self.gripper.reset_close() + # else: + # self.gripper.reset_open() + + def check_terrain_collision(self) -> bool: + """ + Returns true if robot links are in collision with the ground. + """ + robot_name_len = len(self.robot.name) + model_names = self.world.model_names() + terrain_model = self.world.get_model(self.scene.terrain.name) + + for contact in terrain_model.contacts(): + body_b = contact.body_b + + if body_b.startswith(self.robot.name) and len(body_b) > robot_name_len: + link = body_b[robot_name_len + 2 :] + + if link != self.robot.base_link and ( + link in self.robot.actuated_joint_names + or link in self.robot.gripper_actuated_joint_names + ): + return True + + return False + + # def check_all_objects_outside_workspace( + # self, + # object_positions: Dict[str, Tuple[float, float, float]], + # ) -> bool: + # """ + # Returns true if all objects are outside the workspace + # """ + # + # return all( + # [ + # self.check_object_outside_workspace(object_position) + # for object_position in object_positions.values() + # ] + # ) + # + # def check_object_outside_workspace( + # self, + # object_position: Tuple[float, float, float], + # ) -> bool: + # """ + # Returns true if the object is outside the workspace + # """ + # + # return ( + # object_position[0] < self.workspace_min_bound[0] + # or object_position[1] < self.workspace_min_bound[1] + # or object_position[2] < self.workspace_min_bound[2] + # or object_position[0] > self.workspace_max_bound[0] + # or object_position[1] > self.workspace_max_bound[1] + # or object_position[2] > self.workspace_max_bound[2] + # ) + + # def add_parameter_overrides(self, parameter_overrides: Dict[str, Any]): + # self.add_task_parameter_overrides(parameter_overrides) + # self.add_randomizer_parameter_overrides(parameter_overrides) + # # + # def add_task_parameter_overrides(self, parameter_overrides: Dict[str, Any]): + # self.__task_parameter_overrides.update(parameter_overrides) + # # + # def add_randomizer_parameter_overrides(self, parameter_overrides: Dict[str, Any]): + # self._randomizer_parameter_overrides.update(parameter_overrides) + # + # def __consume_parameter_overrides(self): + # for key, value in self.__task_parameter_overrides.items(): + # if hasattr(self, key): + # setattr(self, key, value) + # elif hasattr(self, f"_{key}"): + # setattr(self, f"_{key}", value) + # elif hasattr(self, f"__{key}"): + # setattr(self, f"__{key}", value) + # else: + # self.get_logger().error( + # f"Override '{key}' is not supperted by the task." + # ) + # + # self.__task_parameter_overrides.clear() + + @property + def robot(self) -> Robot: + """The robot property.""" + if self._scene: + return self._scene.robot_wrapper.robot + else: + raise ValueError("R") + + @property + def robot_data(self) -> RobotData: + """The robot_data property.""" + if self._scene: + return self._scene.robot + else: + raise ValueError("RD") + + @property + def robot_wrapper(self) -> RobotWrapper: + """The robot_wrapper property.""" + if self._scene: + return self._scene.robot_wrapper + else: + raise ValueError("RW") + + @property + def scene(self) -> Scene: + """The scene property.""" + if self._scene: + return self._scene + else: + raise ValueError("S") + + @scene.setter + def scene(self, value: Scene): + self._scene = value diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/__init__.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/__init__.py new file mode 100644 index 0000000..1fc68c3 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/__init__.py @@ -0,0 +1,4 @@ +from .reach import Reach +from .reach_color_image import ReachColorImage +from .reach_depth_image import ReachDepthImage +# from .reach_octree import ReachOctree diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/pick_and_place_imitate.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/pick_and_place_imitate.py new file mode 100644 index 0000000..dd2d408 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/pick_and_place_imitate.py @@ -0,0 +1,120 @@ +import abc + +import gymnasium as gym +import numpy as np +from gym_gz.utils.typing import Observation, ObservationSpace + +from rbs_gym.envs.models.sensors import Camera +from rbs_gym.envs.observation import CameraSubscriber +from rbs_gym.envs.tasks.reach import Reach + + +class ReachColorImage(Reach, abc.ABC): + def __init__( + self, + camera_info: list[dict] = [], + monochromatic: bool = False, + **kwargs, + ): + # Initialize the Task base class + Reach.__init__( + self, + **kwargs, + ) + + # Store parameters for later use + self._camera_width = sum(w.get("width", 0) for w in camera_info) + self._camera_height = camera_info[0]["height"] + self._monochromatic = monochromatic + self._saved = False + + self.camera_subs: list[CameraSubscriber] = [] + for camera in camera_info: + # Perception (RGB camera) + self.camera_subs.append( + CameraSubscriber( + node=self, + topic=Camera.get_color_topic(camera["name"], camera["type"]), + is_point_cloud=False, + callback_group=self._callback_group, + ) + ) + + def create_observation_space(self) -> ObservationSpace: + """ + Creates the observation space for the Imitation Learning algorithm. + + This method returns a dictionary-based observation space that includes: + + - **image**: A 2D image captured from the robot's camera. The image's shape is determined by the camera's + width and height, and the number of channels (either monochromatic or RGB). + - Dimensions: `(camera_height, camera_width, channels)` where `channels` is 1 for monochromatic images + and 3 for RGB images. + - Pixel values are in the range `[0, 255]` with data type `np.uint8`. + + - **joint_states**: The joint positions of the robot arm, represented as continuous values within the + range `[-1.0, 1.0]`, where `nDoF` refers to the number of degrees of freedom of the robot arm. + + Returns: + gym.spaces.Dict: A dictionary defining the observation space for the learning algorithm, + containing both image and joint_states information. + """ + return gym.spaces.Dict( + { + "image": gym.spaces.Box( + low=0, + high=255, + shape=( + self._camera_height, + self._camera_width, + 1 if self._monochromatic else 3, + ), + dtype=np.uint8, + ), + "joint_states": gym.spaces.Box(low=-1.0, high=1.0), + } + ) + + def get_observation(self) -> Observation: + # Get the latest images + image = [] + + for sub in self.camera_subs: + image.append(sub.get_observation()) + + image_width = sum(i.width for i in image) + image_height = image[0].height + # if image_width == self._camera_width and image_height == self._camera_height: + + # image_data = np.concatenate([i.data for i in image], axis=1) + + assert ( + image_width == self._camera_width and image_height == self._camera_height + ), f"Error: Resolution of the input image does not match the configured observation space. ({image_width}x{image_height} instead of {self._camera_width}x{self._camera_height})" + + # Reshape and create the observation + # color_image = np.array([i.data for i in image], dtype=np.uint8).reshape( + # self._camera_height, self._camera_width, 3 + # ) + color_image = np.concatenate( + [ + np.array(i.data, dtype=np.uint8).reshape(i.height, i.width, 3) + for i in image + ], + axis=1, + ) + + # # Debug save images + # from PIL import Image + # img_color = Image.fromarray(color_image) + # img_color.save("img_color.png") + + if self._monochromatic: + observation = Observation(color_image[:, :, 0]) + else: + observation = Observation(color_image) + + self.get_logger().debug(f"\nobservation: {observation}") + + # Return the observation + return observation diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach.py new file mode 100644 index 0000000..4d26ad5 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach.py @@ -0,0 +1,199 @@ +import abc +from typing import Tuple + +from geometry_msgs.msg import WrenchStamped +import gymnasium as gym +import numpy as np +from gym_gz.utils.typing import ( + Action, + ActionSpace, + Observation, + ObservationSpace, + Reward, +) + +from rbs_gym.envs.tasks.manipulation import Manipulation +from env_manager.utils.math import distance_to_nearest_point +from std_msgs.msg import Float64MultiArray +from rbs_gym.envs.observation import TwistSubscriber, JointStates +from env_manager.utils.helper import get_object_position + + +class Reach(Manipulation, abc.ABC): + def __init__( + self, + sparse_reward: bool, + collision_reward: float, + act_quick_reward: float, + required_accuracy: float, + **kwargs, + ): + # Initialize the Task base class + Manipulation.__init__( + self, + **kwargs, + ) + + # Additional parameters + self._sparse_reward: bool = sparse_reward + self._act_quick_reward = ( + act_quick_reward if act_quick_reward >= 0.0 else -act_quick_reward + ) + self._collision_reward = ( + collision_reward if collision_reward >= 0.0 else -collision_reward + ) + self._required_accuracy: float = required_accuracy + + # Flag indicating if the task is done (performance - get_reward + is_done) + self._is_done: bool = False + self._is_truncated: bool = False + self._is_terminated: bool = False + + # Distance to target in the previous step (or after reset) + self._previous_distance: float = 0.0 + # self._collision_reward: float = collision_reward + + # self.initial_gripper_joint_positions = self.robot_data.gripper_joint_positions + self.twist = TwistSubscriber( + self, + topic="/cartesian_force_controller/current_twist", + callback_group=self._callback_group, + ) + + self.joint_states = JointStates( + self, topic="/joint_states", callback_group=self._callback_group + ) + + self._action = WrenchStamped() + self._action_array: Action = np.array([]) + self.followed_object_name = "sphere" + + def create_action_space(self) -> ActionSpace: + # 0:3 - (x, y, z) force + # 3:6 - (x, y, z) torque + return gym.spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) + + def create_observation_space(self) -> ObservationSpace: + # 0:3 - (x, y, z) end effector position + # 3:6 - (x, y, z) target position + # 6:9 - (x, y, z) current twist + # Note: These could theoretically be restricted to the workspace and object spawn area instead of inf + return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32) + + def set_action(self, action: Action): + # self.get_logger().debug(f"action: {action}") + # act = Float64MultiArray() + # act.data = action + # self.joint_controller.publisher.publish(act) + if isinstance(action, np.number): + raise RuntimeError("For the task Reach the action space should be array") + + # Store action for reward function + self._action_array = action + + # self._action.header.frame_id = self.robot_ee_link_name + self._action.header.stamp = self.get_clock().now().to_msg() + self._action.header.frame_id = self.robot.ee_link + self._action.wrench.force.x = float(action[0]) * 30.0 + self._action.wrench.force.y = float(action[1]) * 30.0 + self._action.wrench.force.z = float(action[2]) * 30.0 + # self._action.wrench.torque.x = float(action[3]) * 10.0 + # self._action.wrench.torque.y = float(action[4]) * 10.0 + # self._action.wrench.torque.z = float(action[5]) * 10.0 + self.controller.publisher.publish(self._action) + + def get_observation(self) -> Observation: + # Get current end-effector and target positions + self.object_names = [] + ee_position = self.get_ee_position() + target_position = self.get_object_position(self.followed_object_name) + # target_position = self.get_object_position(object_model=self.object_names[0]) + # joint_states = tuple(self.joint_states.get_positions()) + + # self.get_logger().warn(f"joint_states: {joint_states[:7]}") + + twist = self.twist.get_observation() + twt: tuple[float, float, float, float, float, float] = ( + twist.twist.linear.x, + twist.twist.linear.y, + twist.twist.linear.z, + twist.twist.angular.x, + twist.twist.angular.y, + twist.twist.angular.z, + ) + + observation: Observation = np.concatenate([ee_position, target_position, twt], dtype=np.float32) + # Create the observation + # observation = Observation( + # np.concatenate([ee_position, target_position, twt], dtype=np.float32) + # ) + + self.get_logger().debug(f"\nobservation: {observation}") + + # Return the observation + return observation + + def get_reward(self) -> Reward: + reward = 0.0 + + # Compute the current distance to the target + current_distance = self.get_distance_to_target() + + # Mark the episode done if target is reached + if current_distance < self._required_accuracy: + self._is_terminated = True + reward += 1.0 if self._sparse_reward else 0.0 # 100.0 + self.get_logger().debug(f"reward_target: {reward}") + + # Give reward based on how much closer robot got relative to the target for dense reward + if not self._sparse_reward: + distance_delta = self._previous_distance - current_distance + reward += distance_delta * 10.0 + self._previous_distance = current_distance + self.get_logger().debug(f"reward_distance: {reward}") + + if self.check_terrain_collision(): + reward -= self._collision_reward + self._is_truncated = True + self.get_logger().debug(f"reward_collision: {reward}") + + # Reward control + # reward -= np.abs(self._action_array).sum() * 0.01 + # self.get_logger().debug(f"reward_c: {reward}") + + # Subtract a small reward each step to provide incentive to act quickly (if enabled) + reward += self._act_quick_reward + + self.get_logger().debug(f"reward: {reward}") + + return Reward(reward) + + def is_terminated(self) -> bool: + self.get_logger().debug(f"terminated: {self._is_terminated}") + return self._is_terminated + + def is_truncated(self) -> bool: + self.get_logger().debug(f"truncated: {self._is_truncated}") + return self._is_truncated + + def reset_task(self): + Manipulation.reset_task(self) + + self._is_done = False + self._is_truncated = False + self._is_terminated = False + + # Compute and store the distance after reset if using dense reward + if not self._sparse_reward: + self._previous_distance = self.get_distance_to_target() + + self.get_logger().debug("\ntask reset") + + def get_distance_to_target(self) -> float: + ee_position = self.get_ee_position() + object_position = self.get_object_position(object_model=self.followed_object_name) + self.tf2_broadcaster.broadcast_tf( + "world", "object", object_position, (0.0, 0.0, 0.0, 1.0), xyzw=True + ) + + return distance_to_nearest_point(origin=ee_position, points=[object_position]) diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_color_image.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_color_image.py new file mode 100644 index 0000000..32ea90a --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_color_image.py @@ -0,0 +1,89 @@ +import abc + +import gymnasium as gym +import numpy as np +from gym_gz.utils.typing import Observation, ObservationSpace + +from env_manager.models.sensors import Camera +from rbs_gym.envs.observation import CameraSubscriber +from rbs_gym.envs.tasks.reach import Reach + + +class ReachColorImage(Reach, abc.ABC): + def __init__( + self, + camera_info: list[dict] = [], + monochromatic: bool = False, + **kwargs, + ): + + # Initialize the Task base class + Reach.__init__( + self, + **kwargs, + ) + + # Store parameters for later use + self._camera_width = sum(w.get('width', 0) for w in camera_info) + self._camera_height = camera_info[0]["height"] + self._monochromatic = monochromatic + self._saved = False + + self.camera_subs: list[CameraSubscriber] = [] + for camera in camera_info: + # Perception (RGB camera) + self.camera_subs.append(CameraSubscriber( + node=self, + topic=Camera.get_color_topic(camera["name"], camera["type"]), + is_point_cloud=False, + callback_group=self._callback_group, + )) + + def create_observation_space(self) -> ObservationSpace: + + # 0:3*height*width - rgb image + # 0:1*height*width - monochromatic (intensity) image + return gym.spaces.Box( + low=0, + high=255, + shape=( + self._camera_height, + self._camera_width, + 1 if self._monochromatic else 3, + ), + dtype=np.uint8, + ) + + def get_observation(self) -> Observation: + # Get the latest images + image = [] + + for sub in self.camera_subs: + image.append(sub.get_observation()) + + image_width = sum(i.width for i in image) + image_height = image[0].height + + assert ( + image_width == self._camera_width and image_height == self._camera_height + ), f"Error: Resolution of the input image does not match the configured observation space. ({image_width}x{image_height} instead of {self._camera_width}x{self._camera_height})" + + color_image = np.concatenate( + [np.array(i.data, dtype=np.uint8).reshape(i.height, i.width, 3) for i in image], + axis=1 + ) + + # # Debug save images + # from PIL import Image + # img_color = Image.fromarray(color_image) + # img_color.save("img_color.png") + + if self._monochromatic: + observation = Observation(color_image[:, :, 0]) + else: + observation = Observation(color_image) + + self.get_logger().debug(f"\nobservation: {observation}") + + # Return the observation + return observation diff --git a/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_depth_image.py b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_depth_image.py new file mode 100644 index 0000000..00d64b9 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/envs/tasks/reach/reach_depth_image.py @@ -0,0 +1,67 @@ +import abc + +import gymnasium as gym +import numpy as np +from gym_gz.utils.typing import Observation, ObservationSpace + +from env_manager.models.sensors import Camera +from rbs_gym.envs.observation import CameraSubscriber +from rbs_gym.envs.tasks.reach import Reach + + +class ReachDepthImage(Reach, abc.ABC): + def __init__( + self, + camera_width: int, + camera_height: int, + camera_type: str = "depth_camera", + **kwargs, + ): + + # Initialize the Task base class + Reach.__init__( + self, + **kwargs, + ) + + # Store parameters for later use + self._camera_width = camera_width + self._camera_height = camera_height + + # Perception (depth camera) + self.camera_sub = CameraSubscriber( + node=self, + topic=Camera.get_depth_topic(camera_type), + is_point_cloud=False, + callback_group=self._callback_group, + ) + + def create_observation_space(self) -> ObservationSpace: + + # 0:height*width - depth image + return gym.spaces.Box( + low=0, + high=np.inf, + shape=(self._camera_height, self._camera_width, 1), + dtype=np.float32, + ) + + def get_observation(self) -> Observation: + + # Get the latest image + image = self.camera_sub.get_observation() + + # Construct from buffer and reshape + depth_image = np.frombuffer(image.data, dtype=np.float32).reshape( + self._camera_height, self._camera_width, 1 + ) + # Replace all instances of infinity with 0 + depth_image[depth_image == np.inf] = 0.0 + + # Create the observation + observation = Observation(depth_image) + + self.get_logger().debug(f"\nobservation: {observation}") + + # Return the observation + return observation diff --git a/env_manager/rbs_gym/rbs_gym/scripts/evaluate.py b/env_manager/rbs_gym/rbs_gym/scripts/evaluate.py new file mode 100755 index 0000000..813ab4d --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/scripts/evaluate.py @@ -0,0 +1,295 @@ +#!/usr/bin/env -S python3 -O + +import argparse +import os +from typing import Dict + +import numpy as np +import torch as th +import yaml +from rbs_gym import envs as gz_envs +from rbs_gym.utils import create_test_env, get_latest_run_id, get_saved_hyperparams +from rbs_gym.utils.utils import ALGOS, StoreDict, str2bool +from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecEnvWrapper + + +def find_model_path(log_path: str, args) -> str: + model_extensions = ["zip"] + for ext in model_extensions: + model_path = os.path.join(log_path, f"{args.env}.{ext}") + if os.path.isfile(model_path): + return model_path + + if args.load_best: + best_model_path = os.path.join(log_path, "best_model.zip") + if os.path.isfile(best_model_path): + return best_model_path + + if args.load_checkpoint is not None: + checkpoint_model_path = os.path.join( + log_path, f"rl_model_{args.load_checkpoint}_steps.zip" + ) + if os.path.isfile(checkpoint_model_path): + return checkpoint_model_path + + raise ValueError(f"No model found for {args.algo} on {args.env}, path: {log_path}") + + +def main(): + parser = argparse.ArgumentParser() + + # Environment and its parameters + parser.add_argument( + "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID" + ) + parser.add_argument( + "--env-kwargs", + type=str, + nargs="+", + action=StoreDict, + help="Optional keyword argument to pass to the env constructor", + ) + parser.add_argument("--n-envs", type=int, default=1, help="Number of environments") + + # Algorithm + parser.add_argument( + "--algo", + type=str, + choices=list(ALGOS.keys()), + required=False, + default="sac", + help="RL algorithm to use during the training", + ) + parser.add_argument( + "--num-threads", + type=int, + default=-1, + help="Number of threads for PyTorch (-1 to use default)", + ) + + # Test duration + parser.add_argument( + "-n", + "--n-episodes", + type=int, + default=200, + help="Number of evaluation episodes", + ) + + # Random seed + parser.add_argument("--seed", type=int, default=0, help="Random generator seed") + + # Model to test + parser.add_argument( + "-f", "--log-folder", type=str, default="logs", help="Path to the log directory" + ) + parser.add_argument( + "--exp-id", + type=int, + default=0, + help="Experiment ID (default: 0: latest, -1: no exp folder)", + ) + parser.add_argument( + "--load-best", + type=str2bool, + default=False, + help="Load best model instead of last model if available", + ) + parser.add_argument( + "--load-checkpoint", + type=int, + help="Load checkpoint instead of last model if available, you must pass the number of timesteps corresponding to it", + ) + + # Deterministic/stochastic actions + parser.add_argument( + "--stochastic", + type=str2bool, + default=False, + help="Use stochastic actions instead of deterministic", + ) + + # Logging + parser.add_argument( + "--reward-log", type=str, default="reward_logs", help="Where to log reward" + ) + parser.add_argument( + "--norm-reward", + type=str2bool, + default=False, + help="Normalize reward if applicable (trained with VecNormalize)", + ) + + # Disable render + parser.add_argument( + "--no-render", + type=str2bool, + default=False, + help="Do not render the environment (useful for tests)", + ) + + # Verbosity + parser.add_argument( + "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)" + ) + + args, unknown = parser.parse_known_args() + + if args.exp_id == 0: + args.exp_id = get_latest_run_id( + os.path.join(args.log_folder, args.algo), args.env + ) + print(f"Loading latest experiment, id={args.exp_id}") + + # Sanity checks + if args.exp_id > 0: + log_path = os.path.join(args.log_folder, args.algo, f"{args.env}_{args.exp_id}") + else: + log_path = os.path.join(args.log_folder, args.algo) + + assert os.path.isdir(log_path), f"The {log_path} folder was not found" + + model_path = find_model_path(log_path, args) + off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"] + + if args.algo in off_policy_algos: + args.n_envs = 1 + + set_random_seed(args.seed) + + if args.num_threads > 0: + if args.verbose > 1: + print(f"Setting torch.num_threads to {args.num_threads}") + th.set_num_threads(args.num_threads) + + stats_path = os.path.join(log_path, args.env) + hyperparams, stats_path = get_saved_hyperparams( + stats_path, norm_reward=args.norm_reward, test_mode=True + ) + + # load env_kwargs if existing + env_kwargs = {} + args_path = os.path.join(log_path, args.env, "args.yml") + if os.path.isfile(args_path): + with open(args_path, "r") as f: + loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) + if loaded_args["env_kwargs"] is not None: + env_kwargs = loaded_args["env_kwargs"] + # overwrite with command line arguments + if args.env_kwargs is not None: + env_kwargs.update(args.env_kwargs) + + log_dir = args.reward_log if args.reward_log != "" else None + + env = create_test_env( + args.env, + n_envs=args.n_envs, + stats_path=stats_path, + seed=args.seed, + log_dir=log_dir, + should_render=not args.no_render, + hyperparams=hyperparams, + env_kwargs=env_kwargs, + ) + + kwargs = dict(seed=args.seed) + if args.algo in off_policy_algos: + # Dummy buffer size as we don't need memory to evaluate the trained agent + kwargs.update(dict(buffer_size=1)) + + custom_objects = { + "observation_space": env.observation_space, + "action_space": env.action_space, + } + + model = ALGOS[args.algo].load( + model_path, env=env, custom_objects=custom_objects, **kwargs + ) + + obs = env.reset() + + # Deterministic by default + stochastic = args.stochastic + deterministic = not stochastic + + print( + f"Evaluating for {args.n_episodes} episodes with a", + "deterministic" if deterministic else "stochastic", + "policy.", + ) + + state = None + episode_reward = 0.0 + episode_rewards, episode_lengths, success_episode_lengths = [], [], [] + ep_len = 0 + episode = 0 + # For HER, monitor success rate + successes = [] + while episode < args.n_episodes: + action, state = model.predict(obs, state=state, deterministic=deterministic) + obs, reward, done, infos = env.step(action) + if not args.no_render: + env.render("human") + + episode_reward += reward[0] + ep_len += 1 + + if done and args.verbose > 0: + episode += 1 + print(f"--- Episode {episode}/{args.n_episodes}") + # NOTE: for env using VecNormalize, the mean reward + # is a normalized reward when `--norm_reward` flag is passed + print(f"Episode Reward: {episode_reward:.2f}") + episode_rewards.append(episode_reward) + print("Episode Length", ep_len) + episode_lengths.append(ep_len) + if infos[0].get("is_success") is not None: + print("Success?:", infos[0].get("is_success", False)) + successes.append(infos[0].get("is_success", False)) + if infos[0].get("is_success"): + success_episode_lengths.append(ep_len) + print(f"Current success rate: {100 * np.mean(successes):.2f}%") + episode_reward = 0.0 + ep_len = 0 + state = None + + if args.verbose > 0 and len(successes) > 0: + print(f"Success rate: {100 * np.mean(successes):.2f}%") + + if args.verbose > 0 and len(episode_rewards) > 0: + print( + f"Mean reward: {np.mean(episode_rewards):.2f} " + f"+/- {np.std(episode_rewards):.2f}" + ) + + if args.verbose > 0 and len(episode_lengths) > 0: + print( + f"Mean episode length: {np.mean(episode_lengths):.2f} " + f"+/- {np.std(episode_lengths):.2f}" + ) + + if args.verbose > 0 and len(success_episode_lengths) > 0: + print( + f"Mean episode length of successful episodes: {np.mean(success_episode_lengths):.2f} " + f"+/- {np.std(success_episode_lengths):.2f}" + ) + + # Workaround for https://github.com/openai/gym/issues/893 + if not args.no_render: + if args.n_envs == 1 and "Bullet" not in args.env and isinstance(env, VecEnv): + # DummyVecEnv + # Unwrap env + while isinstance(env, VecEnvWrapper): + env = env.venv + if isinstance(env, DummyVecEnv): + env.envs[0].env.close() + else: + env.close() + else: + # SubprocVecEnv + env.close() + + +if __name__ == "__main__": + main() diff --git a/env_manager/rbs_gym/rbs_gym/scripts/optimize.py b/env_manager/rbs_gym/rbs_gym/scripts/optimize.py new file mode 100644 index 0000000..d840a67 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/scripts/optimize.py @@ -0,0 +1,234 @@ +#!/usr/bin/env -S python3 -O + +import argparse +import difflib +import os +import uuid +from typing import Dict + +import gymnasium as gym +import numpy as np +import torch as th +from rbs_gym import envs as gz_envs +from rbs_gym.utils.exp_manager import ExperimentManager +from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool +from stable_baselines3.common.utils import set_random_seed + + +def main(): + parser = argparse.ArgumentParser() + + # Environment and its parameters + parser.add_argument( + "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID" + ) + parser.add_argument( + "--env-kwargs", + type=str, + nargs="+", + action=StoreDict, + help="Optional keyword argument to pass to the env constructor", + ) + parser.add_argument( + "--vec-env", + type=str, + choices=["dummy", "subproc"], + default="dummy", + help="Type of VecEnv to use", + ) + + # Algorithm and training + parser.add_argument( + "--algo", + type=str, + choices=list(ALGOS.keys()), + required=False, + default="sac", + help="RL algorithm to use during the training", + ) + parser.add_argument( + "-params", + "--hyperparams", + type=str, + nargs="+", + action=StoreDict, + help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)", + ) + parser.add_argument( + "-n", + "--n-timesteps", + type=int, + default=-1, + help="Overwrite the number of timesteps", + ) + parser.add_argument( + "--num-threads", + type=int, + default=-1, + help="Number of threads for PyTorch (-1 to use default)", + ) + + # Continue training an already trained agent + parser.add_argument( + "-i", + "--trained-agent", + type=str, + default="", + help="Path to a pretrained agent to continue training", + ) + + # Random seed + parser.add_argument("--seed", type=int, default=-1, help="Random generator seed") + + # Saving of model + parser.add_argument( + "--save-freq", + type=int, + default=10000, + help="Save the model every n steps (if negative, no checkpoint)", + ) + parser.add_argument( + "--save-replay-buffer", + type=str2bool, + default=False, + help="Save the replay buffer too (when applicable)", + ) + + # Pre-load a replay buffer and start training on it + parser.add_argument( + "--preload-replay-buffer", + type=empty_str2none, + default="", + help="Path to a replay buffer that should be preloaded before starting the training process", + ) + + # Logging + parser.add_argument( + "-f", "--log-folder", type=str, default="logs", help="Path to the log directory" + ) + parser.add_argument( + "-tb", + "--tensorboard-log", + type=empty_str2none, + default="tensorboard_logs", + help="Tensorboard log dir", + ) + parser.add_argument( + "--log-interval", + type=int, + default=-1, + help="Override log interval (default: -1, no change)", + ) + parser.add_argument( + "-uuid", + "--uuid", + type=str2bool, + default=False, + help="Ensure that the run has a unique ID", + ) + + # Evaluation + parser.add_argument( + "--eval-freq", + type=int, + default=-1, + help="Evaluate the agent every n steps (if negative, no evaluation)", + ) + parser.add_argument( + "--eval-episodes", + type=int, + default=5, + help="Number of episodes to use for evaluation", + ) + + # Verbosity + parser.add_argument( + "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)" + ) + + # HER specifics + parser.add_argument( + "--truncate-last-trajectory", + type=str2bool, + default=True, + help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.", + ) + + args, unknown = parser.parse_known_args() + + # Check if the selected environment is valid + # If it could not be found, suggest the closest match + registered_envs = set(gym.envs.registry.keys()) + if args.env not in registered_envs: + try: + closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0] + except IndexError: + closest_match = "'no close match found...'" + raise ValueError( + f"{args.env} not found in gym registry, you maybe meant {closest_match}?" + ) + + # If no specific seed is selected, choose a random one + if args.seed < 0: + args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item() + + # Set the random seed across platforms + set_random_seed(args.seed) + + # Setting num threads to 1 makes things run faster on cpu + if args.num_threads > 0: + if args.verbose > 1: + print(f"Setting torch.num_threads to {args.num_threads}") + th.set_num_threads(args.num_threads) + + # Verify that pre-trained agent exists before continuing to train it + if args.trained_agent != "": + assert args.trained_agent.endswith(".zip") and os.path.isfile( + args.trained_agent + ), "The trained_agent must be a valid path to a .zip file" + + # If enabled, ensure that the run has a unique ID + uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" + + print("=" * 10, args.env, "=" * 10) + print(f"Seed: {args.seed}") + env_kwargs = {"render_mode": "human"} + + exp_manager = ExperimentManager( + args, + args.algo, + args.env, + args.log_folder, + args.tensorboard_log, + args.n_timesteps, + args.eval_freq, + args.eval_episodes, + args.save_freq, + args.hyperparams, + args.env_kwargs, + args.trained_agent, + truncate_last_trajectory=args.truncate_last_trajectory, + uuid_str=uuid_str, + seed=args.seed, + log_interval=args.log_interval, + save_replay_buffer=args.save_replay_buffer, + preload_replay_buffer=args.preload_replay_buffer, + verbose=args.verbose, + vec_env_type=args.vec_env, + ) + + # Prepare experiment + results = exp_manager.setup_experiment() + if results is not None: + model, saved_hyperparams = results + + # Normal training + if model is not None: + exp_manager.learn(model) + exp_manager.save_trained_model(model) + else: + exp_manager.hyperparameters_optimization() + + +if __name__ == "__main__": + main() diff --git a/env_manager/rbs_gym/rbs_gym/scripts/test_agent.py b/env_manager/rbs_gym/rbs_gym/scripts/test_agent.py new file mode 100755 index 0000000..e201b4b --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/scripts/test_agent.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +import argparse +from typing import Dict + +import gymnasium as gym +from stable_baselines3.common.env_checker import check_env + +from rbs_gym import envs as gz_envs +from rbs_gym.utils.utils import StoreDict, str2bool + + +def main(): + parser = argparse.ArgumentParser() + + # Environment and its parameters + parser.add_argument( + "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID" + ) + parser.add_argument( + "--env-kwargs", + type=str, + nargs="+", + action=StoreDict, + help="Optional keyword argument to pass to the env constructor", + ) + + # Number of episodes to run + parser.add_argument( + "-n", + "--n-episodes", + type=int, + default=10000, + help="Overwrite the number of episodes", + ) + + # Random seed + parser.add_argument("--seed", type=int, default=69, help="Random generator seed") + + # Flag to check environment + parser.add_argument( + "--check-env", + type=str2bool, + default=True, + help="Flag to check the environment before running the random agent", + ) + + # Flag to enable rendering + parser.add_argument( + "--render", + type=str2bool, + default=False, + help="Flag to enable rendering", + ) + + args, unknown = parser.parse_known_args() + + # Create the environment + if args.env_kwargs is not None: + env = gym.make(args.env, **args.env_kwargs) + else: + env = gym.make(args.env) + + # Initialize random seed + env.seed(args.seed) + + # Check the environment + if args.check_env: + check_env(env, warn=True, skip_render_check=True) + + # Step environment for bunch of episodes + for episode in range(args.n_episodes): + # Initialize returned values + done = False + total_reward = 0 + + # Reset the environment + observation = env.reset() + + # Step through the current episode until it is done + while not done: + # Sample random action + action = env.action_space.sample() + + # Step the environment with the random action + observation, reward, truncated, terminated, info = env.step(action) + + done = truncated or terminated + + # Accumulate the reward + total_reward += reward + + print(f"Episode #{episode}\n\treward: {total_reward}") + + # Cleanup once done + env.close() + + +if __name__ == "__main__": + main() diff --git a/env_manager/rbs_gym/rbs_gym/scripts/train.py b/env_manager/rbs_gym/rbs_gym/scripts/train.py new file mode 100755 index 0000000..87ea498 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/scripts/train.py @@ -0,0 +1,265 @@ +#!/usr/bin/env -S python3 -O + +import argparse +import difflib +import os +import uuid + +import gymnasium as gym +import numpy as np +import torch as th +from stable_baselines3.common.utils import set_random_seed + +from rbs_gym import envs as gz_envs +from rbs_gym.utils.exp_manager import ExperimentManager +from rbs_gym.utils.utils import ALGOS, StoreDict, empty_str2none, str2bool + + +def main(): + + parser = argparse.ArgumentParser() + + # Environment and its parameters + parser.add_argument( + "--env", type=str, default="Reach-Gazebo-v0", help="Environment ID" + ) + parser.add_argument( + "--env-kwargs", + type=str, + nargs="+", + action=StoreDict, + help="Optional keyword argument to pass to the env constructor", + ) + parser.add_argument( + "--vec-env", + type=str, + choices=["dummy", "subproc"], + default="dummy", + help="Type of VecEnv to use", + ) + + # Algorithm and training + parser.add_argument( + "--algo", + type=str, + choices=list(ALGOS.keys()), + required=False, + default="sac", + help="RL algorithm to use during the training", + ) + parser.add_argument( + "-params", + "--hyperparams", + type=str, + nargs="+", + action=StoreDict, + help="Optional RL hyperparameter overwrite (e.g. learning_rate:0.01 train_freq:10)", + ) + parser.add_argument( + "-n", + "--n-timesteps", + type=int, + default=-1, + help="Overwrite the number of timesteps", + ) + parser.add_argument( + "--num-threads", + type=int, + default=-1, + help="Number of threads for PyTorch (-1 to use default)", + ) + + # Continue training an already trained agent + parser.add_argument( + "-i", + "--trained-agent", + type=str, + default="", + help="Path to a pretrained agent to continue training", + ) + + # Random seed + parser.add_argument("--seed", type=int, default=-1, help="Random generator seed") + + # Saving of model + parser.add_argument( + "--save-freq", + type=int, + default=10000, + help="Save the model every n steps (if negative, no checkpoint)", + ) + parser.add_argument( + "--save-replay-buffer", + type=str2bool, + default=False, + help="Save the replay buffer too (when applicable)", + ) + + # Pre-load a replay buffer and start training on it + parser.add_argument( + "--preload-replay-buffer", + type=empty_str2none, + default="", + help="Path to a replay buffer that should be preloaded before starting the training process", + ) + parser.add_argument( + "--track", + type=str2bool, + default=False, + help="Track experiment using wandb" + ) + # optimization parameters + parser.add_argument( + "--optimize-hyperparameters", + type=str2bool, + default=False, + help="Run optimization or not?" + ) + parser.add_argument( + "--no-optim-plots", action="store_true", default=False, help="Disable hyperparameter optimization plots" + ) + + # Logging + parser.add_argument( + "-f", "--log-folder", type=str, default="logs", help="Path to the log directory" + ) + parser.add_argument( + "-tb", + "--tensorboard-log", + type=empty_str2none, + default="tensorboard_logs", + help="Tensorboard log dir", + ) + parser.add_argument( + "--log-interval", + type=int, + default=-1, + help="Override log interval (default: -1, no change)", + ) + parser.add_argument( + "-uuid", + "--uuid", + type=str2bool, + default=False, + help="Ensure that the run has a unique ID", + ) + + # Evaluation + parser.add_argument( + "--eval-freq", + type=int, + default=-1, + help="Evaluate the agent every n steps (if negative, no evaluation)", + ) + parser.add_argument( + "--eval-episodes", + type=int, + default=5, + help="Number of episodes to use for evaluation", + ) + + # Verbosity + parser.add_argument( + "--verbose", type=int, default=1, help="Verbose mode (0: no output, 1: INFO)" + ) + + # HER specifics + parser.add_argument( + "--truncate-last-trajectory", + type=str2bool, + default=True, + help="When using HER with online sampling the last trajectory in the replay buffer will be truncated after reloading the replay buffer.", + ) + + args, unknown = parser.parse_known_args() + + # Check if the selected environment is valid + # If it could not be found, suggest the closest match + registered_envs = set(gym.envs.registry.keys()) + if args.env not in registered_envs: + try: + closest_match = difflib.get_close_matches(args.env, registered_envs, n=1)[0] + except IndexError: + closest_match = "'no close match found...'" + raise ValueError( + f"{args.env} not found in gym registry, you maybe meant {closest_match}?" + ) + + # If no specific seed is selected, choose a random one + if args.seed < 0: + args.seed = np.random.randint(2**32 - 1, dtype=np.int64).item() + + # Set the random seed across platforms + set_random_seed(args.seed) + + # Setting num threads to 1 makes things run faster on cpu + if args.num_threads > 0: + if args.verbose > 1: + print(f"Setting torch.num_threads to {args.num_threads}") + th.set_num_threads(args.num_threads) + + # Verify that pre-trained agent exists before continuing to train it + if args.trained_agent != "": + assert args.trained_agent.endswith(".zip") and os.path.isfile( + args.trained_agent + ), "The trained_agent must be a valid path to a .zip file" + + # If enabled, ensure that the run has a unique ID + uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" + + print("=" * 10, args.env, "=" * 10) + print(f"Seed: {args.seed}") + + if args.track: + try: + import datetime + except ImportError as e: + raise ImportError( + "if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`" + ) from e + + run_name_str = f"{args.env}__{args.algo}__{args.seed}__{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + args.tensorboard_log = "runs" + + + exp_manager = ExperimentManager( + args, + args.algo, + args.env, + args.log_folder, + args.tensorboard_log, + args.n_timesteps, + args.eval_freq, + args.eval_episodes, + args.save_freq, + args.hyperparams, + args.env_kwargs, + args.trained_agent, + args.optimize_hyperparameters, + run_name=run_name_str, + truncate_last_trajectory=args.truncate_last_trajectory, + uuid_str=uuid_str, + seed=args.seed, + log_interval=args.log_interval, + save_replay_buffer=args.save_replay_buffer, + preload_replay_buffer=args.preload_replay_buffer, + verbose=args.verbose, + vec_env_type=args.vec_env, + no_optim_plots=args.no_optim_plots, + ) + + # Prepare experiment + results = exp_manager.setup_experiment() + if results is not None: + model, saved_hyperparams = results + + # Normal training + if model is not None: + exp_manager.learn(model, saved_hyperparams) + exp_manager.save_trained_model(model) + else: + exp_manager.hyperparameters_optimization() + + +if __name__ == "__main__": + main() diff --git a/env_manager/rbs_gym/rbs_gym/utils/__init__.py b/env_manager/rbs_gym/rbs_gym/utils/__init__.py new file mode 100644 index 0000000..b55f8bc --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/utils/__init__.py @@ -0,0 +1,9 @@ +from .utils import ( + ALGOS, + create_test_env, + get_latest_run_id, + get_saved_hyperparams, + get_trained_models, + get_wrapper_class, + linear_schedule, +) diff --git a/env_manager/rbs_gym/rbs_gym/utils/callbacks.py b/env_manager/rbs_gym/rbs_gym/utils/callbacks.py new file mode 100644 index 0000000..66c6bdd --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/utils/callbacks.py @@ -0,0 +1,306 @@ +import os +import tempfile +import time +from copy import deepcopy +from functools import wraps +from threading import Thread +from typing import Optional + +import optuna +from sb3_contrib import TQC +from stable_baselines3 import SAC +from stable_baselines3.common.callbacks import ( + BaseCallback, + CheckpointCallback, + EvalCallback, +) +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv + + +class TrialEvalCallback(EvalCallback): + """ + Callback used for evaluating and reporting a trial. + """ + + def __init__( + self, + eval_env: VecEnv, + trial: optuna.Trial, + n_eval_episodes: int = 5, + eval_freq: int = 10000, + deterministic: bool = True, + verbose: int = 0, + best_model_save_path: Optional[str] = None, + log_path: Optional[str] = None, + ): + + super(TrialEvalCallback, self).__init__( + eval_env=eval_env, + n_eval_episodes=n_eval_episodes, + eval_freq=eval_freq, + deterministic=deterministic, + verbose=verbose, + best_model_save_path=best_model_save_path, + log_path=log_path, + ) + self.trial = trial + self.eval_idx = 0 + self.is_pruned = False + + def _on_step(self) -> bool: + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + print("Evaluating trial") + super(TrialEvalCallback, self)._on_step() + self.eval_idx += 1 + # report best or report current ? + # report num_timesteps or elasped time ? + self.trial.report(self.last_mean_reward, self.eval_idx) + # Prune trial if need + if self.trial.should_prune(): + self.is_pruned = True + return False + return True + + +class ParallelTrainCallback(BaseCallback): + """ + Callback to explore (collect experience) and train (do gradient steps) + at the same time using two separate threads. + Normally used with off-policy algorithms and `train_freq=(1, "episode")`. + + - blocking mode: wait for the model to finish updating the policy before collecting new experience + at the end of a rollout + - force sync mode: stop training to update to the latest policy for collecting + new experience + + :param gradient_steps: Number of gradient steps to do before + sending the new policy + :param verbose: Verbosity level + :param sleep_time: Limit the fps in the thread collecting experience. + """ + + def __init__( + self, gradient_steps: int = 100, verbose: int = 0, sleep_time: float = 0.0 + ): + super(ParallelTrainCallback, self).__init__(verbose) + self.batch_size = 0 + self._model_ready = True + self._model = None + self.gradient_steps = gradient_steps + self.process = None + self.model_class = None + self.sleep_time = sleep_time + + def _init_callback(self) -> None: + temp_file = tempfile.TemporaryFile() + + # Windows TemporaryFile is not a io Buffer + # we save the model in the logs/ folder + if os.name == "nt": + temp_file = os.path.join("logs", "model_tmp.zip") + + self.model.save(temp_file) + + # TODO (external): add support for other algorithms + for model_class in [SAC, TQC]: + if isinstance(self.model, model_class): + self.model_class = model_class + break + + assert ( + self.model_class is not None + ), f"{self.model} is not supported for parallel training" + self._model = self.model_class.load(temp_file) + + self.batch_size = self._model.batch_size + + # Disable train method + def patch_train(function): + @wraps(function) + def wrapper(*args, **kwargs): + return + + return wrapper + + # Add logger for parallel training + self._model.set_logger(self.model.logger) + self.model.train = patch_train(self.model.train) + + # Hack: Re-add correct values at save time + def patch_save(function): + @wraps(function) + def wrapper(*args, **kwargs): + return self._model.save(*args, **kwargs) + + return wrapper + + self.model.save = patch_save(self.model.save) + + def train(self) -> None: + self._model_ready = False + + self.process = Thread(target=self._train_thread, daemon=True) + self.process.start() + + def _train_thread(self) -> None: + self._model.train( + gradient_steps=self.gradient_steps, batch_size=self.batch_size + ) + self._model_ready = True + + def _on_step(self) -> bool: + if self.sleep_time > 0: + time.sleep(self.sleep_time) + return True + + def _on_rollout_end(self) -> None: + if self._model_ready: + self._model.replay_buffer = deepcopy(self.model.replay_buffer) + self.model.set_parameters(deepcopy(self._model.get_parameters())) + self.model.actor = self.model.policy.actor + if self.num_timesteps >= self._model.learning_starts: + self.train() + # Do not wait for the training loop to finish + # self.process.join() + + def _on_training_end(self) -> None: + # Wait for the thread to terminate + if self.process is not None: + if self.verbose > 0: + print("Waiting for training thread to terminate") + self.process.join() + + +class SaveVecNormalizeCallback(BaseCallback): + """ + Callback for saving a VecNormalize wrapper every ``save_freq`` steps + + :param save_freq: (int) + :param save_path: (str) Path to the folder where ``VecNormalize`` will be saved, as ``vecnormalize.pkl`` + :param name_prefix: (str) Common prefix to the saved ``VecNormalize``, if None (default) + only one file will be kept. + """ + + def __init__( + self, + save_freq: int, + save_path: str, + name_prefix: Optional[str] = None, + verbose: int = 0, + ): + super(SaveVecNormalizeCallback, self).__init__(verbose) + self.save_freq = save_freq + self.save_path = save_path + self.name_prefix = name_prefix + + def _init_callback(self) -> None: + # Create folder if needed + if self.save_path is not None: + os.makedirs(self.save_path, exist_ok=True) + + def _on_step(self) -> bool: + if self.n_calls % self.save_freq == 0: + if self.name_prefix is not None: + path = os.path.join( + self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps.pkl" + ) + else: + path = os.path.join(self.save_path, "vecnormalize.pkl") + if self.model.get_vec_normalize_env() is not None: + self.model.get_vec_normalize_env().save(path) + if self.verbose > 1: + print(f"Saving VecNormalize to {path}") + return True + + +class CheckpointCallbackWithReplayBuffer(CheckpointCallback): + """ + Callback for saving a model every ``save_freq`` steps + :param save_freq: + :param save_path: Path to the folder where the model will be saved. + :param name_prefix: Common prefix to the saved models + :param save_replay_buffer: If enabled, save replay buffer together with model (if supported by algorithm). + :param verbose: + """ + + def __init__( + self, + save_freq: int, + save_path: str, + name_prefix: str = "rl_model", + save_replay_buffer: bool = False, + verbose: int = 0, + ): + super(CheckpointCallbackWithReplayBuffer, self).__init__( + save_freq, save_path, name_prefix, verbose + ) + self.save_replay_buffer = save_replay_buffer + # self.save_replay_buffer = hasattr(self.model, "save_replay_buffer") and save_replay_buffer + + def _on_step(self) -> bool: + if self.n_calls % self.save_freq == 0: + path = os.path.join( + self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps" + ) + self.model.save(path) + if self.verbose > 0: + print(f"Saving model checkpoint to {path}") + if self.save_replay_buffer: + path_replay_buffer = os.path.join(self.save_path, "replay_buffer.pkl") + self.model.save_replay_buffer(path_replay_buffer) + if self.verbose > 0: + print(f"Saving model checkpoint to {path_replay_buffer}") + return True + + +class CurriculumLoggerCallback(BaseCallback): + """ + Custom callback for logging curriculum values. + """ + + def __init__(self, verbose=0): + super(CurriculumLoggerCallback, self).__init__(verbose) + + def _on_step(self) -> bool: + + for infos in self.locals["infos"]: + for info_key, info_value in infos.items(): + if not ( + info_key.startswith("curriculum") + and info_key.count("__mean_step__") + ): + continue + + self.logger.record_mean( + key=info_key.replace("__mean_step__", ""), value=info_value + ) + + return True + + def _on_rollout_end(self) -> None: + + for infos in self.locals["infos"]: + for info_key, info_value in infos.items(): + if not info_key.startswith("curriculum"): + continue + if info_key.count("__mean_step__"): + continue + + if info_key.count("__mean_episode__"): + self.logger.record_mean( + key=info_key.replace("__mean_episode__", ""), value=info_value + ) + else: + if isinstance(info_value, str): + exclude = "tensorboard" + else: + exclude = None + self.logger.record(key=info_key, value=info_value, exclude=exclude) + + +class MetricsCallback(BaseCallback): + def __init__(self, verbose: int = 0): + super(MetricsCallback, self).__init__(verbose) + + def _on_step(self) -> bool: + pass diff --git a/env_manager/rbs_gym/rbs_gym/utils/exp_manager.py b/env_manager/rbs_gym/rbs_gym/utils/exp_manager.py new file mode 100644 index 0000000..eaa0b9b --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/utils/exp_manager.py @@ -0,0 +1,936 @@ +import argparse +import os +import pickle as pkl +import time +import warnings +from collections import OrderedDict +from pprint import pprint +from typing import Any, Callable, Dict, List, Optional, Tuple + +import gymnasium as gym +import numpy as np +import optuna +import yaml +from optuna.integration.skopt import SkoptSampler +from optuna.pruners import BasePruner, MedianPruner, NopPruner, SuccessiveHalvingPruner +from optuna.samplers import BaseSampler, RandomSampler, TPESampler +from optuna.visualization import plot_optimization_history, plot_param_importances +from stable_baselines3 import HerReplayBuffer +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.callbacks import BaseCallback, EvalCallback +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.noise import ( + NormalActionNoise, + OrnsteinUhlenbeckActionNoise, +) +from stable_baselines3.common.preprocessing import ( + is_image_space, + is_image_space_channels_first, +) +from stable_baselines3.common.utils import constant_fn +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + SubprocVecEnv, + VecEnv, + VecFrameStack, + VecNormalize, + VecTransposeImage, + is_vecenv_wrapped, +) +from torch import nn as nn + +from rbs_gym.utils.callbacks import ( + CheckpointCallbackWithReplayBuffer, + SaveVecNormalizeCallback, + TrialEvalCallback, +) +from rbs_gym.utils.hyperparams_opt import HYPERPARAMS_SAMPLER +from rbs_gym.utils.utils import ( + ALGOS, + get_callback_list, + get_latest_run_id, + get_wrapper_class, + linear_schedule, +) +from aim.sb3 import AimCallback + + +class ExperimentManager(object): + """ + Experiment manager: read the hyperparameters, + preprocess them, create the environment and the RL model. + + Please take a look at `train.py` to have the details for each argument. + """ + + def __init__( + self, + args: argparse.Namespace, + algo: str, + env_id: str, + log_folder: str, + tensorboard_log: str = "", + n_timesteps: int = 0, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + save_freq: int = -1, + hyperparams: Optional[Dict[str, Any]] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + trained_agent: str = "", + optimize_hyperparameters: bool = False, + storage: Optional[str] = None, + study_name: Optional[str] = None, + n_trials: int = 1, + n_jobs: int = 1, + sampler: str = "tpe", + pruner: str = "median", + optimization_log_path: Optional[str] = None, + n_startup_trials: int = 0, + n_evaluations: int = 1, + truncate_last_trajectory: bool = False, + uuid_str: str = "", + seed: int = 0, + log_interval: int = 0, + save_replay_buffer: bool = False, + preload_replay_buffer: str = "", + verbose: int = 1, + vec_env_type: str = "dummy", + n_eval_envs: int = 1, + no_optim_plots: bool = False, + run_name: str = "rbs_gym_run", + ): + super(ExperimentManager, self).__init__() + self.algo = algo + self.env_id = env_id + # Custom params + self.custom_hyperparams = hyperparams + self.env_kwargs = {} if env_kwargs is None else env_kwargs + self.n_timesteps = n_timesteps + self.normalize = False + self.normalize_kwargs = {} + self.env_wrapper = None + self.frame_stack = None + self.seed = seed + self.optimization_log_path = optimization_log_path + self.run_name = run_name + + self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[ + vec_env_type + ] + + self.vec_env_kwargs = {} + # self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"} + + # Callbacks + self.specified_callbacks = [] + self.callbacks = [] + self.save_freq = save_freq + self.eval_freq = eval_freq + self.n_eval_episodes = n_eval_episodes + self.n_eval_envs = n_eval_envs + + self.n_envs = 1 # it will be updated when reading hyperparams + self.n_actions = None # For DDPG/TD3 action noise objects + self._hyperparams = {} + + self.trained_agent = trained_agent + self.continue_training = trained_agent.endswith(".zip") and os.path.isfile( + trained_agent + ) + self.truncate_last_trajectory = truncate_last_trajectory + + self.preload_replay_buffer = preload_replay_buffer + + self._is_atari = self.is_atari(env_id) + self._is_gazebo_env = self.is_gazebo_env(env_id) + + # Hyperparameter optimization config + self.optimize_hyperparameters = optimize_hyperparameters + self.storage = storage + self.study_name = study_name + self.no_optim_plots = no_optim_plots + # maximum number of trials for finding the best hyperparams + self.n_trials = n_trials + # number of parallel jobs when doing hyperparameter search + self.n_jobs = n_jobs + self.sampler = sampler + self.pruner = pruner + self.n_startup_trials = n_startup_trials + self.n_evaluations = n_evaluations + self.deterministic_eval = not self.is_atari(self.env_id) + + # Logging + self.log_folder = log_folder + self.tensorboard_log = ( + None if tensorboard_log == "" else os.path.join(tensorboard_log, run_name, env_id) + ) + self.verbose = verbose + self.args = args + self.log_interval = log_interval + self.save_replay_buffer = save_replay_buffer + + self.log_path = f"{log_folder}/{self.algo}/" + self.save_path = os.path.join( + self.log_path, + f"{self.env_id}_{get_latest_run_id(self.log_path, self.env_id) + 1}{uuid_str}", + ) + self.params_path = f"{self.save_path}/{self.env_id}" + + def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]: + """ + Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects) + create the environment and possibly the model. + + :return: the initialized RL model + """ + hyperparams, saved_hyperparams = self.read_hyperparameters() + hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams( + hyperparams + ) + + # Create env to have access to action space for action noise + self._env = self.create_envs(self.n_envs, no_log=False) + + self.create_log_folder() + self.create_callbacks() + + self._hyperparams = self._preprocess_action_noise(hyperparams, self._env) + + if self.continue_training: + model = self._load_pretrained_agent(self._hyperparams, self._env) + elif self.optimize_hyperparameters: + return None + else: + # Train an agent from scratch + model = ALGOS[self.algo]( + env=self._env, + tensorboard_log=self.tensorboard_log, + seed=self.seed, + verbose=self.verbose, + **self._hyperparams, + ) + + # Pre-load replay buffer if enabled + if self.preload_replay_buffer: + if self.preload_replay_buffer.endswith(".pkl"): + replay_buffer_path = self.preload_replay_buffer + else: + replay_buffer_path = os.path.join( + self.preload_replay_buffer, "replay_buffer.pkl" + ) + if os.path.exists(replay_buffer_path): + print("Pre-loading replay buffer") + if self.algo == "her": + model.load_replay_buffer( + replay_buffer_path, self.truncate_last_trajectory + ) + else: + model.load_replay_buffer(replay_buffer_path) + else: + raise Exception(f"Replay buffer {replay_buffer_path} " "does not exist") + + self._save_config(saved_hyperparams) + return model, saved_hyperparams + + def learn(self, model: BaseAlgorithm, hyperparams: dict = {}) -> None: + """ + :param model: an initialized RL model + """ + kwargs = {} + if self.log_interval > -1: + kwargs = {"log_interval": self.log_interval} + + if len(self.callbacks) > 0: + aim_callback = AimCallback(repo=self.log_folder, experiment_name=self.run_name) + self.callbacks.append(aim_callback) + kwargs["callback"] = self.callbacks + + if self.continue_training: + kwargs["reset_num_timesteps"] = False + model.env.reset() + + try: + model.learn(self.n_timesteps, **kwargs) + except Exception as e: + print(f"Caught an exception during training of the model: {e}") + self.save_trained_model(model) + finally: + # Release resources + try: + model.env.close() + except EOFError: + pass + + def save_trained_model(self, model: BaseAlgorithm) -> None: + """ + Save trained model optionally with its replay buffer + and ``VecNormalize`` statistics + + :param model: + """ + print(f"Saving to {self.save_path}") + model.save(f"{self.save_path}/{self.env_id}") + + if hasattr(model, "save_replay_buffer") and self.save_replay_buffer: + print("Saving replay buffer") + model.save_replay_buffer(os.path.join(self.save_path, "replay_buffer.pkl")) + + if self.normalize: + # Important: save the running average, for testing the agent we need that normalization + model.get_vec_normalize_env().save( + os.path.join(self.params_path, "vecnormalize.pkl") + ) + + def _save_config(self, saved_hyperparams: Dict[str, Any]) -> None: + """ + Save unprocessed hyperparameters, this can be use later + to reproduce an experiment. + + :param saved_hyperparams: + """ + # Save hyperparams + with open(os.path.join(self.params_path, "config.yml"), "w") as f: + yaml.dump(saved_hyperparams, f) + + # save command line arguments + with open(os.path.join(self.params_path, "args.yml"), "w") as f: + ordered_args = OrderedDict( + [(key, vars(self.args)[key]) for key in sorted(vars(self.args).keys())] + ) + yaml.dump(ordered_args, f) + + print(f"Log path: {self.save_path}") + + def read_hyperparameters(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # Load hyperparameters from yaml file + hyperparams_dir = os.path.abspath( + os.path.join( + os.path.realpath(__file__), *3 * [os.path.pardir], "hyperparams" + ) + ) + with open(f"{hyperparams_dir}/{self.algo}.yml", "r") as f: + hyperparams_dict = yaml.safe_load(f) + if self.env_id in list(hyperparams_dict.keys()): + hyperparams = hyperparams_dict[self.env_id] + elif self._is_atari: + hyperparams = hyperparams_dict["atari"] + else: + raise ValueError( + f"Hyperparameters not found for {self.algo}-{self.env_id}" + ) + + if self.custom_hyperparams is not None: + # Overwrite hyperparams if needed + hyperparams.update(self.custom_hyperparams) + # Sort hyperparams that will be saved + saved_hyperparams = OrderedDict( + [(key, hyperparams[key]) for key in sorted(hyperparams.keys())] + ) + + if self.verbose > 0: + print( + "Default hyperparameters for environment (ones being tuned will be overridden):" + ) + pprint(saved_hyperparams) + + return hyperparams, saved_hyperparams + + @staticmethod + def _preprocess_schedules(hyperparams: Dict[str, Any]) -> Dict[str, Any]: + # Create schedules + for key in ["learning_rate", "clip_range", "clip_range_vf"]: + if key not in hyperparams: + continue + if isinstance(hyperparams[key], str): + schedule, initial_value = hyperparams[key].split("_") + initial_value = float(initial_value) + hyperparams[key] = linear_schedule(initial_value) + elif isinstance(hyperparams[key], (float, int)): + # Negative value: ignore (ex: for clipping) + if hyperparams[key] < 0: + continue + hyperparams[key] = constant_fn(float(hyperparams[key])) + else: + raise ValueError(f"Invalid value for {key}: {hyperparams[key]}") + return hyperparams + + def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, Any]: + if "normalize" in hyperparams.keys(): + self.normalize = hyperparams["normalize"] + + # Special case, instead of both normalizing + # both observation and reward, we can normalize one of the two. + # in that case `hyperparams["normalize"]` is a string + # that can be evaluated as python, + # ex: "dict(norm_obs=False, norm_reward=True)" + if isinstance(self.normalize, str): + self.normalize_kwargs = eval(self.normalize) + self.normalize = True + + # Use the same discount factor as for the algorithm + if "gamma" in hyperparams: + self.normalize_kwargs["gamma"] = hyperparams["gamma"] + + del hyperparams["normalize"] + return hyperparams + + def _preprocess_hyperparams( + self, hyperparams: Dict[str, Any] + ) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback]]: + self.n_envs = hyperparams.get("n_envs", 1) + + if self.verbose > 0: + print(f"Using {self.n_envs} environments") + + # Convert schedule strings to objects + hyperparams = self._preprocess_schedules(hyperparams) + + # Pre-process train_freq + if "train_freq" in hyperparams and isinstance(hyperparams["train_freq"], list): + hyperparams["train_freq"] = tuple(hyperparams["train_freq"]) + + # Should we overwrite the number of timesteps? + if self.n_timesteps > 0: + if self.verbose: + print(f"Overwriting n_timesteps with n={self.n_timesteps}") + else: + self.n_timesteps = int(hyperparams["n_timesteps"]) + + # Pre-process normalize config + hyperparams = self._preprocess_normalization(hyperparams) + + # Pre-process policy/buffer keyword arguments + # Convert to python object if needed + # TODO: Use the new replay_buffer_class argument of offpolicy algorithms instead of monkey patch + for kwargs_key in { + "policy_kwargs", + "replay_buffer_class", + "replay_buffer_kwargs", + }: + if kwargs_key in hyperparams.keys() and isinstance( + hyperparams[kwargs_key], str + ): + hyperparams[kwargs_key] = eval(hyperparams[kwargs_key]) + + # Delete keys so the dict can be pass to the model constructor + if "n_envs" in hyperparams.keys(): + del hyperparams["n_envs"] + del hyperparams["n_timesteps"] + + if "frame_stack" in hyperparams.keys(): + self.frame_stack = hyperparams["frame_stack"] + del hyperparams["frame_stack"] + + # obtain a class object from a wrapper name string in hyperparams + # and delete the entry + env_wrapper = get_wrapper_class(hyperparams) + if "env_wrapper" in hyperparams.keys(): + del hyperparams["env_wrapper"] + + callbacks = get_callback_list(hyperparams) + if "callback" in hyperparams.keys(): + self.specified_callbacks = hyperparams["callback"] + del hyperparams["callback"] + + return hyperparams, env_wrapper, callbacks + + def _preprocess_action_noise( + self, hyperparams: Dict[str, Any], env: VecEnv + ) -> Dict[str, Any]: + # Parse noise string + # Note: only off-policy algorithms are supported + if hyperparams.get("noise_type") is not None: + noise_type = hyperparams["noise_type"].strip() + noise_std = hyperparams["noise_std"] + + # Save for later (hyperparameter optimization) + self.n_actions = env.action_space.shape[0] + + if "normal" in noise_type: + hyperparams["action_noise"] = NormalActionNoise( + mean=np.zeros(self.n_actions), + sigma=noise_std * np.ones(self.n_actions), + ) + elif "ornstein-uhlenbeck" in noise_type: + hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise( + mean=np.zeros(self.n_actions), + sigma=noise_std * np.ones(self.n_actions), + ) + else: + raise RuntimeError(f'Unknown noise type "{noise_type}"') + + print(f"Applying {noise_type} noise with std {noise_std}") + + del hyperparams["noise_type"] + del hyperparams["noise_std"] + + return hyperparams + + def create_log_folder(self): + os.makedirs(self.params_path, exist_ok=True) + + def create_callbacks(self): + + if self.save_freq > 0: + # Account for the number of parallel environments + self.save_freq = max(self.save_freq // self.n_envs, 1) + self.callbacks.append( + CheckpointCallbackWithReplayBuffer( + save_freq=self.save_freq, + save_path=self.save_path, + name_prefix="rl_model", + save_replay_buffer=self.save_replay_buffer, + verbose=self.verbose, + ) + ) + + # Create test env if needed, do not normalize reward + if self.eval_freq > 0 and not self.optimize_hyperparameters: + # Account for the number of parallel environments + self.eval_freq = max(self.eval_freq // self.n_envs, 1) + + if self.verbose > 0: + print("Creating test environment") + + save_vec_normalize = SaveVecNormalizeCallback( + save_freq=1, save_path=self.params_path + ) + eval_callback = EvalCallback( + eval_env=self._env, + # TODO: Use separate environment(s) for evaluation + # self.create_envs(self.n_eval_envs, eval_env=True), + callback_on_new_best=save_vec_normalize, + best_model_save_path=self.save_path, + n_eval_episodes=self.n_eval_episodes, + log_path=self.save_path, + eval_freq=self.eval_freq, + deterministic=self.deterministic_eval, + ) + + self.callbacks.append(eval_callback) + + @staticmethod + def is_atari(env_id: str) -> bool: + entry_point = gym.envs.registry[env_id].entry_point + return "AtariEnv" in str(entry_point) + + @staticmethod + def is_bullet(env_id: str) -> bool: + entry_point = gym.envs.registry[env_id].entry_point + return "pybullet_envs" in str(entry_point) + + @staticmethod + def is_robotics_env(env_id: str) -> bool: + entry_point = gym.envs.registry[env_id].entry_point + return "gym.envs.robotics" in str(entry_point) or "panda_gym.envs" in str( + entry_point + ) + + @staticmethod + def is_gazebo_env(env_id: str) -> bool: + return "Gazebo" in gym.envs.registry[env_id].entry_point + + def _maybe_normalize(self, env: VecEnv, eval_env: bool) -> VecEnv: + """ + Wrap the env into a VecNormalize wrapper if needed + and load saved statistics when present. + + :param env: + :param eval_env: + :return: + """ + # Pretrained model, load normalization + path_ = os.path.join(os.path.dirname(self.trained_agent), self.env_id) + path_ = os.path.join(path_, "vecnormalize.pkl") + + if os.path.exists(path_): + print("Loading saved VecNormalize stats") + env = VecNormalize.load(path_, env) + # Deactivate training and reward normalization + if eval_env: + env.training = False + env.norm_reward = False + + elif self.normalize: + # Copy to avoid changing default values by reference + local_normalize_kwargs = self.normalize_kwargs.copy() + # Do not normalize reward for env used for evaluation + if eval_env: + if len(local_normalize_kwargs) > 0: + local_normalize_kwargs["norm_reward"] = False + else: + local_normalize_kwargs = {"norm_reward": False} + + if self.verbose > 0: + if len(local_normalize_kwargs) > 0: + print(f"Normalization activated: {local_normalize_kwargs}") + else: + print("Normalizing input and reward") + # Note: The following line was added but not sure if it is still required + env.num_envs = self.n_envs + env = VecNormalize(env, **local_normalize_kwargs) + return env + + def create_envs( + self, n_envs: int, eval_env: bool = False, no_log: bool = False + ) -> VecEnv: + """ + Create the environment and wrap it if necessary. + + :param n_envs: + :param eval_env: Whether is it an environment used for evaluation or not + :param no_log: Do not log training when doing hyperparameter optim + (issue with writing the same file) + :return: the vectorized environment, with appropriate wrappers + """ + # Do not log eval env (issue with writing the same file) + log_dir = None if eval_env or no_log else self.save_path + + monitor_kwargs = {} + # Special case for GoalEnvs: log success rate too + if ( + "Neck" in self.env_id + or self.is_robotics_env(self.env_id) + or "parking-v0" in self.env_id + ): + monitor_kwargs = dict(info_keywords=("is_success",)) + + # On most env, SubprocVecEnv does not help and is quite memory hungry + # therefore we use DummyVecEnv by default + env = make_vec_env( + env_id=self.env_id, + n_envs=n_envs, + seed=self.seed, + env_kwargs=self.env_kwargs, + monitor_dir=log_dir, + wrapper_class=self.env_wrapper, + vec_env_cls=self.vec_env_class, + vec_env_kwargs=self.vec_env_kwargs, + monitor_kwargs=monitor_kwargs, + ) + + # Wrap the env into a VecNormalize wrapper if needed + # and load saved statistics when present + env = self._maybe_normalize(env, eval_env) + + # Optional Frame-stacking + if self.frame_stack is not None: + n_stack = self.frame_stack + env = VecFrameStack(env, n_stack) + if self.verbose > 0: + print(f"Stacking {n_stack} frames") + + if not is_vecenv_wrapped(env, VecTransposeImage): + wrap_with_vectranspose = False + if isinstance(env.observation_space, gym.spaces.Dict): + # If even one of the keys is a image-space in need of transpose, apply transpose + # If the image spaces are not consistent (for instance one is channel first, + # the other channel last), VecTransposeImage will throw an error + for space in env.observation_space.spaces.values(): + wrap_with_vectranspose = wrap_with_vectranspose or ( + is_image_space(space) + and not is_image_space_channels_first(space) + ) + else: + wrap_with_vectranspose = is_image_space( + env.observation_space + ) and not is_image_space_channels_first(env.observation_space) + + if wrap_with_vectranspose: + if self.verbose >= 1: + print("Wrapping the env in a VecTransposeImage.") + env = VecTransposeImage(env) + + return env + + def _load_pretrained_agent( + self, hyperparams: Dict[str, Any], env: VecEnv + ) -> BaseAlgorithm: + # Continue training + print( + f"Loading pretrained agent '{self.trained_agent}' to continue its training" + ) + # Policy should not be changed + del hyperparams["policy"] + + if "policy_kwargs" in hyperparams.keys(): + del hyperparams["policy_kwargs"] + + model = ALGOS[self.algo].load( + self.trained_agent, + env=env, + seed=self.seed, + tensorboard_log=self.tensorboard_log, + verbose=self.verbose, + **hyperparams, + ) + + replay_buffer_path = os.path.join( + os.path.dirname(self.trained_agent), "replay_buffer.pkl" + ) + + if not self.preload_replay_buffer and os.path.exists(replay_buffer_path): + print("Loading replay buffer") + # `truncate_last_traj` will be taken into account only if we use HER replay buffer + model.load_replay_buffer( + replay_buffer_path, truncate_last_traj=self.truncate_last_trajectory + ) + return model + + def _create_sampler(self, sampler_method: str) -> BaseSampler: + # n_warmup_steps: Disable pruner until the trial reaches the given number of step. + if sampler_method == "random": + sampler = RandomSampler(seed=self.seed) + elif sampler_method == "tpe": + # TODO (external): try with multivariate=True + sampler = TPESampler(n_startup_trials=self.n_startup_trials, seed=self.seed) + elif sampler_method == "skopt": + # cf https://scikit-optimize.github.io/#skopt.Optimizer + # GP: gaussian process + # Gradient boosted regression: GBRT + sampler = SkoptSampler( + skopt_kwargs={"base_estimator": "GP", "acq_func": "gp_hedge"} + ) + else: + raise ValueError(f"Unknown sampler: {sampler_method}") + return sampler + + def _create_pruner(self, pruner_method: str) -> BasePruner: + if pruner_method == "halving": + pruner = SuccessiveHalvingPruner( + min_resource=1, reduction_factor=4, min_early_stopping_rate=0 + ) + elif pruner_method == "median": + pruner = MedianPruner( + n_startup_trials=self.n_startup_trials, + n_warmup_steps=self.n_evaluations // 3, + ) + elif pruner_method == "none": + # Do not prune + pruner = NopPruner() + else: + raise ValueError(f"Unknown pruner: {pruner_method}") + return pruner + + # Important: Objective changed in this project (rbs_gym) to evaluate on the same environment that is used for training (cannot have two existing simulatenously with the current setup) + def objective(self, trial: optuna.Trial) -> float: + + kwargs = self._hyperparams.copy() + + trial.model_class = None + + # Hack to use DDPG/TD3 noise sampler + trial.n_actions = self._env.action_space.shape[0] + + # Hack when using HerReplayBuffer + if kwargs.get("replay_buffer_class") == HerReplayBuffer: + trial.her_kwargs = kwargs.get("replay_buffer_kwargs", {}) + + # Sample candidate hyperparameters + kwargs.update(HYPERPARAMS_SAMPLER[self.algo](trial)) + print(f"\nRunning a new trial with hyperparameters: {kwargs}") + + # Write hyperparameters into a file + trial_params_path = os.path.join(self.params_path, "optimization") + os.makedirs(trial_params_path, exist_ok=True) + with open( + os.path.join( + trial_params_path, f"hyperparameters_trial_{trial.number}.yml" + ), + "w", + ) as f: + yaml.dump(kwargs, f) + + model = ALGOS[self.algo]( + env=self._env, + # Note: Here I enabled tensorboard logs + tensorboard_log=self.tensorboard_log, + # Note: Here I differ and I seed the trial. I want all trials to have the same starting conditions + seed=self.seed, + verbose=self.verbose, + **kwargs, + ) + + # Pre-load replay buffer if enabled + if self.preload_replay_buffer: + if self.preload_replay_buffer.endswith(".pkl"): + replay_buffer_path = self.preload_replay_buffer + else: + replay_buffer_path = os.path.join( + self.preload_replay_buffer, "replay_buffer.pkl" + ) + if os.path.exists(replay_buffer_path): + print("Pre-loading replay buffer") + if self.algo == "her": + model.load_replay_buffer( + replay_buffer_path, self.truncate_last_trajectory + ) + else: + model.load_replay_buffer(replay_buffer_path) + else: + raise Exception(f"Replay buffer {replay_buffer_path} " "does not exist") + + model.trial = trial + + eval_freq = int(self.n_timesteps / self.n_evaluations) + # Account for parallel envs + eval_freq_ = max(eval_freq // model.get_env().num_envs, 1) + # Use non-deterministic eval for Atari + callbacks = get_callback_list({"callback": self.specified_callbacks}) + path = None + if self.optimization_log_path is not None: + path = os.path.join( + self.optimization_log_path, f"trial_{str(trial.number)}" + ) + eval_callback = TrialEvalCallback( + # TODO: Use a separate environment for evaluation during trial + model.env, + model.trial, + best_model_save_path=path, + log_path=path, + n_eval_episodes=self.n_eval_episodes, + eval_freq=eval_freq_, + deterministic=self.deterministic_eval, + verbose=self.verbose, + ) + callbacks.append(eval_callback) + + try: + model.learn(self.n_timesteps, callback=callbacks) + # Reset env + self._env.reset() + except AssertionError as e: + # Reset env + self._env.reset() + print("Trial stopped:", e) + # Prune hyperparams that generate NaNs + raise optuna.exceptions.TrialPruned() + except Exception as err: + exception_type = type(err).__name__ + print("Trial stopped due to raised exception:", exception_type, err) + # Prune also all other exceptions + raise optuna.exceptions.TrialPruned() + is_pruned = eval_callback.is_pruned + reward = eval_callback.last_mean_reward + + print( + f"\nFinished a trial with reward={reward}, is_pruned={is_pruned} " + f"for hyperparameters: {kwargs}" + ) + + del model + + if is_pruned: + raise optuna.exceptions.TrialPruned() + + return reward + + def hyperparameters_optimization(self) -> None: + + if self.verbose > 0: + print("Optimizing hyperparameters") + + if self.storage is not None and self.study_name is None: + warnings.warn( + f"You passed a remote storage: {self.storage} but no `--study-name`." + "The study name will be generated by Optuna, make sure to re-use the same study name " + "when you want to do distributed hyperparameter optimization." + ) + + if self.tensorboard_log is not None: + warnings.warn( + "Tensorboard log is deactivated when running hyperparameter optimization" + ) + self.tensorboard_log = None + + # TODO (external): eval each hyperparams several times to account for noisy evaluation + sampler = self._create_sampler(self.sampler) + pruner = self._create_pruner(self.pruner) + + if self.verbose > 0: + print(f"Sampler: {self.sampler} - Pruner: {self.pruner}") + + study = optuna.create_study( + sampler=sampler, + pruner=pruner, + storage=self.storage, + study_name=self.study_name, + load_if_exists=True, + direction="maximize", + ) + + try: + study.optimize( + self.objective, + n_trials=self.n_trials, + n_jobs=self.n_jobs, + gc_after_trial=True, + show_progress_bar=True, + ) + except KeyboardInterrupt: + pass + + print("Number of finished trials: ", len(study.trials)) + + print("Best trial:") + trial = study.best_trial + + print("Value: ", trial.value) + + print("Params: ") + for key, value in trial.params.items(): + print(f" {key}: {value}") + + report_name = ( + f"report_{self.env_id}_{self.n_trials}-trials-{self.n_timesteps}" + f"-{self.sampler}-{self.pruner}_{int(time.time())}" + ) + + log_path = os.path.join(self.log_folder, self.algo, report_name) + + if self.verbose: + print(f"Writing report to {log_path}") + + # Write report + os.makedirs(os.path.dirname(log_path), exist_ok=True) + study.trials_dataframe().to_csv(f"{log_path}.csv") + + # Save python object to inspect/re-use it later + with open(f"{log_path}.pkl", "wb+") as f: + pkl.dump(study, f) + + # Skip plots + if self.no_optim_plots: + return + + # Plot optimization result + try: + fig1 = plot_optimization_history(study) + fig2 = plot_param_importances(study) + + fig1.show() + fig2.show() + except (ValueError, ImportError, RuntimeError): + pass + + def collect_demonstration(self, model): + # Any random action will do (this won't actually be used since `preload_replay_buffer` env kwarg is enabled) + action = np.array([model.env.action_space.sample()]) + # Reset env at the beginning + obs = model.env.reset() + # Collect transitions + for i in range(model.replay_buffer.buffer_size): + # Note: If `None` is passed to Grasp env, it uses custom action heuristic to reach the target + next_obs, rewards, dones, infos = model.env.unwrapped.step(action) + # Extract the actual actions from info + actual_actions = [info["actual_actions"] for info in infos] + # Add to replay buffer + model.replay_buffer.add(obs, next_obs, actual_actions, rewards, dones) + # Update current observation + obs = next_obs + + print("Saving replay buffer") + model.save_replay_buffer(os.path.join(self.save_path, "replay_buffer.pkl")) + model.env.close() + exit diff --git a/env_manager/rbs_gym/rbs_gym/utils/hyperparams_opt.py b/env_manager/rbs_gym/rbs_gym/utils/hyperparams_opt.py new file mode 100644 index 0000000..9d36cf5 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/utils/hyperparams_opt.py @@ -0,0 +1,170 @@ +from typing import Any, Dict + +import numpy as np +import optuna +from stable_baselines3.common.noise import NormalActionNoise +from torch import nn as nn + + +def sample_sac_params( + trial: optuna.Trial, +) -> Dict[str, Any]: + """ + Sampler for SAC hyperparameters + """ + + buffer_size = 150000 + # learning_starts = trial.suggest_categorical( + # "learning_starts", [5000, 10000, 20000]) + learning_starts = 5000 + + batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256, 512, 1024, 2048]) + learning_rate = trial.suggest_float( + "learning_rate", low=0.000001, high=0.001, log=True + ) + + gamma = trial.suggest_float("gamma", low=0.98, high=1.0, log=True) + tau = trial.suggest_float("tau", low=0.001, high=0.025, log=True) + + ent_coef = "auto_0.5_0.1" + target_entropy = "auto" + + noise_std = trial.suggest_float("noise_std", low=0.01, high=0.2, log=True) + action_noise = NormalActionNoise( + mean=np.zeros(trial.n_actions), sigma=np.ones(trial.n_actions) * noise_std + ) + + train_freq = 1 + gradient_steps = trial.suggest_categorical("gradient_steps", [1, 2]) + + policy_kwargs = dict() + net_arch = trial.suggest_categorical("net_arch", [256, 384, 512]) + policy_kwargs["net_arch"] = [net_arch] * 3 + + return { + "buffer_size": buffer_size, + "learning_starts": learning_starts, + "batch_size": batch_size, + "learning_rate": learning_rate, + "gamma": gamma, + "tau": tau, + "ent_coef": ent_coef, + "target_entropy": target_entropy, + "action_noise": action_noise, + "train_freq": train_freq, + "gradient_steps": gradient_steps, + "policy_kwargs": policy_kwargs, + } + + +def sample_td3_params( + trial: optuna.Trial, +) -> Dict[str, Any]: + """ + Sampler for TD3 hyperparameters + """ + + buffer_size = 150000 + # learning_starts = trial.suggest_categorical( + # "learning_starts", [5000, 10000, 20000]) + learning_starts = 5000 + + batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) + learning_rate = trial.suggest_float( + "learning_rate", low=0.000001, high=0.001, log=True + ) + + gamma = trial.suggest_float("gamma", low=0.98, high=1.0, log=True) + tau = trial.suggest_float("tau", low=0.001, high=0.025, log=True) + + target_policy_noise = trial.suggest_float( + "target_policy_noise", low=0.00000001, high=0.5, log=True + ) + target_noise_clip = 0.5 + + noise_std = trial.suggest_float("noise_std", low=0.025, high=0.5, log=True) + action_noise = NormalActionNoise( + mean=np.zeros(trial.n_actions), sigma=np.ones(trial.n_actions) * noise_std + ) + + train_freq = 1 + gradient_steps = trial.suggest_categorical("gradient_steps", [1, 2]) + + policy_kwargs = dict() + net_arch = trial.suggest_categorical("net_arch", [256, 384, 512]) + policy_kwargs["net_arch"] = [net_arch] * 3 + + return { + "buffer_size": buffer_size, + "learning_starts": learning_starts, + "batch_size": batch_size, + "learning_rate": learning_rate, + "gamma": gamma, + "tau": tau, + "target_policy_noise": target_policy_noise, + "target_noise_clip": target_noise_clip, + "action_noise": action_noise, + "train_freq": train_freq, + "gradient_steps": gradient_steps, + "policy_kwargs": policy_kwargs, + } + + +def sample_tqc_params( + trial: optuna.Trial, +) -> Dict[str, Any]: + """ + Sampler for TQC hyperparameters + """ + + buffer_size = 25000 + learning_starts = 0 + + batch_size = 32 + learning_rate = trial.suggest_float( + "learning_rate", low=0.000025, high=0.00075, log=True + ) + + gamma = 1.0 - trial.suggest_float("gamma", low=0.0001, high=0.025, log=True) + tau = trial.suggest_float("tau", low=0.0005, high=0.025, log=True) + + ent_coef = "auto_0.1_0.05" + target_entropy = "auto" + + noise_std = trial.suggest_float("noise_std", low=0.01, high=0.1, log=True) + action_noise = NormalActionNoise( + mean=np.zeros(trial.n_actions), sigma=np.ones(trial.n_actions) * noise_std + ) + + train_freq = 1 + gradient_steps = trial.suggest_categorical("gradient_steps", [1, 2]) + + policy_kwargs = dict() + net_arch = trial.suggest_categorical("net_arch", [128, 256, 384, 512]) + policy_kwargs["net_arch"] = [net_arch] * 2 + policy_kwargs["n_quantiles"] = trial.suggest_int("n_quantiles", low=20, high=40) + top_quantiles_to_drop_per_net = round(0.08 * policy_kwargs["n_quantiles"]) + policy_kwargs["n_critics"] = trial.suggest_categorical("n_critics", [2, 3]) + + return { + "buffer_size": buffer_size, + "learning_starts": learning_starts, + "batch_size": batch_size, + "learning_rate": learning_rate, + "gamma": gamma, + "tau": tau, + "ent_coef": ent_coef, + "target_entropy": target_entropy, + "top_quantiles_to_drop_per_net": top_quantiles_to_drop_per_net, + "action_noise": action_noise, + "train_freq": train_freq, + "gradient_steps": gradient_steps, + "policy_kwargs": policy_kwargs, + } + + +HYPERPARAMS_SAMPLER = { + "sac": sample_sac_params, + "td3": sample_td3_params, + "tqc": sample_tqc_params, +} diff --git a/env_manager/rbs_gym/rbs_gym/utils/utils.py b/env_manager/rbs_gym/rbs_gym/utils/utils.py new file mode 100644 index 0000000..c073dd6 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/utils/utils.py @@ -0,0 +1,411 @@ +import argparse +import glob +import importlib +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import gymnasium as gym + +# For custom activation fn +import stable_baselines3 as sb3 # noqa: F401 +import torch as th # noqa: F401 +import yaml +from sb3_contrib import QRDQN, TQC +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.sb2_compat.rmsprop_tf_like import ( # noqa: F401 + RMSpropTFLike, +) +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + SubprocVecEnv, + VecEnv, + VecFrameStack, + VecNormalize, +) +from torch import nn as nn # noqa: F401 pylint: disable=unused-import + +ALGOS = { + "a2c": A2C, + "ddpg": DDPG, + "dqn": DQN, + "ppo": PPO, + "sac": SAC, + "td3": TD3, + # SB3 Contrib, + "qrdqn": QRDQN, + "tqc": TQC, +} + + +def flatten_dict_observations(env: gym.Env) -> gym.Env: + assert isinstance(env.observation_space, gym.spaces.Dict) + try: + return gym.wrappers.FlattenObservation(env) + except AttributeError: + keys = env.observation_space.spaces.keys() + return gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys)) + + +def get_wrapper_class( + hyperparams: Dict[str, Any] +) -> Optional[Callable[[gym.Env], gym.Env]]: + """ + Get one or more Gym environment wrapper class specified as a hyper parameter + "env_wrapper". + e.g. + env_wrapper: gym_minigrid.wrappers.FlatObsWrapper + + for multiple, specify a list: + + env_wrapper: + - utils.wrappers.PlotActionWrapper + - utils.wrappers.TimeFeatureWrapper + + + :param hyperparams: + :return: maybe a callable to wrap the environment + with one or multiple gym.Wrapper + """ + + def get_module_name(wrapper_name): + return ".".join(wrapper_name.split(".")[:-1]) + + def get_class_name(wrapper_name): + return wrapper_name.split(".")[-1] + + if "env_wrapper" in hyperparams.keys(): + wrapper_name = hyperparams.get("env_wrapper") + + if wrapper_name is None: + return None + + if not isinstance(wrapper_name, list): + wrapper_names = [wrapper_name] + else: + wrapper_names = wrapper_name + + wrapper_classes = [] + wrapper_kwargs = [] + # Handle multiple wrappers + for wrapper_name in wrapper_names: + # Handle keyword arguments + if isinstance(wrapper_name, dict): + assert len(wrapper_name) == 1, ( + "You have an error in the formatting " + f"of your YAML file near {wrapper_name}. " + "You should check the indentation." + ) + wrapper_dict = wrapper_name + wrapper_name = list(wrapper_dict.keys())[0] + kwargs = wrapper_dict[wrapper_name] + else: + kwargs = {} + wrapper_module = importlib.import_module(get_module_name(wrapper_name)) + wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name)) + wrapper_classes.append(wrapper_class) + wrapper_kwargs.append(kwargs) + + def wrap_env(env: gym.Env) -> gym.Env: + """ + :param env: + :return: + """ + for wrapper_class, kwargs in zip(wrapper_classes, wrapper_kwargs): + env = wrapper_class(env, **kwargs) + return env + + return wrap_env + else: + return None + + +def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]: + """ + Get one or more Callback class specified as a hyper-parameter + "callback". + e.g. + callback: stable_baselines3.common.callbacks.CheckpointCallback + + for multiple, specify a list: + + callback: + - utils.callbacks.PlotActionWrapper + - stable_baselines3.common.callbacks.CheckpointCallback + + :param hyperparams: + :return: + """ + + def get_module_name(callback_name): + return ".".join(callback_name.split(".")[:-1]) + + def get_class_name(callback_name): + return callback_name.split(".")[-1] + + callbacks = [] + + if "callback" in hyperparams.keys(): + callback_name = hyperparams.get("callback") + + if callback_name is None: + return callbacks + + if not isinstance(callback_name, list): + callback_names = [callback_name] + else: + callback_names = callback_name + + # Handle multiple wrappers + for callback_name in callback_names: + # Handle keyword arguments + if isinstance(callback_name, dict): + assert len(callback_name) == 1, ( + "You have an error in the formatting " + f"of your YAML file near {callback_name}. " + "You should check the indentation." + ) + callback_dict = callback_name + callback_name = list(callback_dict.keys())[0] + kwargs = callback_dict[callback_name] + else: + kwargs = {} + callback_module = importlib.import_module(get_module_name(callback_name)) + callback_class = getattr(callback_module, get_class_name(callback_name)) + callbacks.append(callback_class(**kwargs)) + + return callbacks + + +def create_test_env( + env_id: str, + n_envs: int = 1, + stats_path: Optional[str] = None, + seed: int = 0, + log_dir: Optional[str] = None, + should_render: bool = True, + hyperparams: Optional[Dict[str, Any]] = None, + env_kwargs: Optional[Dict[str, Any]] = None, +) -> VecEnv: + """ + Create environment for testing a trained agent + + :param env_id: + :param n_envs: number of processes + :param stats_path: path to folder containing saved running averaged + :param seed: Seed for random number generator + :param log_dir: Where to log rewards + :param should_render: For Pybullet env, display the GUI + :param hyperparams: Additional hyperparams (ex: n_stack) + :param env_kwargs: Optional keyword argument to pass to the env constructor + :return: + """ + # Avoid circular import + from rbs_gym.utils.exp_manager import ExperimentManager + + # Create the environment and wrap it if necessary + env_wrapper = get_wrapper_class(hyperparams) + + hyperparams = {} if hyperparams is None else hyperparams + + if "env_wrapper" in hyperparams.keys(): + del hyperparams["env_wrapper"] + + vec_env_kwargs = {} + vec_env_cls = DummyVecEnv + if n_envs > 1 or (ExperimentManager.is_bullet(env_id) and should_render): + # HACK: force SubprocVecEnv for Bullet env + # as Pybullet envs does not follow gym.render() interface + vec_env_cls = SubprocVecEnv + # start_method = 'spawn' for thread safe + + env = make_vec_env( + env_id, + n_envs=n_envs, + monitor_dir=log_dir, + seed=seed, + wrapper_class=env_wrapper, + env_kwargs=env_kwargs, + vec_env_cls=vec_env_cls, + vec_env_kwargs=vec_env_kwargs, + ) + + # Load saved stats for normalizing input and rewards + # And optionally stack frames + if stats_path is not None: + if hyperparams["normalize"]: + print("Loading running average") + print(f"with params: {hyperparams['normalize_kwargs']}") + path_ = os.path.join(stats_path, "vecnormalize.pkl") + if os.path.exists(path_): + env = VecNormalize.load(path_, env) + # Deactivate training and reward normalization + env.training = False + env.norm_reward = False + else: + raise ValueError(f"VecNormalize stats {path_} not found") + + n_stack = hyperparams.get("frame_stack", 0) + if n_stack > 0: + print(f"Stacking {n_stack} frames") + env = VecFrameStack(env, n_stack) + return env + + +def linear_schedule(initial_value: Union[float, str]) -> Callable[[float], float]: + """ + Linear learning rate schedule. + + :param initial_value: (float or str) + :return: (function) + """ + if isinstance(initial_value, str): + initial_value = float(initial_value) + + def func(progress_remaining: float) -> float: + """ + Progress will decrease from 1 (beginning) to 0 + :param progress_remaining: (float) + :return: (float) + """ + return progress_remaining * initial_value + + return func + + +def get_trained_models(log_folder: str) -> Dict[str, Tuple[str, str]]: + """ + :param log_folder: Root log folder + :return: Dict[str, Tuple[str, str]] representing the trained agents + """ + trained_models = {} + for algo in os.listdir(log_folder): + if not os.path.isdir(os.path.join(log_folder, algo)): + continue + for env_id in os.listdir(os.path.join(log_folder, algo)): + # Retrieve env name + env_id = env_id.split("_")[0] + trained_models[f"{algo}-{env_id}"] = (algo, env_id) + return trained_models + + +def get_latest_run_id(log_path: str, env_id: str) -> int: + """ + Returns the latest run number for the given log name and log path, + by finding the greatest number in the directories. + + :param log_path: path to log folder + :param env_id: + :return: latest run number + """ + max_run_id = 0 + for path in glob.glob(os.path.join(log_path, env_id + "_[0-9]*")): + file_name = os.path.basename(path) + ext = file_name.split("_")[-1] + if ( + env_id == "_".join(file_name.split("_")[:-1]) + and ext.isdigit() + and int(ext) > max_run_id + ): + max_run_id = int(ext) + return max_run_id + + +def get_saved_hyperparams( + stats_path: str, + norm_reward: bool = False, + test_mode: bool = False, +) -> Tuple[Dict[str, Any], str]: + """ + :param stats_path: + :param norm_reward: + :param test_mode: + :return: + """ + hyperparams = {} + if not os.path.isdir(stats_path): + stats_path = None + else: + config_file = os.path.join(stats_path, "config.yml") + if os.path.isfile(config_file): + # Load saved hyperparameters + with open(os.path.join(stats_path, "config.yml"), "r") as f: + hyperparams = yaml.load( + f, Loader=yaml.UnsafeLoader + ) # pytype: disable=module-attr + hyperparams["normalize"] = hyperparams.get("normalize", False) + else: + obs_rms_path = os.path.join(stats_path, "obs_rms.pkl") + hyperparams["normalize"] = os.path.isfile(obs_rms_path) + + # Load normalization params + if hyperparams["normalize"]: + if isinstance(hyperparams["normalize"], str): + normalize_kwargs = eval(hyperparams["normalize"]) + if test_mode: + normalize_kwargs["norm_reward"] = norm_reward + else: + normalize_kwargs = { + "norm_obs": hyperparams["normalize"], + "norm_reward": norm_reward, + } + hyperparams["normalize_kwargs"] = normalize_kwargs + return hyperparams, stats_path + + +class StoreDict(argparse.Action): + """ + Custom argparse action for storing dict. + + In: args1:0.0 args2:"dict(a=1)" + Out: {'args1': 0.0, arg2: dict(a=1)} + """ + + def __init__(self, option_strings, dest, nargs=None, **kwargs): + self._nargs = nargs + super(StoreDict, self).__init__(option_strings, dest, nargs=nargs, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + + arg_dict = {} + + if hasattr(namespace, self.dest): + current_arg = getattr(namespace, self.dest) + if isinstance(current_arg, Dict): + arg_dict = getattr(namespace, self.dest) + + for arguments in values: + if not arguments: + continue + key = arguments.split(":")[0] + value = ":".join(arguments.split(":")[1:]) + # Evaluate the string as python code + arg_dict[key] = eval(value) + setattr(namespace, self.dest, arg_dict) + + +def str2bool(value: Union[str, bool]) -> bool: + """ + Convert logical string to boolean. Can be used as argparse type. + """ + + if isinstance(value, bool): + return value + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + elif value.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def empty_str2none(value: Optional[str]) -> Optional[str]: + """ + If string is empty, convert to None. Can be used as argparse type. + """ + + if not value: + return None + return value diff --git a/env_manager/rbs_gym/rbs_gym/utils/wrappers.py b/env_manager/rbs_gym/rbs_gym/utils/wrappers.py new file mode 100644 index 0000000..9107117 --- /dev/null +++ b/env_manager/rbs_gym/rbs_gym/utils/wrappers.py @@ -0,0 +1,395 @@ +import gymnasium as gym +import numpy as np +from matplotlib import pyplot as plt +from scipy.signal import iirfilter, sosfilt, zpk2sos + + +class DoneOnSuccessWrapper(gym.Wrapper): + """ + Reset on success and offsets the reward. + Useful for GoalEnv. + """ + + def __init__(self, env: gym.Env, reward_offset: float = 0.0, n_successes: int = 1): + super(DoneOnSuccessWrapper, self).__init__(env) + self.reward_offset = reward_offset + self.n_successes = n_successes + self.current_successes = 0 + + def reset(self): + self.current_successes = 0 + return self.env.reset() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + if info.get("is_success", False): + self.current_successes += 1 + else: + self.current_successes = 0 + # number of successes in a row + done = done or self.current_successes >= self.n_successes + reward += self.reward_offset + return obs, reward, done, info + + def compute_reward(self, achieved_goal, desired_goal, info): + reward = self.env.compute_reward(achieved_goal, desired_goal, info) + return reward + self.reward_offset + + +class ActionNoiseWrapper(gym.Wrapper): + """ + Add gaussian noise to the action (without telling the agent), + to test the robustness of the control. + + :param env: (gym.Env) + :param noise_std: (float) Standard deviation of the noise + """ + + def __init__(self, env, noise_std=0.1): + super(ActionNoiseWrapper, self).__init__(env) + self.noise_std = noise_std + + def step(self, action): + noise = np.random.normal( + np.zeros_like(action), np.ones_like(action) * self.noise_std + ) + noisy_action = action + noise + return self.env.step(noisy_action) + + +# from https://docs.obspy.org +def lowpass(data, freq, df, corners=4, zerophase=False): + """ + Butterworth-Lowpass Filter. + + Filter data removing data over certain frequency ``freq`` using ``corners`` + corners. + The filter uses :func:`scipy.signal.iirfilter` (for design) + and :func:`scipy.signal.sosfilt` (for applying the filter). + + :type data: numpy.ndarray + :param data: Data to filter. + :param freq: Filter corner frequency. + :param df: Sampling rate in Hz. + :param corners: Filter corners / order. + :param zerophase: If True, apply filter once forwards and once backwards. + This results in twice the number of corners but zero phase shift in + the resulting filtered trace. + :return: Filtered data. + """ + fe = 0.5 * df + f = freq / fe + # raise for some bad scenarios + if f > 1: + f = 1.0 + msg = ( + "Selected corner frequency is above Nyquist. " + + "Setting Nyquist as high corner." + ) + print(msg) + z, p, k = iirfilter(corners, f, btype="lowpass", ftype="butter", output="zpk") + sos = zpk2sos(z, p, k) + if zerophase: + firstpass = sosfilt(sos, data) + return sosfilt(sos, firstpass[::-1])[::-1] + else: + return sosfilt(sos, data) + + +class LowPassFilterWrapper(gym.Wrapper): + """ + Butterworth-Lowpass + + :param env: (gym.Env) + :param freq: Filter corner frequency. + :param df: Sampling rate in Hz. + """ + + def __init__(self, env, freq=5.0, df=25.0): + super(LowPassFilterWrapper, self).__init__(env) + self.freq = freq + self.df = df + self.signal = [] + + def reset(self): + self.signal = [] + return self.env.reset() + + def step(self, action): + self.signal.append(action) + filtered = np.zeros_like(action) + for i in range(self.action_space.shape[0]): + smoothed_action = lowpass( + np.array(self.signal)[:, i], freq=self.freq, df=self.df + ) + filtered[i] = smoothed_action[-1] + return self.env.step(filtered) + + +class ActionSmoothingWrapper(gym.Wrapper): + """ + Smooth the action using exponential moving average. + + :param env: (gym.Env) + :param smoothing_coef: (float) Smoothing coefficient (0 no smoothing, 1 very smooth) + """ + + def __init__(self, env, smoothing_coef: float = 0.0): + super(ActionSmoothingWrapper, self).__init__(env) + self.smoothing_coef = smoothing_coef + self.smoothed_action = None + # from https://github.com/rail-berkeley/softlearning/issues/3 + # for smoothing latent space + # self.alpha = self.smoothing_coef + # self.beta = np.sqrt(1 - self.alpha ** 2) / (1 - self.alpha) + + def reset(self): + self.smoothed_action = None + return self.env.reset() + + def step(self, action): + if self.smoothed_action is None: + self.smoothed_action = np.zeros_like(action) + self.smoothed_action = ( + self.smoothing_coef * self.smoothed_action + + (1 - self.smoothing_coef) * action + ) + return self.env.step(self.smoothed_action) + + +class DelayedRewardWrapper(gym.Wrapper): + """ + Delay the reward by `delay` steps, it makes the task harder but more realistic. + The reward is accumulated during those steps. + + :param env: (gym.Env) + :param delay: (int) Number of steps the reward should be delayed. + """ + + def __init__(self, env, delay=10): + super(DelayedRewardWrapper, self).__init__(env) + self.delay = delay + self.current_step = 0 + self.accumulated_reward = 0.0 + + def reset(self): + self.current_step = 0 + self.accumulated_reward = 0.0 + return self.env.reset() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + + self.accumulated_reward += reward + self.current_step += 1 + + if self.current_step % self.delay == 0 or done: + reward = self.accumulated_reward + self.accumulated_reward = 0.0 + else: + reward = 0.0 + return obs, reward, done, info + + +class HistoryWrapper(gym.Wrapper): + """ + Stack past observations and actions to give an history to the agent. + + :param env: (gym.Env) + :param horizon: (int) Number of steps to keep in the history. + """ + + def __init__(self, env: gym.Env, horizon: int = 5): + assert isinstance(env.observation_space, gym.spaces.Box) + + wrapped_obs_space = env.observation_space + wrapped_action_space = env.action_space + + # TODO (external): double check, it seems wrong when we have different low and highs + low_obs = np.repeat(wrapped_obs_space.low, horizon, axis=-1) + high_obs = np.repeat(wrapped_obs_space.high, horizon, axis=-1) + + low_action = np.repeat(wrapped_action_space.low, horizon, axis=-1) + high_action = np.repeat(wrapped_action_space.high, horizon, axis=-1) + + low = np.concatenate((low_obs, low_action)) + high = np.concatenate((high_obs, high_action)) + + # Overwrite the observation space + env.observation_space = gym.spaces.Box( + low=low, high=high, dtype=wrapped_obs_space.dtype + ) + + super(HistoryWrapper, self).__init__(env) + + self.horizon = horizon + self.low_action, self.high_action = low_action, high_action + self.low_obs, self.high_obs = low_obs, high_obs + self.low, self.high = low, high + self.obs_history = np.zeros(low_obs.shape, low_obs.dtype) + self.action_history = np.zeros(low_action.shape, low_action.dtype) + + def _create_obs_from_history(self): + return np.concatenate((self.obs_history, self.action_history)) + + def reset(self): + # Flush the history + self.obs_history[...] = 0 + self.action_history[...] = 0 + obs = self.env.reset() + self.obs_history[..., -obs.shape[-1] :] = obs + return self._create_obs_from_history() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + last_ax_size = obs.shape[-1] + + self.obs_history = np.roll(self.obs_history, shift=-last_ax_size, axis=-1) + self.obs_history[..., -obs.shape[-1] :] = obs + + self.action_history = np.roll( + self.action_history, shift=-action.shape[-1], axis=-1 + ) + self.action_history[..., -action.shape[-1] :] = action + return self._create_obs_from_history(), reward, done, info + + +class HistoryWrapperObsDict(gym.Wrapper): + """ + History Wrapper for dict observation. + + :param env: (gym.Env) + :param horizon: (int) Number of steps to keep in the history. + """ + + def __init__(self, env, horizon=5): + assert isinstance(env.observation_space.spaces["observation"], gym.spaces.Box) + + wrapped_obs_space = env.observation_space.spaces["observation"] + wrapped_action_space = env.action_space + + # TODO (external): double check, it seems wrong when we have different low and highs + low_obs = np.repeat(wrapped_obs_space.low, horizon, axis=-1) + high_obs = np.repeat(wrapped_obs_space.high, horizon, axis=-1) + + low_action = np.repeat(wrapped_action_space.low, horizon, axis=-1) + high_action = np.repeat(wrapped_action_space.high, horizon, axis=-1) + + low = np.concatenate((low_obs, low_action)) + high = np.concatenate((high_obs, high_action)) + + # Overwrite the observation space + env.observation_space.spaces["observation"] = gym.spaces.Box( + low=low, high=high, dtype=wrapped_obs_space.dtype + ) + + super(HistoryWrapperObsDict, self).__init__(env) + + self.horizon = horizon + self.low_action, self.high_action = low_action, high_action + self.low_obs, self.high_obs = low_obs, high_obs + self.low, self.high = low, high + self.obs_history = np.zeros(low_obs.shape, low_obs.dtype) + self.action_history = np.zeros(low_action.shape, low_action.dtype) + + def _create_obs_from_history(self): + return np.concatenate((self.obs_history, self.action_history)) + + def reset(self): + # Flush the history + self.obs_history[...] = 0 + self.action_history[...] = 0 + obs_dict = self.env.reset() + obs = obs_dict["observation"] + self.obs_history[..., -obs.shape[-1] :] = obs + + obs_dict["observation"] = self._create_obs_from_history() + + return obs_dict + + def step(self, action): + obs_dict, reward, done, info = self.env.step(action) + obs = obs_dict["observation"] + last_ax_size = obs.shape[-1] + + self.obs_history = np.roll(self.obs_history, shift=-last_ax_size, axis=-1) + self.obs_history[..., -obs.shape[-1] :] = obs + + self.action_history = np.roll( + self.action_history, shift=-action.shape[-1], axis=-1 + ) + self.action_history[..., -action.shape[-1] :] = action + + obs_dict["observation"] = self._create_obs_from_history() + + return obs_dict, reward, done, info + + +class PlotActionWrapper(gym.Wrapper): + """ + Wrapper for plotting the taken actions. + Only works with 1D actions for now. + Optionally, it can be used to plot the observations too. + + :param env: (gym.Env) + :param plot_freq: (int) Plot every `plot_freq` episodes + """ + + def __init__(self, env, plot_freq=5): + super(PlotActionWrapper, self).__init__(env) + self.plot_freq = plot_freq + self.current_episode = 0 + # Observation buffer (Optional) + # self.observations = [] + # Action buffer + self.actions = [] + + def reset(self): + self.current_episode += 1 + if self.current_episode % self.plot_freq == 0: + self.plot() + # Reset + self.actions = [] + obs = self.env.reset() + self.actions.append([]) + # self.observations.append(obs) + return obs + + def step(self, action): + obs, reward, done, info = self.env.step(action) + + self.actions[-1].append(action) + # self.observations.append(obs) + + return obs, reward, done, info + + def plot(self): + actions = self.actions + x = np.arange(sum([len(episode) for episode in actions])) + plt.figure("Actions") + plt.title("Actions during exploration", fontsize=14) + plt.xlabel("Timesteps", fontsize=14) + plt.ylabel("Action", fontsize=14) + + start = 0 + for i in range(len(self.actions)): + end = start + len(self.actions[i]) + plt.plot(x[start:end], self.actions[i]) + # Clipped actions: real behavior, note that it is between [-2, 2] for the Pendulum + # plt.scatter(x[start:end], np.clip(self.actions[i], -1, 1), s=1) + # plt.scatter(x[start:end], self.actions[i], s=1) + start = end + + plt.show() + + +class FeatureExtractorFreezeParammetersWrapper(gym.Wrapper): + """ + Freezes parameters of the feature extractor. + """ + + def __init__(self, env: gym.Env): + + super(FeatureExtractorFreezeParammetersWrapper, self).__init__(env) + for param in self.feature_extractor.parameters(): + param.requires_grad = False diff --git a/env_manager/rbs_gym/resource/rbs_gym b/env_manager/rbs_gym/resource/rbs_gym new file mode 100644 index 0000000..e69de29 diff --git a/env_manager/rbs_gym/setup.cfg b/env_manager/rbs_gym/setup.cfg new file mode 100644 index 0000000..759913d --- /dev/null +++ b/env_manager/rbs_gym/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/rbs_gym +[install] +install_scripts=$base/lib/rbs_gym diff --git a/env_manager/rbs_gym/setup.py b/env_manager/rbs_gym/setup.py new file mode 100644 index 0000000..e280fa4 --- /dev/null +++ b/env_manager/rbs_gym/setup.py @@ -0,0 +1,36 @@ +import os +from glob import glob + +from setuptools import find_packages, setup + +package_name = "rbs_gym" + +setup( + name=package_name, + version="0.0.0", + packages=find_packages(exclude=["test"]), + data_files=[ + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), + ( + os.path.join("share", package_name, "launch"), + glob(os.path.join("launch", "*launch.[pxy][yma]*")), + ), + ], + install_requires=["setuptools"], + zip_safe=True, + maintainer="narmak", + maintainer_email="ur.narmak@gmail.com", + description="TODO: Package description", + license="Apache-2.0", + tests_require=["pytest"], + entry_points={ + "console_scripts": [ + "train = rbs_gym.scripts.train:main", + "spawner = rbs_gym.scripts.spawner:main", + "velocity = rbs_gym.scripts.velocity:main", + "test_agent = rbs_gym.scripts.test_agent:main", + "evaluate = rbs_gym.scripts.evaluate:main", + ], + }, +) diff --git a/env_manager/rbs_gym/test/test_copyright.py b/env_manager/rbs_gym/test/test_copyright.py new file mode 100644 index 0000000..97a3919 --- /dev/null +++ b/env_manager/rbs_gym/test/test_copyright.py @@ -0,0 +1,25 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found errors' diff --git a/env_manager/rbs_gym/test/test_flake8.py b/env_manager/rbs_gym/test/test_flake8.py new file mode 100644 index 0000000..27ee107 --- /dev/null +++ b/env_manager/rbs_gym/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, \ + 'Found %d code style errors / warnings:\n' % len(errors) + \ + '\n'.join(errors) diff --git a/env_manager/rbs_gym/test/test_pep257.py b/env_manager/rbs_gym/test/test_pep257.py new file mode 100644 index 0000000..b234a38 --- /dev/null +++ b/env_manager/rbs_gym/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found code style errors / warnings' diff --git a/env_manager/rbs_runtime/LICENSE b/env_manager/rbs_runtime/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/env_manager/rbs_runtime/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/env_manager/rbs_runtime/config/default-scene-config.yaml b/env_manager/rbs_runtime/config/default-scene-config.yaml new file mode 100644 index 0000000..7016a9a --- /dev/null +++ b/env_manager/rbs_runtime/config/default-scene-config.yaml @@ -0,0 +1,249 @@ +camera: + - clip_color: !tuple + - 0.01 + - 1000.0 + clip_depth: !tuple + - 0.05 + - 10.0 + enable: true + height: 128 + horizontal_fov: 1.0471975511965976 + image_format: R8G8B8 + name: robot_camera + noise_mean: null + noise_stddev: null + publish_color: true + publish_depth: true + publish_points: true + random_pose_focal_point_z_offset: 0.0 + random_pose_mode: orbit + random_pose_orbit_distance: 1.0 + random_pose_orbit_height_range: !tuple + - 0.1 + - 0.7 + random_pose_orbit_ignore_arc_behind_robot: 0.39269908169872414 + random_pose_rollout_counter: 0.0 + random_pose_rollouts_num: 1 + random_pose_select_position_options: [] + relative_to: base_link + spawn_position: !tuple + - 0.0 + - 0.0 + - 1.0 + spawn_quat_xyzw: !tuple + - 0.0 + - 0.70710678118 + - 0.0 + - 0.70710678118 + type: rgbd_camera + update_rate: 10 + vertical_fov: 1.0471975511965976 + width: 128 +gravity: !tuple + - 0.0 + - 0.0 + - -9.80665 +gravity_std: !tuple + - 0.0 + - 0.0 + - 0.0232 +light: + color: !tuple + - 1.0 + - 1.0 + - 1.0 + - 1.0 + direction: !tuple + - 0.5 + - -0.25 + - -0.75 + distance: 1000.0 + model_rollouts_num: 1 + radius: 25.0 + random_minmax_elevation: !tuple + - -0.15 + - -0.65 + type: sun + visual: true +objects: + - color: null + name: star + orientation: !tuple + - 1.0 + - 0.0 + - 0.0 + - 0.0 + position: !tuple + - -0.1 + - -0.40 + - 0.1 + randomize: + count: 0 + models_rollouts_num: 0 + random_color: false + random_model: false + random_orientation: false + random_pose: false + random_position: false + random_spawn_position_segments: [] + random_spawn_position_update_workspace_centre: false + random_spawn_volume: !tuple + - 0.5 + - 0.5 + - 0.5 + relative_to: world + static: false + texture: [] + type: "model" + + - color: null + name: box + mass: 1.0 + size: !tuple + - 0.02 + - 0.02 + - 0.05 + orientation: !tuple + - 1.0 + - 0.0 + - 0.0 + - 0.0 + position: !tuple + - -0.1 + - -0.37 + - 0.1 + randomize: + count: 0 + models_rollouts_num: 0 + random_color: false + random_model: false + random_orientation: false + random_pose: false + random_position: false + random_spawn_position_segments: [] + random_spawn_position_update_workspace_centre: false + random_spawn_volume: !tuple + - 0.5 + - 0.5 + - 0.5 + relative_to: world + static: false + type: "box" + + + - color: null + name: cylinder + mass: 1.0 + radius: 0.01 + length: 0.05 + orientation: !tuple + - 1.0 + - 0.0 + - 0.0 + - 0.0 + position: !tuple + - -0.1 + - -0.34 + - 0.1 + randomize: + count: 0 + models_rollouts_num: 0 + random_color: false + random_model: false + random_orientation: false + random_pose: false + random_position: false + random_spawn_position_segments: [] + random_spawn_position_update_workspace_centre: false + random_spawn_volume: !tuple + - 0.5 + - 0.5 + - 0.5 + relative_to: world + static: false + type: "cylinder" + + - color: null + name: board + orientation: !tuple + - 1.0 + - 0.0 + - 0.0 + - 0.0 + position: !tuple + - 0.0 + - -0.37 + - 0.01 + randomize: + count: 0 + models_rollouts_num: 0 + random_color: false + random_model: false + random_orientation: false + random_pose: false + random_position: false + random_spawn_position_segments: [] + random_spawn_position_update_workspace_centre: false + random_spawn_volume: !tuple + - 0.5 + - 0.5 + - 0.5 + relative_to: world + static: true + texture: [] + type: "model" +physics_rollouts_num: 0 +plugins: + fts_broadcaster: false + scene_broadcaster: false + sensors_render_engine: ogre2 + user_commands: false +robot: + gripper_joint_positions: 0.0 + joint_positions: + - -1.57 + - 0.5 + - 3.14159 + - 1.5 + - 0.0 + - 1.4 + - 0.0 + name: rbs_arm + randomizer: + joint_positions: false + joint_positions_above_object_spawn: false + joint_positions_above_object_spawn_elevation: 0.2 + joint_positions_above_object_spawn_xy_randomness: 0.2 + joint_positions_std: 0.1 + pose: false + spawn_volume: !tuple + - 1.0 + - 1.0 + - 0.0 + spawn_position: !tuple + - 0.0 + - 0.0 + - 0.0 + spawn_quat_xyzw: !tuple + - 0.0 + - 0.0 + - 0.0 + - 1.0 + urdf_string: "" + with_gripper: true +terrain: + model_rollouts_num: 1 + name: ground + size: !tuple + - 1.5 + - 1.5 + spawn_position: !tuple + - 0 + - 0 + - 0 + spawn_quat_xyzw: !tuple + - 0 + - 0 + - 0 + - 1 + type: flat diff --git a/env_manager/rbs_runtime/launch/runtime.launch.py b/env_manager/rbs_runtime/launch/runtime.launch.py new file mode 100644 index 0000000..6f56440 --- /dev/null +++ b/env_manager/rbs_runtime/launch/runtime.launch.py @@ -0,0 +1,356 @@ +import os + +import xacro +import yaml +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import ( + DeclareLaunchArgument, + IncludeLaunchDescription, + OpaqueFunction, + TimerAction, +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch.launch_introspector import indent +from launch.substitutions import LaunchConfiguration, PathJoinSubstitution +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare +from robot_builder.external.ros2_control import ControllerManager +from robot_builder.parser.urdf import URDF_parser + + +def launch_setup(context, *args, **kwargs): + # Initialize Arguments + robot_type = LaunchConfiguration("robot_type") + # General arguments + with_gripper_condition = LaunchConfiguration("with_gripper") + description_package = LaunchConfiguration("description_package") + description_file = LaunchConfiguration("description_file") + use_moveit = LaunchConfiguration("use_moveit") + moveit_config_package = LaunchConfiguration("moveit_config_package") + moveit_config_file = LaunchConfiguration("moveit_config_file") + use_sim_time = LaunchConfiguration("use_sim_time") + scene_config_file = LaunchConfiguration("scene_config_file").perform(context) + ee_link_name = LaunchConfiguration("ee_link_name").perform(context) + base_link_name = LaunchConfiguration("base_link_name").perform(context) + control_space = LaunchConfiguration("control_space").perform(context) + control_strategy = LaunchConfiguration("control_strategy").perform(context) + interactive = LaunchConfiguration("interactive").perform(context) + real_robot = LaunchConfiguration("real_robot").perform(context) + + use_rbs_utils = LaunchConfiguration("use_rbs_utils") + assembly_config_name = LaunchConfiguration("assembly_config_name") + + if not scene_config_file == "": + config_file = {"config_file": scene_config_file} + else: + config_file = {} + + description_package_abs_path = get_package_share_directory( + description_package.perform(context) + ) + + controllers_file = os.path.join( + description_package_abs_path, "config", "controllers.yaml" + ) + + xacro_file = os.path.join( + description_package_abs_path, + "urdf", + description_file.perform(context), + ) + + # xacro_config_file = f"{description_package_abs_path}/config/xacro_args.yaml" + xacro_config_file = os.path.join( + description_package_abs_path, "urdf", "xacro_args.yaml" + ) + + # TODO: hide this to another place + # Load xacro_args + def param_constructor(loader, node, local_vars): + value = loader.construct_scalar(node) + return LaunchConfiguration(value).perform( + local_vars.get("context", "Launch context if not defined") + ) + + def variable_constructor(loader, node, local_vars): + value = loader.construct_scalar(node) + return local_vars.get(value, f"Variable '{value}' not found") + + def load_xacro_args(yaml_file, local_vars): + # Get valut from ros2 argument + yaml.add_constructor( + "!param", lambda loader, node: param_constructor(loader, node, local_vars) + ) + + # Get value from local variable in this code + # The local variable should be initialized before the loader was called + yaml.add_constructor( + "!variable", + lambda loader, node: variable_constructor(loader, node, local_vars), + ) + + with open(yaml_file, "r") as file: + return yaml.load(file, Loader=yaml.FullLoader) + + mappings_data = load_xacro_args(xacro_config_file, locals()) + + robot_description_doc = xacro.process_file(xacro_file, mappings=mappings_data) + + robot_description_semantic_content = "" + + if use_moveit.perform(context) == "true": + srdf_config_file = os.path.join( + get_package_share_directory(moveit_config_package.perform(context)), + "srdf", + "xacro_args.yaml", + ) + srdf_file = os.path.join( + get_package_share_directory(moveit_config_package.perform(context)), + "srdf", + moveit_config_file.perform(context), + ) + srdf_mappings = load_xacro_args(srdf_config_file, locals()) + robot_description_semantic_content = xacro.process_file( + srdf_file, mappings=srdf_mappings + ) + robot_description_semantic_content = ( + robot_description_semantic_content.toprettyxml(indent=" ") + ) + control_space = "joint" + control_strategy = "position" + interactive = "false" + + robot_description_content = robot_description_doc.toprettyxml(indent=" ") + robot_description = {"robot_description": robot_description_content} + + # Parse robot and configure controller's file for ControllerManager + robot = URDF_parser.load_string( + robot_description_content, ee_link_name=ee_link_name + ) + ControllerManager.save_to_yaml( + robot, description_package_abs_path, "controllers.yaml" + ) + + rbs_robot_setup = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + PathJoinSubstitution( + [FindPackageShare("rbs_bringup"), "launch", "rbs_robot.launch.py"] + ) + ] + ), + launch_arguments={ + "with_gripper": with_gripper_condition, + "controllers_file": controllers_file, + "robot_type": robot_type, + "description_package": description_package, + "description_file": description_file, + "robot_name": robot_type, + "use_moveit": use_moveit, + "moveit_config_package": moveit_config_package, + "moveit_config_file": moveit_config_file, + "use_sim_time": use_sim_time, + "use_controllers": "true", + "robot_description": robot_description_content, + "robot_description_semantic": robot_description_semantic_content, + "base_link_name": base_link_name, + "ee_link_name": ee_link_name, + "control_space": control_space, + "control_strategy": control_strategy, + "interactive_control": interactive, + "use_rbs_utils": use_rbs_utils, + "assembly_config_name": assembly_config_name, + }.items(), + ) + + if real_robot == "true": + controller_manager_node = Node( + package="controller_manager", + executable="ros2_control_node", + parameters=[ + robot_description, + controllers_file, + ], + output="screen", + ) + return [rbs_robot_setup, controller_manager_node] + + rbs_runtime = Node( + package="rbs_runtime", + executable="runtime", + arguments=[("--ros-args --remap log_level:=debug")], + parameters=[robot_description, config_file, {"use_sim_time": True}], + ) + + clock_bridge = Node( + package="ros_gz_bridge", + executable="parameter_bridge", + arguments=["/clock@rosgraph_msgs/msg/Clock[ignition.msgs.Clock"], + output="screen", + ) + + delay_robot_control_stack = TimerAction(period=10.0, actions=[rbs_robot_setup]) + + nodes_to_start = [rbs_runtime, clock_bridge, delay_robot_control_stack] + return nodes_to_start + + +def generate_launch_description(): + declared_arguments = [] + declared_arguments.append( + DeclareLaunchArgument( + "robot_type", + description="Type of robot by name", + choices=[ + "rbs_arm", + "ar4", + "ur3", + "ur3e", + "ur5", + "ur5e", + "ur10", + "ur10e", + "ur16e", + ], + default_value="rbs_arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_package", + default_value="rbs_arm", + description="Description package with robot URDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom description.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "description_file", + default_value="rbs_arm_modular.xacro", + description="URDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "robot_name", + default_value="arm0", + description="Name for robot, used to apply namespace for specific robot in multirobot setup", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_package", + default_value="rbs_arm", + description="MoveIt config package with robot SRDF/XACRO files. Usually the argument \ + is not set, it enables use of a custom moveit config.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "moveit_config_file", + default_value="rbs_arm.srdf.xacro", + description="MoveIt SRDF/XACRO description file with the robot.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_sim_time", + default_value="true", + description="Make MoveIt to use simulation time.\ + This is needed for the trajectory planing in simulation.", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "with_gripper", default_value="true", description="With gripper or not?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_moveit", default_value="false", description="Launch moveit?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "launch_perception", default_value="false", description="Launch perception?" + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "use_controllers", + default_value="true", + description="Launch controllers?", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "scene_config_file", + default_value="", + description="Path to a scene configuration file", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "ee_link_name", + default_value="", + description="End effector name of robot arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "base_link_name", + default_value="", + description="Base link name if robot arm", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "control_space", + default_value="task", + choices=["task", "joint"], + description="Specify the control space for the robot (e.g., task space).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "control_strategy", + default_value="position", + choices=["position", "velocity", "effort"], + description="Specify the control strategy (e.g., position control).", + ) + ) + declared_arguments.append( + DeclareLaunchArgument( + "interactive", + default_value="true", + description="Wheter to run the motion_control_handle controller", + ), + ) + declared_arguments.append( + DeclareLaunchArgument( + "real_robot", + default_value="false", + description="Wheter to run on the real robot", + ), + ) + + declared_arguments.append( + DeclareLaunchArgument( + "use_rbs_utils", + default_value="true", + description="Wheter to use rbs_utils", + ), + ) + + declared_arguments.append( + DeclareLaunchArgument( + "assembly_config_name", + default_value="", + description="Assembly config name from rbs_assets_library", + ), + ) + + return LaunchDescription( + declared_arguments + [OpaqueFunction(function=launch_setup)] + ) diff --git a/env_manager/rbs_runtime/package.nix b/env_manager/rbs_runtime/package.nix new file mode 100644 index 0000000..706e14d --- /dev/null +++ b/env_manager/rbs_runtime/package.nix @@ -0,0 +1,29 @@ +# Automatically generated by: ros2nix --distro jazzy --flake --license Apache-2.0 +# Copyright 2025 None +# Distributed under the terms of the Apache-2.0 license +{ + lib, + buildRosPackage, + ament-copyright, + ament-flake8, + ament-pep257, + env-manager, + env-manager-interfaces, + pythonPackages, + scenario, +}: +buildRosPackage rec { + pname = "ros-jazzy-rbs-runtime"; + version = "0.0.0"; + + src = ./.; + + buildType = "ament_python"; + checkInputs = [ament-copyright ament-flake8 ament-pep257 pythonPackages.pytest]; + propagatedBuildInputs = [pythonPackages.dacite scenario env-manager env-manager-interfaces]; + + meta = { + description = "TODO: Package description"; + license = with lib.licenses; [asl20]; + }; +} diff --git a/env_manager/rbs_runtime/package.xml b/env_manager/rbs_runtime/package.xml new file mode 100644 index 0000000..65e2d1d --- /dev/null +++ b/env_manager/rbs_runtime/package.xml @@ -0,0 +1,21 @@ + + + + rbs_runtime + 0.0.0 + TODO: Package description + narmak + Apache-2.0 + + env_manager + env_manager_interfaces + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/env_manager/rbs_runtime/rbs_runtime/__init__.py b/env_manager/rbs_runtime/rbs_runtime/__init__.py new file mode 100644 index 0000000..029c56e --- /dev/null +++ b/env_manager/rbs_runtime/rbs_runtime/__init__.py @@ -0,0 +1,41 @@ +from pathlib import Path + +import yaml +from dacite import from_dict +from env_manager.models.configs import SceneData + +from typing import Dict, Any +from env_manager.models.configs import ( + BoxObjectData, CylinderObjectData, MeshData, ModelData, ObjectData, ObjectRandomizerData +) + +def object_factory(obj_data: Dict[str, Any]) -> ObjectData: + obj_type = obj_data.get('type', '') + + if "randomize" in obj_data and isinstance(obj_data["randomize"], dict): + obj_data["randomize"] = from_dict(data_class=ObjectRandomizerData, data=obj_data["randomize"]) + + if obj_type == 'box': + return from_dict(data_class=BoxObjectData, data=obj_data) + elif obj_type == 'cylinder': + return from_dict(data_class=CylinderObjectData, data=obj_data) + elif obj_type == 'mesh': + return from_dict(data_class=MeshData, data=obj_data) + elif obj_type == 'model': + return from_dict(data_class=ModelData, data=obj_data) + else: + return from_dict(data_class=ObjectData, data=obj_data) + +def scene_config_loader(file: str | Path) -> SceneData: + def tuple_constructor(loader, node): + return tuple(loader.construct_sequence(node)) + + with open(file, "r") as yaml_file: + yaml.SafeLoader.add_constructor("!tuple", tuple_constructor) + scene_cfg = yaml.load(yaml_file, Loader=yaml.SafeLoader) + + scene_data = from_dict(data_class=SceneData, data=scene_cfg) + + scene_data.objects = [object_factory(obj) for obj in scene_cfg.get('objects', [])] + + return scene_data diff --git a/env_manager/rbs_runtime/rbs_runtime/scripts/__init__.py b/env_manager/rbs_runtime/rbs_runtime/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/env_manager/rbs_runtime/rbs_runtime/scripts/runtime.py b/env_manager/rbs_runtime/rbs_runtime/scripts/runtime.py new file mode 100755 index 0000000..e234263 --- /dev/null +++ b/env_manager/rbs_runtime/rbs_runtime/scripts/runtime.py @@ -0,0 +1,73 @@ +#!/usr/bin/python3 + +import rclpy +from ament_index_python.packages import get_package_share_directory +from env_manager.scene import Scene +from env_manager_interfaces.srv import ResetEnv +from rbs_assets_library import get_world_file +from rclpy.node import Node +from scenario.bindings.gazebo import GazeboSimulator +from rclpy.executors import MultiThreadedExecutor, ExternalShutdownException + +from .. import scene_config_loader + +class GazeboRuntime(Node): + def __init__(self) -> None: + super().__init__(node_name="rbs_gz_runtime") + self.declare_parameter("robot_description", "") + self.declare_parameter( + "config_file", + get_package_share_directory("rbs_runtime") + + "/config/default-scene-config.yaml", + ) + + self.gazebo = GazeboSimulator(step_size=0.001, rtf=1.0, steps_per_run=1) + + self.gazebo.insert_world_from_sdf(get_world_file("default")) + if not self.gazebo.initialize(): + raise RuntimeError("Gazebo cannot be initialized") + + config_file = ( + self.get_parameter("config_file").get_parameter_value().string_value + ) + + scene_data = scene_config_loader(config_file) + + self.scene = Scene( + node=self, + gazebo=self.gazebo, + scene=scene_data, + robot_urdf_string=self.get_parameter("robot_description") + .get_parameter_value() + .string_value, + ) + self.scene.init_scene() + + self.reset_env_srv = self.create_service( + ResetEnv, "/env_manager/reset_env", self.reset_env + ) + self.is_env_reset = False + self.timer = self.create_timer(self.gazebo.step_size(), self.gazebo_run) + + def gazebo_run(self): + if self.is_env_reset: + self.scene.reset_scene() + self.is_env_reset = False + self.gazebo.run() + + def reset_env(self, req, res): + self.is_env_reset = True + res.ok = self.is_env_reset + return res + + +def main(): + rclpy.init() + + executor = MultiThreadedExecutor() + node = GazeboRuntime() + executor.add_node(node) + try: + executor.spin() + except (KeyboardInterrupt, ExternalShutdownException): + node.destroy_node() diff --git a/env_manager/rbs_runtime/resource/rbs_runtime b/env_manager/rbs_runtime/resource/rbs_runtime new file mode 100644 index 0000000..e69de29 diff --git a/env_manager/rbs_runtime/setup.cfg b/env_manager/rbs_runtime/setup.cfg new file mode 100644 index 0000000..27e6f37 --- /dev/null +++ b/env_manager/rbs_runtime/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/rbs_runtime +[install] +install_scripts=$base/lib/rbs_runtime diff --git a/env_manager/rbs_runtime/setup.py b/env_manager/rbs_runtime/setup.py new file mode 100644 index 0000000..12ce2cd --- /dev/null +++ b/env_manager/rbs_runtime/setup.py @@ -0,0 +1,36 @@ +import os +from glob import glob + +from setuptools import find_packages, setup + +package_name = "rbs_runtime" + +setup( + name=package_name, + version="0.0.0", + packages=find_packages(exclude=["test"]), + data_files=[ + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), + ( + os.path.join("share", package_name, "launch"), + glob(os.path.join("launch", "*launch.[pxy][yma]*")), + ), + ( + os.path.join("share", package_name, "config"), + glob(os.path.join("config", "*config.[yma]*")), + ), + ], + install_requires=["setuptools", "dacite"], + zip_safe=True, + maintainer="narmak", + maintainer_email="ur.narmak@gmail.com", + description="TODO: Package description", + license="Apache-2.0", + tests_require=["pytest"], + entry_points={ + "console_scripts": [ + "runtime = rbs_runtime.scripts.runtime:main", + ], + }, +) diff --git a/env_manager/rbs_runtime/test/test_copyright.py b/env_manager/rbs_runtime/test/test_copyright.py new file mode 100644 index 0000000..97a3919 --- /dev/null +++ b/env_manager/rbs_runtime/test/test_copyright.py @@ -0,0 +1,25 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found errors' diff --git a/env_manager/rbs_runtime/test/test_flake8.py b/env_manager/rbs_runtime/test/test_flake8.py new file mode 100644 index 0000000..27ee107 --- /dev/null +++ b/env_manager/rbs_runtime/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, \ + 'Found %d code style errors / warnings:\n' % len(errors) + \ + '\n'.join(errors) diff --git a/env_manager/rbs_runtime/test/test_pep257.py b/env_manager/rbs_runtime/test/test_pep257.py new file mode 100644 index 0000000..b234a38 --- /dev/null +++ b/env_manager/rbs_runtime/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found code style errors / warnings'