diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..ca5f8de
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+ignore = D100, D101, D102, D103, D104, D105, D107, E203, E501, W503, SIM201, SIM113, B027, C408, B008
+docstring-convention = google
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
new file mode 100644
index 0000000..5db5e14
--- /dev/null
+++ b/.github/workflows/pythonapp.yml
@@ -0,0 +1,44 @@
+name: Python application
+
+on:
+ push:
+ branches: [ main ]
+ pull_request:
+ branches: [ main ]
+ schedule:
+ - cron: "0 0 * * 0"
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ # python-version: ["3.8", "3.9", "3.10", "3.11"]
+ python-version: ["3.8", "3.10"]
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements_dev.txt
+ pip install -r requirements.txt
+ pip install -e .
+ git submodule init
+ git submodule update
+ pip install third_party/DynamixelSDK/python
+
+ - name: Black
+ run: |
+ black --check gello
+ - name: flake8
+ run: |
+ # stop the build if there are Python syntax errors or undefined names
+ flake8 gello --count --select=E9,F63,F7,F82 --show-source --statistics
+ - name: Test with pytest
+ run: |
+ pytest gello
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1b141e3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,29 @@
+# keep each top level section alphabetical
+# keep each item within the sections alphabetical
+
+# large files
+*.png
+*.gif
+
+# mac
+*.DS_Store
+
+# misc
+*logs*
+*runs*
+*tb_logs*
+*wandb*
+*wandb_logs*
+outputs/*
+
+# python
+*__pycache__
+*.egg-info
+*.hypothesis
+*.ipynb_checkpoints
+*.mypy_cache
+*.pyc
+
+# vim
+*.swp
+MUJOCO_LOG.TXT
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..b8c7d8f
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "third_party/mujoco_menagerie"]
+ path = third_party/mujoco_menagerie
+ url = https://github.com/deepmind/mujoco_menagerie.git
+[submodule "third_party/DynamixelSDK"]
+ path = third_party/DynamixelSDK
+ url = https://github.com/ROBOTIS-GIT/DynamixelSDK.git
diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 0000000..04bab0f
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,6 @@
+[settings]
+use_parentheses=True
+include_trailing_comma=True
+multi_line_output=3
+ensure_newline_before_comments=True
+line_length=88
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..7c92e72
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,38 @@
+repos:
+
+# remove unused python imports
+- repo: https://github.com/myint/autoflake.git
+ rev: v2.1.1
+ hooks:
+ - id: autoflake
+ args: ["--in-place", "--remove-all-unused-imports", "--ignore-init-module-imports"]
+
+# sort imports
+- repo: https://github.com/timothycrosley/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+
+# code format according to black
+- repo: https://github.com/ambv/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+
+# check for python styling with flake8
+- repo: https://github.com/PyCQA/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ additional_dependencies: [
+ 'flake8-docstrings',
+ 'flake8-bugbear',
+ 'flake8-comprehensions',
+ 'flake8-simplify',
+ ]
+
+# cleanup notebooks
+- repo: https://github.com/kynan/nbstripout
+ rev: 0.6.1
+ hooks:
+ - id: nbstripout
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..e3f44ea
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,22 @@
+FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
+
+WORKDIR /gello
+
+# Set environment variables first (less likely to change)
+ENV PYTHONPATH=/gello:/gello/third_party/oculus_reader/
+
+# Group apt updates and installs together
+RUN apt update && apt install -y \
+ libhidapi-dev \
+ python3-pip \
+ android-tools-adb \
+ libegl1-mesa-dev && \
+ rm -rf /var/lib/apt/lists/*
+
+
+# Python alias setup
+RUN echo "alias python=python3" >> ~/.bashrc
+
+# Install Python dependencies
+COPY requirements.txt /gello
+RUN pip install -r requirements.txt
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..7b3eeaa
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Philipp Wu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 1bf2160..e929fb6 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,174 @@
-# GELLO Software
-This is the central repo that holds the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/
+# GELLO
+This is the central repo that holds the all the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/
+See the GELLO hardware repo for the STL files and hardware instructions for building your own GELLO https://github.com/wuphilipp/gello_mechanical
+```
+git clone https://github.com/wuphilipp/gello_software.git
+cd gello
+```
-# Installation
+
+
+
-Coming Soon
+## Use your own enviroment
+```
+git submodule init
+git submodule update
+pip install -r requirements.txt
+pip install -e .
+pip install -e third_party/DynamixelSDK/python
+```
+
+## Use with Docker
+First install ```docker``` following this [link](https://docs.docker.com/engine/install/ubuntu/) on your host machine.
+Then you can clone the repo and build the corresponding docker environment
+
+Build the docker image and tag it as gello:latest. If you are going to name it differently, you need to change the launch.py image name
+```
+docker build . -t gello:latest
+```
+
+We have provided an entry point into the docker container
+```
+python experiments/launch.py
+```
+
+# GELLO configuration setup (PLEASE READ)
+Now that you have downloaded the code, there is some additional preparation work to properly configure the Dynamixels and GELLO.
+These instructions will guide you on how to update the motor ids of the Dynamixels and then how to extract the joint offsets to configure your GELLO.
+
+## Update motor IDs
+Install the [dynamixel_wizard](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2/).
+By default, each motor has the ID 1. In order for multiple dynamixels to be controlled by the same U2D2 controller board, each dynamixel must have a unique ID.
+This process must be done one motor at a time. Connect each motor, starting from the base motor, and assign them in increasing order until you reach the gripper.
+
+Steps:
+ * Connect a single motor to the controller and connect the controller to the computer.
+ * Open the dynamixel wizard
+ * Click scan (found at the top left corner), this should detect the dynamixel. Connect to the motor
+ * Look for the ID address and change the ID to the appropriate number.
+ * Repeat for each motor
+
+## Create the GELLO configuration and determining joint ID's
+After the motor ID's are set, we can now connect to the GELLO controller device. However each motor has its own joint offset, which will result in a joint offset between GELLO and your actual robot arm.
+Dynamixels have a symmetric 4 hole pattern which means there the joint offset is a multiple of pi/2.
+The `GelloAgent` class accepts a `DynamixelRobotConfig` (found in `gello/agents/gello_agent.py`). The Dynamixel config specifies the parameters you need to find to operate your GELLO. Look at the documentation for more details.
+
+We have created a simple script to automatically detect the joint offset:
+* set GELLO into a known configuration, where you know what the corresponding joint angles should be. For example, we set out GELLO in this configuration, where we know the desired ground truth joints. (0, -90, 90, -90, -90, 0)
+
+
+
+
+
+* run
+```
+python scripts/gello_get_offset.py \
+ --start-joints 0 -1.57 1.57 -1.57 -1.57 0 \ # in radians
+ --joint-signs 1 1 -1 1 1 1 \
+ --port /dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6
+# replace values with your own
+```
+* Use the known starting joints for `start-joints`.
+* Use the `joint-signs` for your own robot (see below).
+* Use your serial port for `port`. You can find the port id of your U2D2 Dynamixel device by running `ls /dev/serial/by-id` and looking for the path that starts with `usb-FTDI_USB__-__Serial_Converter` (on Ubuntu). On Mac, look in /dev/ and the device that starts with `cu.usbserial`
+
+`joint-signs` for each robot type:
+* UR: `1 1 -1 1 1 1`
+* Panda: `1 -1 1 1 1 -1 1`
+* xArm: `1 1 1 1 1 1 1`
+
+The script prints out a list of joint offsets. Go to `gello/agents/gello_agent.py` and add a DynamixelRobotConfig to the PORT_CONFIG_MAP. You are now ready to run your GELLO!
+
+# Using GELLO to control a robot!
+
+The code provided here is simple and only relies on python packages. The code does NOT use ROS, but a ROS wrapper can easily be adapted from this code.
+For multiprocessing, we leverage [ZMQ](https://zeromq.org/)
+
+## Testing in sim
+First test your GELLO with a simulated robot to make sure that the joint angles match as expected.
+In one terminal run
+```
+python experiments/launch_nodes.py --robot
+```
+This launched the robot node. A simulated robot using the mujoco viewer should appear.
+
+Then, launch your GELLO (the controller node).
+```
+python experiments/run_env.py --agent=gello
+```
+You should be able to use GELLO to control the simulated robot!
+
+## Running on a real robot.
+Once you have verified that your GELLO is properly configured, you can test it on a real robot!
+
+Before you run with the real robot, you will have to install a robot specific python package.
+The supported robots are in `gello/robots`.
+ * UR: [ur_rtde](https://sdurobotics.gitlab.io/ur_rtde/installation/installation.html)
+ * panda: [polymetis](https://facebookresearch.github.io/fairo/polymetis/installation.html). If you use a different framework to control the panda, the code is easy to adpot. See/Modify `gello/robots/panda.py`
+ * xArm: [xArm python SDK](https://github.com/xArm-Developer/xArm-Python-SDK)
+
+```
+# Launch all of the node
+python experiments/launch_nodes.py --robot=
+# run the enviroment loop
+python experiments/run_env.py --agent=gello
+```
+
+Ideally you can start your GELLO near a known configuration each time. If this is possible, you can set the `--start-joint` flag with GELLO's known starting configuration. This also enables the robot to reset before you begin teleoperation.
+
+## Collect data
+We have provided a simple example for collecting data with gello.
+To save trajectories with the keyboard, add the following flag `--use-save-interface`
+
+Data can then be processed using the demo_to_gdict script.
+```
+python gello/data_utils/demo_to_gdict.py --source-dir=
+```
+
+## Running a bimanual system with GELLO
+GELLO also be used in bimanual configurations.
+For an example, see the `bimanual_ur` robot in `launch_nodes.py` and `--bimanual` flag in the `run_env.py` script.
+
+## Notes
+Due to the use of multiprocessing, sometimes python process are not killed properly. We have provided the kill_nodes script which will kill the
+python processes.
+```
+./kill_nodes.sh
+```
+
+### Using a new robot!
+If you want to use a new robot you need a GELLO that is compatible. If the kiniamtics are close enough, you may directly use an existing GELLO. Otherwise you will have to design your own.
+To add a new robot, simply implement the `Robot` protocol found in `gello/robots/robot`. See `gello/robots/panda.py`, `gello/robots/ur.py`, `gello/robots/xarm_robot.py` for examples.
+
+### Contributing
+Please make a PR if you would like to contribute! The goal of this project is to enable more accessible and higher quality teleoperation devices and we would love your input!
+
+You can optionally install some dev packages.
+```
+pip install -r requirements_dev.txt
+```
+
+The code is organized as follows:
+ * `scripts`: contains some helpful python `scripts`
+ * `experiments`: contains entrypoints into the gello code
+ * `gello`: contains all of the `gello` python package code
+ * `agents`: teleoperation agents
+ * `cameras`: code to interface with camera hardware
+ * `data_utils`: data processing utils. used for imitation learning
+ * `dm_control_tasks`: dm_control utils to build a simple dm_control enviroment. used for demos
+ * `dynamixel`: code to interface with the dynamixel hardware
+ * `robots`: robot specific interfaces
+ * `zmq_core`: zmq utilities for enabling a multi node system
+
+
+This code base uses `isort` and `black` for code formatting.
+pre-commits hooks are great. This will automatically do some checking/formatting. To use the pre-commit hooks, run the following:
+```
+pip install pre-commit
+pre-commit install
+```
# Citation
@@ -15,3 +179,11 @@ Coming Soon
year={2023},
}
```
+
+# License & Acknowledgements
+This source code is licensed under the MIT license found in the LICENSE file. in the root directory of this source tree.
+
+This project builds on top of or utilizes the following third party dependencies.
+ * [google-deepmind/mujoco_menagerie](https://github.com/google-deepmind/mujoco_menagerie): Prebuilt robot models for mujoco
+ * [brentyi/tyro](https://github.com/brentyi/tyro): Argument parsing and configuration
+ * [ZMQ](https://zeromq.org/): Enables easy create of node like processes in python.
diff --git a/config_hostmachine.sh b/config_hostmachine.sh
new file mode 100644
index 0000000..60051b8
--- /dev/null
+++ b/config_hostmachine.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# install nvidia driver 520
+sudo apt-get install nvidia-520 -y
+
+# install docker
+sudo apt-get update
+sudo apt-get install ca-certificates curl gnupg
+
+sudo install -m 0755 -d /etc/apt/keyrings
+curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
+sudo chmod a+r /etc/apt/keyrings/docker.gpg
+
+echo \
+ "deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
+ "$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \
+ sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
+
+sudo apt-get update
+sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
diff --git a/experiments/launch_camera_clients.py b/experiments/launch_camera_clients.py
new file mode 100644
index 0000000..a89efb7
--- /dev/null
+++ b/experiments/launch_camera_clients.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass
+from typing import Tuple
+
+import numpy as np
+import tyro
+
+from gello.zmq_core.camera_node import ZMQClientCamera
+
+
+@dataclass
+class Args:
+ ports: Tuple[int, ...] = (5000, 5001)
+ hostname: str = "127.0.0.1"
+ # hostname: str = "128.32.175.167"
+
+
+def main(args):
+ cameras = []
+ import cv2
+
+ images_display_names = []
+ for port in args.ports:
+ cameras.append(ZMQClientCamera(port=port, host=args.hostname))
+ images_display_names.append(f"image_{port}")
+ cv2.namedWindow(images_display_names[-1], cv2.WINDOW_NORMAL)
+
+ while True:
+ for display_name, camera in zip(images_display_names, cameras):
+ image, depth = camera.read()
+ stacked_depth = np.dstack([depth, depth, depth]).astype(np.uint8)
+ image_depth = cv2.hconcat([image[:, :, ::-1], stacked_depth])
+ cv2.imshow(display_name, image_depth)
+ cv2.waitKey(1)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/experiments/launch_camera_nodes.py b/experiments/launch_camera_nodes.py
new file mode 100644
index 0000000..b5c6168
--- /dev/null
+++ b/experiments/launch_camera_nodes.py
@@ -0,0 +1,40 @@
+from dataclasses import dataclass
+from multiprocessing import Process
+
+import tyro
+
+from gello.cameras.realsense_camera import RealSenseCamera, get_device_ids
+from gello.zmq_core.camera_node import ZMQServerCamera
+
+
+@dataclass
+class Args:
+ # hostname: str = "127.0.0.1"
+ hostname: str = "128.32.175.167"
+
+
+def launch_server(port: int, camera_id: int, args: Args):
+ camera = RealSenseCamera(camera_id)
+ server = ZMQServerCamera(camera, port=port, host=args.hostname)
+ print(f"Starting camera server on port {port}")
+ server.serve()
+
+
+def main(args):
+ ids = get_device_ids()
+ camera_port = 5000
+ camera_servers = []
+ for camera_id in ids:
+ # start a python process for each camera
+ print(f"Launching camera {camera_id} on port {camera_port}")
+ camera_servers.append(
+ Process(target=launch_server, args=(camera_port, camera_id, args))
+ )
+ camera_port += 1
+
+ for server in camera_servers:
+ server.start()
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/experiments/launch_nodes.py b/experiments/launch_nodes.py
new file mode 100644
index 0000000..9251711
--- /dev/null
+++ b/experiments/launch_nodes.py
@@ -0,0 +1,94 @@
+from dataclasses import dataclass
+from pathlib import Path
+
+import tyro
+
+from gello.robots.robot import BimanualRobot, PrintRobot
+from gello.zmq_core.robot_node import ZMQServerRobot
+
+
+@dataclass
+class Args:
+ robot: str = "xarm"
+ robot_port: int = 6001
+ hostname: str = "127.0.0.1"
+ robot_ip: str = "192.168.1.10"
+
+
+def launch_robot_server(args: Args):
+ port = args.robot_port
+ if args.robot == "sim_ur":
+ MENAGERIE_ROOT: Path = (
+ Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
+ )
+ xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
+ gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
+ from gello.robots.sim_robot import MujocoRobotServer
+
+ server = MujocoRobotServer(
+ xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
+ )
+ server.serve()
+ elif args.robot == "sim_panda":
+ from gello.robots.sim_robot import MujocoRobotServer
+
+ MENAGERIE_ROOT: Path = (
+ Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
+ )
+ xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
+ gripper_xml = None
+ server = MujocoRobotServer(
+ xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
+ )
+ server.serve()
+ elif args.robot == "sim_xarm":
+ from gello.robots.sim_robot import MujocoRobotServer
+
+ MENAGERIE_ROOT: Path = (
+ Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
+ )
+ xml = MENAGERIE_ROOT / "ufactory_xarm7" / "xarm7.xml"
+ gripper_xml = None
+ server = MujocoRobotServer(
+ xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
+ )
+ server.serve()
+
+ else:
+ if args.robot == "xarm":
+ from gello.robots.xarm_robot import XArmRobot
+
+ robot = XArmRobot(ip=args.robot_ip)
+ elif args.robot == "ur":
+ from gello.robots.ur import URRobot
+
+ robot = URRobot(robot_ip=args.robot_ip)
+ elif args.robot == "panda":
+ from gello.robots.panda import PandaRobot
+
+ robot = PandaRobot(robot_ip=args.robot_ip)
+ elif args.robot == "bimanual_ur":
+ from gello.robots.ur import URRobot
+
+ # IP for the bimanual robot setup is hardcoded
+ _robot_l = URRobot(robot_ip="192.168.2.10")
+ _robot_r = URRobot(robot_ip="192.168.1.10")
+ robot = BimanualRobot(_robot_l, _robot_r)
+ elif args.robot == "none" or args.robot == "print":
+ robot = PrintRobot(8)
+
+ else:
+ raise NotImplementedError(
+ f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
+ )
+ server = ZMQServerRobot(robot, port=port, host=args.hostname)
+ print(f"Starting robot server on port {port}")
+ server.serve()
+
+
+def main(args):
+ launch_robot_server(args)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/experiments/quick_run.py b/experiments/quick_run.py
new file mode 100644
index 0000000..9d1fc90
--- /dev/null
+++ b/experiments/quick_run.py
@@ -0,0 +1,192 @@
+import atexit
+import glob
+import time
+from dataclasses import dataclass
+from multiprocessing import Process
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import tyro
+
+from gello.agents.agent import DummyAgent
+from gello.agents.gello_agent import GelloAgent
+from gello.agents.spacemouse_agent import SpacemouseAgent
+from gello.env import RobotEnv
+from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
+
+
+@dataclass
+class Args:
+ hz: int = 100
+
+ agent: str = "gello"
+ robot: str = "ur5"
+ gello_port: Optional[str] = None
+ mock: bool = False
+ verbose: bool = False
+
+ hostname: str = "127.0.0.1"
+ robot_port: int = 6001
+
+
+def launch_robot_server(port: int, args: Args):
+ if args.robot == "sim_ur":
+ MENAGERIE_ROOT: Path = (
+ Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
+ )
+ xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
+ gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
+ from gello.robots.sim_robot import MujocoRobotServer
+
+ server = MujocoRobotServer(
+ xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
+ )
+ server.serve()
+ elif args.robot == "sim_panda":
+ from gello.robots.sim_robot import MujocoRobotServer
+
+ MENAGERIE_ROOT: Path = (
+ Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
+ )
+ xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
+ gripper_xml = None
+ server = MujocoRobotServer(
+ xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
+ )
+ server.serve()
+
+ else:
+ if args.robot == "xarm":
+ from gello.robots.xarm_robot import XArmRobot
+
+ robot = XArmRobot()
+ elif args.robot == "ur5":
+ from gello.robots.ur import URRobot
+
+ robot = URRobot(robot_ip=args.robot_ip)
+ else:
+ raise NotImplementedError(
+ f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
+ )
+ server = ZMQServerRobot(robot, port=port, host=args.hostname)
+ print(f"Starting robot server on port {port}")
+ server.serve()
+
+
+def start_robot_process(args: Args):
+ process = Process(target=launch_robot_server, args=(args.robot_port, args))
+
+ # Function to kill the child process
+ def kill_child_process(process):
+ print("Killing child process...")
+ process.terminate()
+
+ # Register the kill_child_process function to be called at exit
+ atexit.register(kill_child_process, process)
+ process.start()
+
+
+def main(args: Args):
+ start_robot_process(args)
+
+ robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
+ env = RobotEnv(robot_client, control_rate_hz=args.hz)
+
+ if args.agent == "gello":
+ gello_port = args.gello_port
+ if gello_port is None:
+ usb_ports = glob.glob("/dev/serial/by-id/*")
+ print(f"Found {len(usb_ports)} ports")
+ if len(usb_ports) > 0:
+ gello_port = usb_ports[0]
+ print(f"using port {gello_port}")
+ else:
+ raise ValueError(
+ "No gello port found, please specify one or plug in gello"
+ )
+ agent = GelloAgent(port=gello_port)
+
+ reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
+ curr_joints = env.get_obs()["joint_positions"]
+ if reset_joints.shape == curr_joints.shape:
+ max_delta = (np.abs(curr_joints - reset_joints)).max()
+ steps = min(int(max_delta / 0.01), 100)
+
+ for jnt in np.linspace(curr_joints, reset_joints, steps):
+ env.step(jnt)
+ time.sleep(0.001)
+
+ elif args.agent == "quest":
+ from gello.agents.quest_agent import SingleArmQuestAgent
+
+ agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
+ elif args.agent == "spacemouse":
+ agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
+ elif args.agent == "dummy" or args.agent == "none":
+ agent = DummyAgent(num_dofs=robot_client.num_dofs())
+ else:
+ raise ValueError("Invalid agent name")
+
+ # going to start position
+ print("Going to start position")
+ start_pos = agent.act(env.get_obs())
+ obs = env.get_obs()
+ joints = obs["joint_positions"]
+
+ abs_deltas = np.abs(start_pos - joints)
+ id_max_joint_delta = np.argmax(abs_deltas)
+
+ max_joint_delta = 0.8
+ if abs_deltas[id_max_joint_delta] > max_joint_delta:
+ id_mask = abs_deltas > max_joint_delta
+ print()
+ ids = np.arange(len(id_mask))[id_mask]
+ for i, delta, joint, current_j in zip(
+ ids,
+ abs_deltas[id_mask],
+ start_pos[id_mask],
+ joints[id_mask],
+ ):
+ print(
+ f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
+ )
+ return
+
+ print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
+ assert len(start_pos) == len(
+ joints
+ ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
+
+ max_delta = 0.05
+ for _ in range(25):
+ obs = env.get_obs()
+ command_joints = agent.act(obs)
+ current_joints = obs["joint_positions"]
+ delta = command_joints - current_joints
+ max_joint_delta = np.abs(delta).max()
+ if max_joint_delta > max_delta:
+ delta = delta / max_joint_delta * max_delta
+ env.step(current_joints + delta)
+
+ obs = env.get_obs()
+ joints = obs["joint_positions"]
+ action = agent.act(obs)
+ if (action - joints > 0.5).any():
+ print("Action is too big")
+
+ # print which joints are too big
+ joint_index = np.where(action - joints > 0.8)
+ for j in joint_index:
+ print(
+ f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
+ )
+ exit()
+
+ while True:
+ action = agent.act(obs)
+ obs = env.step(action)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/experiments/run_env.py b/experiments/run_env.py
new file mode 100644
index 0000000..bd2319a
--- /dev/null
+++ b/experiments/run_env.py
@@ -0,0 +1,246 @@
+import datetime
+import glob
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Tuple
+
+import numpy as np
+import tyro
+
+from gello.agents.agent import BimanualAgent, DummyAgent
+from gello.agents.gello_agent import GelloAgent
+from gello.data_utils.format_obs import save_frame
+from gello.env import RobotEnv
+from gello.robots.robot import PrintRobot
+from gello.zmq_core.robot_node import ZMQClientRobot
+
+
+def print_color(*args, color=None, attrs=(), **kwargs):
+ import termcolor
+
+ if len(args) > 0:
+ args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
+ print(*args, **kwargs)
+
+
+@dataclass
+class Args:
+ agent: str = "none"
+ robot_port: int = 6001
+ wrist_camera_port: int = 5000
+ base_camera_port: int = 5001
+ hostname: str = "127.0.0.1"
+ robot_type: str = None # only needed for quest agent or spacemouse agent
+ hz: int = 100
+ start_joints: Optional[Tuple[float, ...]] = None
+
+ gello_port: Optional[str] = None
+ mock: bool = False
+ use_save_interface: bool = False
+ data_dir: str = "~/bc_data"
+ bimanual: bool = False
+ verbose: bool = False
+
+
+def main(args):
+ if args.mock:
+ robot_client = PrintRobot(8, dont_print=True)
+ camera_clients = {}
+ else:
+ camera_clients = {
+ # you can optionally add camera nodes here for imitation learning purposes
+ # "wrist": ZMQClientCamera(port=args.wrist_camera_port, host=args.hostname),
+ # "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
+ }
+ robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
+ env = RobotEnv(robot_client, control_rate_hz=args.hz, camera_dict=camera_clients)
+
+ if args.bimanual:
+ if args.agent == "gello":
+ # dynamixel control box port map (to distinguish left and right gello)
+ right = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0"
+ left = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0"
+ left_agent = GelloAgent(port=left)
+ right_agent = GelloAgent(port=right)
+ agent = BimanualAgent(left_agent, right_agent)
+ elif args.agent == "quest":
+ from gello.agents.quest_agent import SingleArmQuestAgent
+
+ left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
+ right_agent = SingleArmQuestAgent(
+ robot_type=args.robot_type, which_hand="r"
+ )
+ agent = BimanualAgent(left_agent, right_agent)
+ # raise NotImplementedError
+ elif args.agent == "spacemouse":
+ from gello.agents.spacemouse_agent import SpacemouseAgent
+
+ left_path = "/dev/hidraw0"
+ right_path = "/dev/hidraw1"
+ left_agent = SpacemouseAgent(
+ robot_type=args.robot_type, device_path=left_path, verbose=args.verbose
+ )
+ right_agent = SpacemouseAgent(
+ robot_type=args.robot_type,
+ device_path=right_path,
+ verbose=args.verbose,
+ invert_button=True,
+ )
+ agent = BimanualAgent(left_agent, right_agent)
+ else:
+ raise ValueError(f"Invalid agent name for bimanual: {args.agent}")
+
+ # System setup specific. This reset configuration works well on our setup. If you are mounting the robot
+ # differently, you need a separate reset joint configuration.
+ reset_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0, 0])
+ reset_joints_right = np.deg2rad([0, -90, 90, -90, -90, 0, 0])
+ reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
+ curr_joints = env.get_obs()["joint_positions"]
+ max_delta = (np.abs(curr_joints - reset_joints)).max()
+ steps = min(int(max_delta / 0.01), 100)
+
+ for jnt in np.linspace(curr_joints, reset_joints, steps):
+ env.step(jnt)
+ else:
+ if args.agent == "gello":
+ gello_port = args.gello_port
+ if gello_port is None:
+ usb_ports = glob.glob("/dev/serial/by-id/*")
+ print(f"Found {len(usb_ports)} ports")
+ if len(usb_ports) > 0:
+ gello_port = usb_ports[0]
+ print(f"using port {gello_port}")
+ else:
+ raise ValueError(
+ "No gello port found, please specify one or plug in gello"
+ )
+ if args.start_joints is None:
+ reset_joints = np.deg2rad(
+ [0, -90, 90, -90, -90, 0, 0]
+ ) # Change this to your own reset joints
+ else:
+ reset_joints = args.start_joints
+ agent = GelloAgent(port=gello_port, start_joints=args.start_joints)
+ curr_joints = env.get_obs()["joint_positions"]
+ if reset_joints.shape == curr_joints.shape:
+ max_delta = (np.abs(curr_joints - reset_joints)).max()
+ steps = min(int(max_delta / 0.01), 100)
+
+ for jnt in np.linspace(curr_joints, reset_joints, steps):
+ env.step(jnt)
+ time.sleep(0.001)
+ elif args.agent == "quest":
+ from gello.agents.quest_agent import SingleArmQuestAgent
+
+ agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
+ elif args.agent == "spacemouse":
+ from gello.agents.spacemouse_agent import SpacemouseAgent
+
+ agent = SpacemouseAgent(robot_type=args.robot_type, verbose=args.verbose)
+ elif args.agent == "dummy" or args.agent == "none":
+ agent = DummyAgent(num_dofs=robot_client.num_dofs())
+ elif args.agent == "policy":
+ raise NotImplementedError("add your imitation policy here if there is one")
+ else:
+ raise ValueError("Invalid agent name")
+
+ # going to start position
+ print("Going to start position")
+ start_pos = agent.act(env.get_obs())
+ obs = env.get_obs()
+ joints = obs["joint_positions"]
+
+ abs_deltas = np.abs(start_pos - joints)
+ id_max_joint_delta = np.argmax(abs_deltas)
+
+ max_joint_delta = 0.8
+ if abs_deltas[id_max_joint_delta] > max_joint_delta:
+ id_mask = abs_deltas > max_joint_delta
+ print()
+ ids = np.arange(len(id_mask))[id_mask]
+ for i, delta, joint, current_j in zip(
+ ids,
+ abs_deltas[id_mask],
+ start_pos[id_mask],
+ joints[id_mask],
+ ):
+ print(
+ f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
+ )
+ return
+
+ print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
+ assert len(start_pos) == len(
+ joints
+ ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
+
+ max_delta = 0.05
+ for _ in range(25):
+ obs = env.get_obs()
+ command_joints = agent.act(obs)
+ current_joints = obs["joint_positions"]
+ delta = command_joints - current_joints
+ max_joint_delta = np.abs(delta).max()
+ if max_joint_delta > max_delta:
+ delta = delta / max_joint_delta * max_delta
+ env.step(current_joints + delta)
+
+ obs = env.get_obs()
+ joints = obs["joint_positions"]
+ action = agent.act(obs)
+ if (action - joints > 0.5).any():
+ print("Action is too big")
+
+ # print which joints are too big
+ joint_index = np.where(action - joints > 0.8)
+ for j in joint_index:
+ print(
+ f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
+ )
+ exit()
+
+ if args.use_save_interface:
+ from gello.data_utils.keyboard_interface import KBReset
+
+ kb_interface = KBReset()
+
+ print_color("\nStart 🚀🚀🚀", color="green", attrs=("bold",))
+
+ save_path = None
+ start_time = time.time()
+ while True:
+ num = time.time() - start_time
+ message = f"\rTime passed: {round(num, 2)} "
+ print_color(
+ message,
+ color="white",
+ attrs=("bold",),
+ end="",
+ flush=True,
+ )
+ action = agent.act(obs)
+ dt = datetime.datetime.now()
+ if args.use_save_interface:
+ state = kb_interface.update()
+ if state == "start":
+ dt_time = datetime.datetime.now()
+ save_path = (
+ Path(args.data_dir).expanduser()
+ / args.agent
+ / dt_time.strftime("%m%d_%H%M%S")
+ )
+ save_path.mkdir(parents=True, exist_ok=True)
+ print(f"Saving to {save_path}")
+ elif state == "save":
+ assert save_path is not None, "something went wrong"
+ save_frame(save_path, dt, obs, action)
+ elif state == "normal":
+ save_path = None
+ else:
+ raise ValueError(f"Invalid state {state}")
+ obs = env.step(action)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/gello/__init__.py b/gello/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gello/agents/agent.py b/gello/agents/agent.py
new file mode 100644
index 0000000..f07457e
--- /dev/null
+++ b/gello/agents/agent.py
@@ -0,0 +1,43 @@
+from typing import Any, Dict, Protocol
+
+import numpy as np
+
+
+class Agent(Protocol):
+ def act(self, obs: Dict[str, Any]) -> np.ndarray:
+ """Returns an action given an observation.
+
+ Args:
+ obs: observation from the environment.
+
+ Returns:
+ action: action to take on the environment.
+ """
+ raise NotImplementedError
+
+
+class DummyAgent(Agent):
+ def __init__(self, num_dofs: int):
+ self.num_dofs = num_dofs
+
+ def act(self, obs: Dict[str, Any]) -> np.ndarray:
+ return np.zeros(self.num_dofs)
+
+
+class BimanualAgent(Agent):
+ def __init__(self, agent_left: Agent, agent_right: Agent):
+ self.agent_left = agent_left
+ self.agent_right = agent_right
+
+ def act(self, obs: Dict[str, Any]) -> np.ndarray:
+ left_obs = {}
+ right_obs = {}
+ for key, val in obs.items():
+ L = val.shape[0]
+ half_dim = L // 2
+ assert L == half_dim * 2, f"{key} must be even, something is wrong"
+ left_obs[key] = val[:half_dim]
+ right_obs[key] = val[half_dim:]
+ return np.concatenate(
+ [self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
+ )
diff --git a/gello/agents/gello_agent.py b/gello/agents/gello_agent.py
new file mode 100644
index 0000000..fca60c3
--- /dev/null
+++ b/gello/agents/gello_agent.py
@@ -0,0 +1,139 @@
+import os
+from dataclasses import dataclass
+from typing import Dict, Optional, Sequence, Tuple
+
+import numpy as np
+
+from gello.agents.agent import Agent
+from gello.robots.dynamixel import DynamixelRobot
+
+
+@dataclass
+class DynamixelRobotConfig:
+ joint_ids: Sequence[int]
+ """The joint ids of GELLO (not including the gripper). Usually (1, 2, 3 ...)."""
+
+ joint_offsets: Sequence[float]
+ """The joint offsets of GELLO. There needs to be a joint offset for each joint_id and should be a multiple of pi/2."""
+
+ joint_signs: Sequence[int]
+ """The joint signs of GELLO. There needs to be a joint sign for each joint_id and should be either 1 or -1.
+
+ This will be different for each arm design. Refernce the examples below for the correct signs for your robot.
+ """
+
+ gripper_config: Tuple[int, int, int]
+ """The gripper config of GELLO. This is a tuple of (gripper_joint_id, degrees in open_position, degrees in closed_position)."""
+
+ def __post_init__(self):
+ assert len(self.joint_ids) == len(self.joint_offsets)
+ assert len(self.joint_ids) == len(self.joint_signs)
+
+ def make_robot(
+ self, port: str = "/dev/ttyUSB0", start_joints: Optional[np.ndarray] = None
+ ) -> DynamixelRobot:
+ return DynamixelRobot(
+ joint_ids=self.joint_ids,
+ joint_offsets=list(self.joint_offsets),
+ real=True,
+ joint_signs=list(self.joint_signs),
+ port=port,
+ gripper_config=self.gripper_config,
+ start_joints=start_joints,
+ )
+
+
+PORT_CONFIG_MAP: Dict[str, DynamixelRobotConfig] = {
+ # xArm
+ # "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig(
+ # joint_ids=(1, 2, 3, 4, 5, 6, 7),
+ # joint_offsets=(
+ # 2 * np.pi / 2,
+ # 2 * np.pi / 2,
+ # 2 * np.pi / 2,
+ # 2 * np.pi / 2,
+ # -1 * np.pi / 2 + 2 * np.pi,
+ # 1 * np.pi / 2,
+ # 1 * np.pi / 2,
+ # ),
+ # joint_signs=(1, 1, 1, 1, 1, 1, 1),
+ # gripper_config=(8, 279, 279 - 50),
+ # ),
+ # panda
+ # "/dev/cu.usbserial-FT3M9NVB": DynamixelRobotConfig(
+ "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig(
+ joint_ids=(1, 2, 3, 4, 5, 6, 7),
+ joint_offsets=(
+ 3 * np.pi / 2,
+ 2 * np.pi / 2,
+ 1 * np.pi / 2,
+ 4 * np.pi / 2,
+ -2 * np.pi / 2 + 2 * np.pi,
+ 3 * np.pi / 2,
+ 4 * np.pi / 2,
+ ),
+ joint_signs=(1, -1, 1, 1, 1, -1, 1),
+ gripper_config=(8, 195, 152),
+ ),
+ # Left UR
+ "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0": DynamixelRobotConfig(
+ joint_ids=(1, 2, 3, 4, 5, 6),
+ joint_offsets=(
+ 0,
+ 1 * np.pi / 2 + np.pi,
+ np.pi / 2 + 0 * np.pi,
+ 0 * np.pi + np.pi / 2,
+ np.pi - 2 * np.pi / 2,
+ -1 * np.pi / 2 + 2 * np.pi,
+ ),
+ joint_signs=(1, 1, -1, 1, 1, 1),
+ gripper_config=(7, 20, -22),
+ ),
+ # Right UR
+ "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0": DynamixelRobotConfig(
+ joint_ids=(1, 2, 3, 4, 5, 6),
+ joint_offsets=(
+ np.pi + 0 * np.pi,
+ 2 * np.pi + np.pi / 2,
+ 2 * np.pi + np.pi / 2,
+ 2 * np.pi + np.pi / 2,
+ 1 * np.pi,
+ 3 * np.pi / 2,
+ ),
+ joint_signs=(1, 1, -1, 1, 1, 1),
+ gripper_config=(7, 286, 248),
+ ),
+}
+
+
+class GelloAgent(Agent):
+ def __init__(
+ self,
+ port: str,
+ dynamixel_config: Optional[DynamixelRobotConfig] = None,
+ start_joints: Optional[np.ndarray] = None,
+ ):
+ if dynamixel_config is not None:
+ self._robot = dynamixel_config.make_robot(
+ port=port, start_joints=start_joints
+ )
+ else:
+ assert os.path.exists(port), port
+ assert port in PORT_CONFIG_MAP, f"Port {port} not in config map"
+
+ config = PORT_CONFIG_MAP[port]
+ self._robot = config.make_robot(port=port, start_joints=start_joints)
+
+ def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
+ return self._robot.get_joint_state()
+ dyna_joints = self._robot.get_joint_state()
+ # current_q = dyna_joints[:-1] # last one dim is the gripper
+ current_gripper = dyna_joints[-1] # last one dim is the gripper
+
+ print(current_gripper)
+ if current_gripper < 0.2:
+ self._robot.set_torque_mode(False)
+ return obs["joint_positions"]
+ else:
+ self._robot.set_torque_mode(False)
+ return dyna_joints
diff --git a/gello/agents/quest_agent.py b/gello/agents/quest_agent.py
new file mode 100644
index 0000000..be7af59
--- /dev/null
+++ b/gello/agents/quest_agent.py
@@ -0,0 +1,178 @@
+from typing import Dict
+
+import numpy as np
+import quaternion
+from dm_control import mjcf
+from dm_control.utils.inverse_kinematics import qpos_from_site_pose
+from oculus_reader.reader import OculusReader
+
+from gello.agents.agent import Agent
+from gello.agents.spacemouse_agent import apply_transfer, mj2ur, ur2mj
+from gello.dm_control_tasks.arms.ur5e import UR5e
+
+# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on
+# our setup, for details please checkout the project page.
+quest2ur = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
+ur2quest = np.linalg.inv(quest2ur)
+
+translation_scaling_factor = 2.0
+
+
+class SingleArmQuestAgent(Agent):
+ def __init__(self, robot_type: str, which_hand: str, verbose: bool = False) -> None:
+ """Interact with the robot using the quest controller.
+
+ leftTrig: press to start control (also record the current position as the home position)
+ leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper
+ """
+ self.which_hand = which_hand
+ assert self.which_hand in ["l", "r"]
+
+ self.oculus_reader = OculusReader()
+ if robot_type == "ur5":
+ _robot = UR5e()
+ else:
+ raise ValueError(f"Unknown robot type: {robot_type}")
+ self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model)
+ self.control_active = False
+ self.reference_quest_pose = None
+ self.reference_ee_rot_ur = None
+ self.reference_ee_pos_ur = None
+
+ self.robot_type = robot_type
+ self._verbose = verbose
+
+ def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
+ if self.robot_type == "ur5":
+ num_dof = 6
+ current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
+ current_gripper_angle = obs["joint_positions"][-1]
+ # run the fk
+ self.physics.data.qpos[:num_dof] = current_qpos
+ self.physics.step()
+
+ ee_rot_mj = np.array(
+ self.physics.named.data.site_xmat["attachment_site"]
+ ).reshape(3, 3)
+ ee_pos_mj = np.array(self.physics.named.data.site_xpos["attachment_site"])
+ if self.which_hand == "l":
+ pose_key = "l"
+ trigger_key = "leftTrig"
+ # joystick_key = "leftJS"
+ # left yx
+ gripper_open_key = "Y"
+ gripper_close_key = "X"
+ elif self.which_hand == "r":
+ pose_key = "r"
+ trigger_key = "rightTrig"
+ # joystick_key = "rightJS"
+ # right ba for the key
+ gripper_open_key = "B"
+ gripper_close_key = "A"
+ else:
+ raise ValueError(f"Unknown hand: {self.which_hand}")
+ # check the trigger button state
+ pose_data, button_data = self.oculus_reader.get_transformations_and_buttons()
+ if len(pose_data) == 0 or len(button_data) == 0:
+ print("no data, quest not yet ready")
+ return np.concatenate([current_qpos, [current_gripper_angle]])
+
+ new_gripper_angle = current_gripper_angle
+ if button_data[gripper_open_key]:
+ new_gripper_angle = 1
+ if button_data[gripper_close_key]:
+ new_gripper_angle = 0
+ arm_not_move_return = np.concatenate([current_qpos, [new_gripper_angle]])
+ if len(pose_data) == 0:
+ print("no data, quest not yet ready")
+ return arm_not_move_return
+
+ trigger_state = button_data[trigger_key][0]
+ if trigger_state > 0.5:
+ if self.control_active is True:
+ if self._verbose:
+ print("controlling the arm")
+ current_pose = pose_data[pose_key]
+ delta_rot = current_pose[:3, :3] @ np.linalg.inv(
+ self.reference_quest_pose[:3, :3]
+ )
+ delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3]
+ delta_pos_ur = (
+ apply_transfer(quest2ur, delta_pos) * translation_scaling_factor
+ )
+ # ? is this the case?
+ delta_rot_ur = quest2ur[:3, :3] @ delta_rot @ ur2quest[:3, :3]
+ if self._verbose:
+ print(
+ f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}"
+ )
+ next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur
+ next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur
+
+ target_quat = quaternion.as_float_array(
+ quaternion.from_rotation_matrix(ur2mj[:3, :3] @ next_ee_rot_ur)
+ )
+ ik_result = qpos_from_site_pose(
+ self.physics,
+ "attachment_site",
+ target_pos=apply_transfer(ur2mj, next_ee_pos_ur),
+ target_quat=target_quat,
+ tol=1e-14,
+ max_steps=400,
+ )
+ self.physics.reset()
+ if ik_result.success:
+ new_qpos = ik_result.qpos[:num_dof]
+ else:
+ print("ik failed, using the original qpos")
+ return arm_not_move_return
+ command = np.concatenate([new_qpos, [new_gripper_angle]])
+ return command
+
+ else: # last state is not in active
+ self.control_active = True
+ if self._verbose:
+ print("control activated!")
+ self.reference_quest_pose = pose_data[pose_key]
+
+ self.reference_ee_rot_ur = mj2ur[:3, :3] @ ee_rot_mj
+ self.reference_ee_pos_ur = apply_transfer(mj2ur, ee_pos_mj)
+ return arm_not_move_return
+ else:
+ if self._verbose:
+ print("deactive control")
+ self.control_active = False
+ self.reference_quest_pose = None
+ return arm_not_move_return
+
+
+class DualArmQuestAgent(Agent):
+ def __init__(self, robot_type: str) -> None:
+ self.left_arm = SingleArmQuestAgent(robot_type, "l")
+ self.right_arm = SingleArmQuestAgent(robot_type, "r")
+
+ def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
+ pass
+
+
+if __name__ == "__main__":
+ oculus_reader = OculusReader()
+ while True:
+ """
+ example output:
+ ({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ],
+ [-0.107737 , 0.0958919, 0.989544 , -0.833478 ],
+ [ 0.549685 , 0.835106 , -0.0210789, -0.892425 ],
+ [ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ],
+ [ 0.070887, 0.516083, -0.8536 , -0.238691],
+ [-0.941994, -0.246809, -0.227447, -0.370447],
+ [ 0. , 0. , 0. , 1. ]])},
+ {'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)})
+
+ """
+ pose_data, button_data = oculus_reader.get_transformations_and_buttons()
+ if len(pose_data) == 0:
+ print("no data")
+ continue
+ else:
+ print(pose_data["l"])
diff --git a/gello/agents/spacemouse_agent.py b/gello/agents/spacemouse_agent.py
new file mode 100644
index 0000000..03a42f1
--- /dev/null
+++ b/gello/agents/spacemouse_agent.py
@@ -0,0 +1,224 @@
+import threading
+import time
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import numpy as np
+from dm_control import mjcf
+from dm_control.utils.inverse_kinematics import qpos_from_site_pose
+
+from gello.agents.agent import Agent
+from gello.dm_control_tasks.arms.ur5e import UR5e
+
+# mujoco has a slightly different coordinate system than UR control box
+mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
+ur2mj = np.linalg.inv(mj2ur)
+
+# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on
+# our setup, for details please checkout the project page.
+spacemouse2ur = np.array(
+ [
+ [-1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],
+ ]
+)
+ur2spacemouse = np.linalg.inv(spacemouse2ur)
+
+
+def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray:
+ # xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix
+ if len(xyz) == 3:
+ xyz = np.append(xyz, 1)
+ return np.matmul(mat, xyz)[:3]
+
+
+@dataclass
+class SpacemouseConfig:
+ angle_scale: float = 0.24
+ translation_scale: float = 0.06
+ # only control the xyz, rotation direction, not the gripper
+ invert_control: np.ndarray = np.ones(6)
+ rotation_mode: str = "euler"
+
+
+class SpacemouseAgent(Agent):
+ def __init__(
+ self,
+ robot_type: str,
+ config: SpacemouseConfig = SpacemouseConfig(),
+ device_path: Optional[str] = None,
+ verbose: bool = True,
+ invert_button: bool = False,
+ ) -> None:
+ self.config = config
+ self.last_state_lock = threading.Lock()
+ self._invert_button = invert_button
+ # example state:SpaceNavigator(t=3.581528532, x=0.0, y=0.0, z=0.0, roll=0.0, pitch=0.0, yaw=0.0, buttons=[0, 0])
+ # all continuous inputs range from 0-1, buttons are 0 or 1
+ self.spacemouse_latest_state = None
+ self._device_path = device_path
+ spacemouse_thread = threading.Thread(target=self._read_from_spacemouse)
+ spacemouse_thread.start()
+ self._verbose = verbose
+ if self._verbose:
+ print(f"robot_type: {robot_type}")
+ if robot_type == "ur5":
+ _robot = UR5e()
+ else:
+ raise ValueError(f"Unknown robot type: {robot_type}")
+ self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model)
+
+ def _read_from_spacemouse(self):
+ import pyspacemouse
+
+ if self._device_path is None:
+ mouse = pyspacemouse.open()
+ else:
+ mouse = pyspacemouse.open(path=self._device_path)
+ if mouse:
+ while 1:
+ state = mouse.read()
+ with self.last_state_lock:
+ self.spacemouse_latest_state = state
+ time.sleep(0.001)
+ else:
+ raise ValueError("Failed to open spacemouse")
+
+ def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
+ import quaternion
+
+ # obs: the folllow robot's current state
+ # in rad, 6/7 dof depends on the robot type
+ num_dof = 6
+ if self._verbose:
+ print("act invoked")
+ current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
+ current_gripper_angle = obs["joint_positions"][-1]
+ self.physics.data.qpos[:num_dof] = current_qpos
+
+ self.physics.step()
+ ee_rot = np.array(self.physics.named.data.site_xmat["attachment_site"]).reshape(
+ 3, 3
+ )
+ ee_pos = np.array(self.physics.named.data.site_xpos["attachment_site"])
+
+ ee_rot = mj2ur[:3, :3] @ ee_rot
+ ee_pos = apply_transfer(mj2ur, ee_pos)
+ # ^ mujoco coordinate to UR
+
+ with self.last_state_lock:
+ spacemouse_read = self.spacemouse_latest_state
+ if self._verbose:
+ print(f"spacemouse_read: {spacemouse_read}")
+
+ assert spacemouse_read is not None
+ spacemouse_xyz_rot_np = np.array(
+ [
+ spacemouse_read.x,
+ spacemouse_read.y,
+ spacemouse_read.z,
+ spacemouse_read.roll,
+ spacemouse_read.pitch,
+ spacemouse_read.yaw,
+ ]
+ )
+ spacemouse_button = (
+ spacemouse_read.buttons
+ ) # size 2 list binary indicating left/right button pressing
+ spacemouse_xyz_rot_np = spacemouse_xyz_rot_np * self.config.invert_control
+ if np.max(np.abs(spacemouse_xyz_rot_np)) > 0.9:
+ spacemouse_xyz_rot_np[np.abs(spacemouse_xyz_rot_np) < 0.6] = 0
+ tx, ty, tz, r, p, y = spacemouse_xyz_rot_np
+ # convert roll pick yaw to rotation matrix (rpy)
+
+ trans_transform = np.eye(4)
+ # delta xyz from the spacemouse reading
+ trans_transform[:3, 3] = apply_transfer(
+ spacemouse2ur, np.array([tx, ty, tz]) * self.config.translation_scale
+ )
+
+ # break rot_transform into each axis
+ rot_transform_x = np.eye(4)
+ rot_transform_x[:3, :3] = quaternion.as_rotation_matrix(
+ quaternion.from_rotation_vector(
+ np.array([-p, 0, 0]) * self.config.angle_scale
+ )
+ )
+
+ rot_transform_y = np.eye(4)
+ rot_transform_y[:3, :3] = quaternion.as_rotation_matrix(
+ quaternion.from_rotation_vector(
+ np.array([0, r, 0]) * self.config.angle_scale
+ )
+ )
+
+ rot_transform_z = np.eye(4)
+ rot_transform_z[:3, :3] = quaternion.as_rotation_matrix(
+ quaternion.from_rotation_vector(
+ np.array([0, 0, -y]) * self.config.angle_scale
+ )
+ )
+
+ # in ur space
+ rot_transform = (
+ spacemouse2ur
+ @ rot_transform_z
+ @ rot_transform_y
+ @ rot_transform_x
+ @ ur2spacemouse
+ )
+
+ if self._verbose:
+ print(f"rot_transform: {rot_transform}")
+ print(f"new spacemounse cmd in ur space = {trans_transform[:3, 3]}")
+
+ # import pdb; pdb.set_trace()
+ new_ee_pos = trans_transform[:3, 3] + ee_pos
+ if self.config.rotation_mode == "rpy":
+ new_ee_rot = ee_rot @ rot_transform[:3, :3]
+ elif self.config.rotation_mode == "euler":
+ new_ee_rot = rot_transform[:3, :3] @ ee_rot
+ else:
+ raise NotImplementedError(
+ f"Unknown rotation mode: {self.config.rotation_mode}"
+ )
+
+ target_quat = quaternion.as_float_array(
+ quaternion.from_rotation_matrix(ur2mj[:3, :3] @ new_ee_rot)
+ )
+ ik_result = qpos_from_site_pose(
+ self.physics,
+ "attachment_site",
+ target_pos=apply_transfer(ur2mj, new_ee_pos),
+ target_quat=target_quat,
+ tol=1e-14,
+ max_steps=400,
+ )
+ self.physics.reset()
+ if ik_result.success:
+ new_qpos = ik_result.qpos[:num_dof]
+ else:
+ print("ik failed, using the original qpos")
+ return np.concatenate([current_qpos, [current_gripper_angle]])
+ new_gripper_angle = current_gripper_angle
+ if self._invert_button:
+ if spacemouse_button[1]:
+ new_gripper_angle = 1
+ if spacemouse_button[0]:
+ new_gripper_angle = 0
+ else:
+ if spacemouse_button[1]:
+ new_gripper_angle = 0
+ if spacemouse_button[0]:
+ new_gripper_angle = 1
+ command = np.concatenate([new_qpos, [new_gripper_angle]])
+ return command
+
+
+if __name__ == "__main__":
+ import pyspacemouse
+
+ success = pyspacemouse.open("/dev/hidraw4")
+ success = pyspacemouse.open("/dev/hidraw5")
diff --git a/gello/cameras/camera.py b/gello/cameras/camera.py
new file mode 100644
index 0000000..6a4de19
--- /dev/null
+++ b/gello/cameras/camera.py
@@ -0,0 +1,81 @@
+from pathlib import Path
+from typing import Optional, Protocol, Tuple
+
+import numpy as np
+
+
+class CameraDriver(Protocol):
+ """Camera protocol.
+
+ A protocol for a camera driver. This is used to abstract the camera from the rest of the code.
+ """
+
+ def read(
+ self,
+ img_size: Optional[Tuple[int, int]] = None,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Read a frame from the camera.
+
+ Args:
+ img_size: The size of the image to return. If None, the original size is returned.
+ farthest: The farthest distance to map to 255.
+
+ Returns:
+ np.ndarray: The color image.
+ np.ndarray: The depth image.
+ """
+
+
+class DummyCamera(CameraDriver):
+ """A dummy camera for testing."""
+
+ def read(
+ self,
+ img_size: Optional[Tuple[int, int]] = None,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Read a frame from the camera.
+
+ Args:
+ img_size: The size of the image to return. If None, the original size is returned.
+ farthest: The farthest distance to map to 255.
+
+ Returns:
+ np.ndarray: The color image.
+ np.ndarray: The depth image.
+ """
+ if img_size is None:
+ return (
+ np.random.randint(255, size=(480, 640, 3), dtype=np.uint8),
+ np.random.randint(255, size=(480, 640, 1), dtype=np.uint16),
+ )
+ else:
+ return (
+ np.random.randint(
+ 255, size=(img_size[0], img_size[1], 3), dtype=np.uint8
+ ),
+ np.random.randint(
+ 255, size=(img_size[0], img_size[1], 1), dtype=np.uint16
+ ),
+ )
+
+
+class SavedCamera(CameraDriver):
+ def __init__(self, path: str = "example"):
+ self.path = str(Path(__file__).parent / path)
+ from PIL import Image
+
+ self._color_img = Image.open(f"{self.path}/image.png")
+ self._depth_img = Image.open(f"{self.path}/depth.png")
+
+ def read(
+ self,
+ img_size: Optional[Tuple[int, int]] = None,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ if img_size is not None:
+ color_img = self._color_img.resize(img_size)
+ depth_img = self._depth_img.resize(img_size)
+ else:
+ color_img = self._color_img
+ depth_img = self._depth_img
+
+ return np.array(color_img), np.array(depth_img)[:, :, 0:1]
diff --git a/gello/cameras/realsense_camera.py b/gello/cameras/realsense_camera.py
new file mode 100644
index 0000000..1fd72b8
--- /dev/null
+++ b/gello/cameras/realsense_camera.py
@@ -0,0 +1,123 @@
+import os
+import time
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+from gello.cameras.camera import CameraDriver
+
+
+def get_device_ids() -> List[str]:
+ import pyrealsense2 as rs
+
+ ctx = rs.context()
+ devices = ctx.query_devices()
+ device_ids = []
+ for dev in devices:
+ dev.hardware_reset()
+ device_ids.append(dev.get_info(rs.camera_info.serial_number))
+ time.sleep(2)
+ return device_ids
+
+
+class RealSenseCamera(CameraDriver):
+ def __repr__(self) -> str:
+ return f"RealSenseCamera(device_id={self._device_id})"
+
+ def __init__(self, device_id: Optional[str] = None, flip: bool = False):
+ import pyrealsense2 as rs
+
+ self._device_id = device_id
+
+ if device_id is None:
+ ctx = rs.context()
+ devices = ctx.query_devices()
+ for dev in devices:
+ dev.hardware_reset()
+ time.sleep(2)
+ self._pipeline = rs.pipeline()
+ config = rs.config()
+ else:
+ self._pipeline = rs.pipeline()
+ config = rs.config()
+ config.enable_device(device_id)
+
+ config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
+ config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
+ self._pipeline.start(config)
+ self._flip = flip
+
+ def read(
+ self,
+ img_size: Optional[Tuple[int, int]] = None, # farthest: float = 0.12
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Read a frame from the camera.
+
+ Args:
+ img_size: The size of the image to return. If None, the original size is returned.
+ farthest: The farthest distance to map to 255.
+
+ Returns:
+ np.ndarray: The color image, shape=(H, W, 3)
+ np.ndarray: The depth image, shape=(H, W, 1)
+ """
+ import cv2
+
+ frames = self._pipeline.wait_for_frames()
+ color_frame = frames.get_color_frame()
+ color_image = np.asanyarray(color_frame.get_data())
+ depth_frame = frames.get_depth_frame()
+ depth_image = np.asanyarray(depth_frame.get_data())
+ # depth_image = cv2.convertScaleAbs(depth_image, alpha=0.03)
+ if img_size is None:
+ image = color_image[:, :, ::-1]
+ depth = depth_image
+ else:
+ image = cv2.resize(color_image, img_size)[:, :, ::-1]
+ depth = cv2.resize(depth_image, img_size)
+
+ # rotate 180 degree's because everything is upside down in order to center the camera
+ if self._flip:
+ image = cv2.rotate(image, cv2.ROTATE_180)
+ depth = cv2.rotate(depth, cv2.ROTATE_180)[:, :, None]
+ else:
+ depth = depth[:, :, None]
+
+ return image, depth
+
+
+def _debug_read(camera, save_datastream=False):
+ import cv2
+
+ cv2.namedWindow("image")
+ cv2.namedWindow("depth")
+ counter = 0
+ if not os.path.exists("images"):
+ os.makedirs("images")
+ if save_datastream and not os.path.exists("stream"):
+ os.makedirs("stream")
+ while True:
+ time.sleep(0.1)
+ image, depth = camera.read()
+ depth = np.concatenate([depth, depth, depth], axis=-1)
+ key = cv2.waitKey(1)
+ cv2.imshow("image", image[:, :, ::-1])
+ cv2.imshow("depth", depth)
+ if key == ord("s"):
+ cv2.imwrite(f"images/image_{counter}.png", image[:, :, ::-1])
+ cv2.imwrite(f"images/depth_{counter}.png", depth)
+ if save_datastream:
+ cv2.imwrite(f"stream/image_{counter}.png", image[:, :, ::-1])
+ cv2.imwrite(f"stream/depth_{counter}.png", depth)
+ counter += 1
+ if key == 27:
+ break
+
+
+if __name__ == "__main__":
+ device_ids = get_device_ids()
+ print(f"Found {len(device_ids)} devices")
+ print(device_ids)
+ rs = RealSenseCamera(flip=True, device_id=device_ids[0])
+ im, depth = rs.read()
+ _debug_read(rs, save_datastream=True)
diff --git a/gello/data_utils/conversion_utils.py b/gello/data_utils/conversion_utils.py
new file mode 100644
index 0000000..d7f6977
--- /dev/null
+++ b/gello/data_utils/conversion_utils.py
@@ -0,0 +1,231 @@
+from typing import Dict
+
+import cv2
+import numpy as np
+import torch
+import transforms3d._gohlketransforms as ttf
+
+
+def to_torch(array, device="cpu"):
+ if isinstance(array, torch.Tensor):
+ return array.to(device)
+ if isinstance(array, np.ndarray):
+ return torch.from_numpy(array).to(device)
+ else:
+ return torch.tensor(array).to(device)
+
+
+def to_numpy(array):
+ if isinstance(array, torch.Tensor):
+ return array.cpu().numpy()
+ return array
+
+
+def center_crop(rgb_frame, depth_frame):
+ H, W = rgb_frame.shape[-2:]
+ sq_size = min(H, W)
+
+ # crop the center square
+ if H > W:
+ rgb_frame = rgb_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size]
+ depth_frame = depth_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size]
+ elif W < H:
+ rgb_frame = rgb_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)]
+ depth_frame = depth_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)]
+
+ return rgb_frame, depth_frame
+
+
+def resize(rgb, depth, size=224):
+ rgb = rgb.transpose([1, 2, 0])
+ rgb = cv2.resize(rgb, (size, size), interpolation=cv2.INTER_LINEAR)
+ rgb = rgb.transpose([2, 0, 1])
+
+ depth = cv2.resize(depth[0], (size, size), interpolation=cv2.INTER_LINEAR)
+ depth = depth.reshape([1, size, size])
+ return rgb, depth
+
+
+def filter_depth(depth, max_depth=2.0, min_depth=0.0):
+ depth[np.isnan(depth)] = 0.0
+ depth[np.isinf(depth)] = 0.0
+ depth = np.clip(depth, min_depth, max_depth)
+ return depth
+
+
+def preproc_obs(
+ demo: Dict[str, np.ndarray], joint_only: bool = True
+) -> Dict[str, np.ndarray]:
+ # c h w
+ rgb_wrist = demo.get(f"wrist_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore
+ depth_wrist = demo.get(f"wrist_depth").transpose([2, 0, 1]) * 1.0 # type: ignore
+ rgb_base = demo.get("base_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore
+ depth_base = demo.get("base_depth").transpose([2, 0, 1]) * 1.0 # type: ignore
+
+ # Center crop and fitler depth
+ rgb_wrist, depth_wrist = resize(*center_crop(rgb_wrist, depth_wrist))
+ rgb_base, depth_base = resize(*center_crop(rgb_base, depth_base))
+
+ depth_wrist = filter_depth(depth_wrist)
+ depth_base = filter_depth(depth_base)
+
+ rgb = np.stack([rgb_wrist, rgb_base], axis=0)
+ depth = np.stack([depth_wrist, depth_base], axis=0)
+
+ # Dummy
+ dummy_cam = np.eye(4)
+ K = np.eye(3)
+
+ # state
+ qpos, qvel, ee_pos_quat, gripper_pos = (
+ demo.get("joint_positions"), # type: ignore
+ demo.get("joint_velocities"), # type: ignore
+ demo.get("ee_pos_quat"), # type: ignore
+ demo.get("gripper_position"), # type: ignore
+ )
+
+ if joint_only:
+ state: np.ndarray = qpos
+ else:
+ state: np.ndarray = np.concatenate([qpos, qvel, ee_pos_quat, gripper_pos[None]])
+
+ return {
+ "rgb": rgb,
+ "depth": depth,
+ "camera_poses": dummy_cam,
+ "K_matrices": K,
+ "state": state,
+ }
+
+
+class Pose(object):
+ def __init__(self, x, y, z, qw, qx, qy, qz):
+ self.p = np.array([x, y, z])
+
+ # we internally use tf.transformations, which uses [x, y, z, w] for quaternions.
+ self.q = np.array([qx, qy, qz, qw])
+
+ # make sure that the quaternion has positive scalar part
+ if self.q[3] < 0:
+ self.q *= -1
+
+ self.q = self.q / np.linalg.norm(self.q)
+
+ def __mul__(self, other):
+ assert isinstance(other, Pose)
+ p = self.p + ttf.quaternion_matrix(self.q)[:3, :3].dot(other.p)
+ q = ttf.quaternion_multiply(self.q, other.q)
+ return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
+
+ def __rmul__(self, other):
+ assert isinstance(other, Pose)
+ return other * self
+
+ def __str__(self):
+ return "p: {}, q: {}".format(self.p, self.q)
+
+ def inv(self):
+ R = ttf.quaternion_matrix(self.q)[:3, :3]
+ p = -R.T.dot(self.p)
+ q = ttf.quaternion_inverse(self.q)
+ return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
+
+ def to_quaternion(self):
+ """
+ this satisfies Pose(*to_quaternion(p)) == p
+ """
+ q_reverted = np.array([self.q[3], self.q[0], self.q[1], self.q[2]])
+ return np.concatenate([self.p, q_reverted])
+
+ def to_axis_angle(self):
+ """
+ returns the axis-angle representation of the rotation.
+ """
+ angle = 2 * np.arccos(self.q[3])
+ angle = angle / np.pi
+ if angle > 1:
+ angle -= 2
+
+ axis = self.q[:3] / np.linalg.norm(self.q[:3])
+
+ # keep the axes positive
+ if axis[0] < 0:
+ axis *= -1
+ angle *= -1
+
+ return np.concatenate([self.p, axis, [angle]])
+
+ def to_euler(self):
+ q = np.array(ttf.euler_from_quaternion(self.q))
+ if q[0] > np.pi:
+ q[0] -= 2 * np.pi
+ if q[1] > np.pi:
+ q[1] -= 2 * np.pi
+ if q[2] > np.pi:
+ q[2] -= 2 * np.pi
+
+ q = q / np.pi
+
+ return np.concatenate([self.p, q, [0.0]])
+
+ def to_44_matrix(self):
+ out = np.eye(4)
+ out[:3, :3] = ttf.quaternion_matrix(self.q)[:3, :3]
+ out[:3, 3] = self.p
+ return out
+
+ @staticmethod
+ def from_axis_angle(x, y, z, ax, ay, az, phi):
+ """
+ returns a Pose object from the axis-angle representation of the rotation.
+ """
+
+ phi = phi * np.pi
+ p = np.array([x, y, z])
+ qw = np.cos(phi / 2.0)
+ qx = ax * np.sin(phi / 2.0)
+ qy = ay * np.sin(phi / 2.0)
+ qz = az * np.sin(phi / 2.0)
+
+ return Pose(p[0], p[1], p[2], qw, qx, qy, qz)
+
+ @staticmethod
+ def from_euler(x, y, z, roll, pitch, yaw, _):
+ """
+ returns a Pose object from the euler representation of the rotation.
+ """
+ p = np.array([x, y, z])
+ roll, pitch, yaw = roll * np.pi, pitch * np.pi, yaw * np.pi
+ q = ttf.quaternion_from_euler(roll, pitch, yaw)
+ return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
+
+ @staticmethod
+ def from_quaternion(x, y, z, qw, qx, qy, qz):
+ """
+ returns a Pose object from the quaternion representation of the rotation.
+ """
+ p = np.array([x, y, z])
+ return Pose(p[0], p[1], p[2], qw, qx, qy, qz)
+
+
+def compute_inverse_action(p, p_new, ee_control=False):
+ assert isinstance(p, Pose) and isinstance(p_new, Pose)
+
+ if ee_control:
+ dpose = p.inv() * p_new
+ else:
+ dpose = p_new * p.inv()
+
+ return dpose
+
+
+def compute_forward_action(p, dpose, ee_control=False):
+ assert isinstance(p, Pose) and isinstance(dpose, Pose)
+ dpose = Pose.from_quaternion(*dpose.to_quaternion())
+
+ if ee_control:
+ p_new = p * dpose
+ else:
+ p_new = dpose * p
+
+ return p_new
diff --git a/gello/data_utils/demo_to_gdict.py b/gello/data_utils/demo_to_gdict.py
new file mode 100644
index 0000000..37005b0
--- /dev/null
+++ b/gello/data_utils/demo_to_gdict.py
@@ -0,0 +1,345 @@
+import glob
+import os
+import pickle
+import shutil
+from dataclasses import dataclass
+from typing import Tuple
+
+import numpy as np
+import tyro
+from natsort import natsorted
+from tqdm import tqdm
+
+from gello.data_utils.plot_utils import plot_in_grid
+
+np.set_printoptions(precision=3, suppress=True)
+
+import mediapy as mp
+from gdict.data import DictArray, GDict
+from simple_bc.utils.visualization_utils import make_grid_video_from_numpy
+
+from gello.data_utils.conversion_utils import preproc_obs
+
+# def get_act_bounds(source_dir: str) -> np.ndarray:
+# pkls = natsorted(
+# glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
+# )
+# if len(pkls) <= 30:
+# print(f"Skipping {source_dir} because it has less than 30 frames.")
+# return None
+# pkls = pkls[:-5]
+#
+# scale_factor = None
+# for pkl in pkls:
+# try:
+# with open(pkl, "rb") as f:
+# demo = pickle.load(f)
+# except Exception as e:
+# print(f"Skipping {pkl} because it is corrupted.")
+# print(f"Error: {e}")
+# raise Exception("Corrupted pkl")
+#
+# requested_control = demo.pop("control")
+# curr_scale_factor = np.abs(requested_control)
+# if scale_factor is None:
+# scale_factor = curr_scale_factor
+# else:
+# scale_factor = np.maximum(scale_factor, curr_scale_factor)
+# assert scale_factor is not None
+# return scale_factor
+
+
+def get_act_min_max(source_dir: str) -> Tuple[np.ndarray, np.ndarray]:
+ pkls = natsorted(
+ glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
+ )
+ if len(pkls) <= 30:
+ print(f"Skipping {source_dir} because it has less than 30 frames.")
+ raise RuntimeError("Too few frames")
+ pkls = pkls[:-5]
+
+ scale_min = None
+ scale_max = None
+ for pkl in pkls:
+ try:
+ with open(pkl, "rb") as f:
+ demo = pickle.load(f)
+ except Exception as e:
+ print(f"Skipping {pkl} because it is corrupted.")
+ print(f"Error: {e}")
+ raise Exception("Corrupted pkl")
+
+ requested_control = demo.pop("control")
+ curr_scale_factor = requested_control
+ if scale_min is None:
+ assert scale_max is None
+ scale_min = curr_scale_factor
+ scale_max = curr_scale_factor
+ else:
+ assert scale_max is not None
+ scale_min = np.minimum(scale_min, curr_scale_factor)
+ scale_max = np.maximum(scale_min, curr_scale_factor)
+
+ assert scale_min is not None
+ assert scale_max is not None
+ return scale_min, scale_max
+
+
+def convert_single_demo(
+ source_dir,
+ i,
+ traj_output_dir,
+ rgb_output_dir,
+ depth_output_dir,
+ state_output_dir,
+ action_output_dir,
+ scale_factor,
+ bias_factor,
+):
+ """
+ 1. converts the demo into a gdict
+ 2. visualizes the RGB of the demo
+ 3. visualizes the state + action space of the demo
+ 4. returns these to be collated by the caller.
+ """
+
+ pkls = natsorted(
+ glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
+ )
+ demo_stack = []
+
+ if len(pkls) <= 30:
+ return 0
+
+ # go through the demo in reverse order.
+ # remove the first few frames because they are not useful.
+ pkls = pkls[:-5]
+
+ for pkl in pkls:
+ curr_ts = {}
+ try:
+ with open(pkl, "rb") as f:
+ demo = pickle.load(f)
+ except:
+ print(f"Skipping {pkl} because it is corrupted.")
+ return 0
+
+ obs = preproc_obs(demo)
+ action = demo.pop("control")
+ action = (action - bias_factor) / scale_factor # normalize between -1 and 1
+
+ curr_ts["obs"] = obs
+ curr_ts["actions"] = action
+ curr_ts["dones"] = np.zeros(1) # random fill
+ curr_ts["episode_dones"] = np.zeros(1) # random fill
+
+ curr_ts_wrapped = dict()
+ curr_ts_wrapped[f"traj_{i}"] = curr_ts
+ demo_stack = [curr_ts_wrapped] + demo_stack
+
+ demo_dict = DictArray.stack(demo_stack)
+ GDict.to_hdf5(demo_dict, os.path.join(traj_output_dir + "", f"traj_{i}.h5"))
+
+ ## save the base videos
+ # save the base rgb and depth videos
+ all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 1].transpose([0, 2, 3, 1])
+ all_rgbs = all_rgbs.astype(np.uint8)
+ _, H, W, _ = all_rgbs.shape
+ all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 1].reshape([-1, H, W])
+ all_depths = all_depths / 5.0 # scale to 0-1
+
+ mp.write_video(
+ os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_base.mp4"), all_rgbs, fps=30
+ )
+ mp.write_video(
+ os.path.join(depth_output_dir + "", f"traj_{i}_depth_base.mp4"),
+ all_depths,
+ fps=30,
+ )
+
+ ## save the wrist videos
+ # save the rgb and depth videos
+ all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 0].transpose([0, 2, 3, 1])
+ all_rgbs = all_rgbs.astype(np.uint8)
+ _, H, W, _ = all_rgbs.shape
+ all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 0].reshape([-1, H, W])
+ all_depths = all_depths / 2.0 # scale to 0-1
+
+ mp.write_video(
+ os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_wrist.mp4"), all_rgbs, fps=30
+ )
+ mp.write_video(
+ os.path.join(depth_output_dir + "", f"traj_{i}_depth_wrist.mp4"),
+ all_depths,
+ fps=30,
+ )
+ ##
+
+ all_depths = np.tile(all_depths[..., None], [1, 1, 1, 3])
+
+ # save the state and action plots
+ all_actions = demo_dict[f"traj_{i}"]["actions"]
+ all_states = demo_dict[f"traj_{i}"]["obs"]["state"]
+
+ curr_actions = all_actions.reshape([1, *all_actions.shape])
+ curr_states = all_states.reshape([-1, *all_states.shape])
+
+ plot_in_grid(
+ curr_actions, os.path.join(action_output_dir + "", f"traj_{i}_actions.png")
+ )
+ plot_in_grid(
+ curr_states, os.path.join(state_output_dir + "", f"traj_{i}_states.png")
+ )
+
+ return all_rgbs, all_depths, all_actions, all_states
+
+
+@dataclass
+class Args:
+ source_dir: str
+ vis: bool = True
+
+
+def main(args):
+ subdirs = natsorted(glob.glob(os.path.join(args.source_dir, "*/"), recursive=True))
+
+ output_dir = args.source_dir
+ if output_dir[-1] == "/":
+ output_dir = output_dir[:-1]
+
+ output_dir = os.path.join(output_dir, "_conv")
+
+ if not os.path.isdir(output_dir):
+ os.mkdir(output_dir)
+
+ output_dir = os.path.join(output_dir, "multiview")
+
+ if not os.path.isdir(output_dir):
+ os.mkdir(output_dir)
+ else:
+ print(f"Output directory {output_dir} already exists, and will be deleted")
+ shutil.rmtree(output_dir)
+ os.mkdir(output_dir)
+
+ train_dir = os.path.join(output_dir, "train")
+ val_dir = os.path.join(output_dir, "val")
+
+ if not os.path.isdir(train_dir):
+ os.mkdir(train_dir)
+ if not os.path.isdir(val_dir):
+ os.mkdir(val_dir)
+
+ val_size = int(min(0.1 * len(subdirs), 10))
+ val_indices = np.random.choice(len(subdirs), size=val_size, replace=False)
+ val_indices = set(val_indices)
+
+ print("Computing scale factors")
+ pbar = tqdm(range(len(subdirs)))
+ min_scale_factor = None
+ max_scale_factor = None
+ for i in pbar:
+ try:
+ curr_min, curr_max = get_act_min_max(subdirs[i])
+ if min_scale_factor is None:
+ assert max_scale_factor is None
+ min_scale_factor = curr_min
+ max_scale_factor = curr_max
+ else:
+ assert max_scale_factor is not None
+ min_scale_factor = np.minimum(min_scale_factor, curr_min)
+ max_scale_factor = np.maximum(max_scale_factor, curr_max)
+ pbar.set_description(f"t: {i}")
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Skipping {subdirs[i]}")
+ continue
+ bias_factor = (min_scale_factor + max_scale_factor) / 2.0
+ scale_factor = (max_scale_factor - min_scale_factor) / 2.0
+ scale_factor[scale_factor == 0] = 1.0
+ print("*" * 80)
+ print(f"scale factors: {scale_factor}")
+ print(f"bias factor: {bias_factor}")
+ # make it into a copy pasteable string where the numbers are separated by commas
+ scale_factor_str = ", ".join([f"{x}" for x in scale_factor])
+ print(f"scale_factor = np.array([{scale_factor_str}])")
+ bias_factor_str = ", ".join([f"{x}" for x in bias_factor])
+ print(f"bias_factor = np.array([{bias_factor_str}])")
+ print("*" * 80)
+
+ tot = 0
+
+ all_rgbs = []
+ all_depths = []
+ all_actions = []
+ all_states = []
+
+ vis_dir = os.path.join(output_dir, "vis")
+ state_output_dir = os.path.join(vis_dir, "state")
+ action_output_dir = os.path.join(vis_dir, "action")
+ rgb_output_dir = os.path.join(vis_dir, "rgb")
+ depth_output_dir = os.path.join(vis_dir, "depth")
+
+ if not os.path.isdir(vis_dir):
+ os.mkdir(vis_dir)
+ if not os.path.isdir(state_output_dir):
+ os.mkdir(state_output_dir)
+ if not os.path.isdir(action_output_dir):
+ os.mkdir(action_output_dir)
+ if not os.path.isdir(rgb_output_dir):
+ os.mkdir(rgb_output_dir)
+ if not os.path.isdir(depth_output_dir):
+ os.mkdir(depth_output_dir)
+
+ pbar = tqdm(range(len(subdirs)))
+ for i in pbar:
+ out_dir = val_dir if i in val_indices else train_dir
+ out_dir = os.path.join(out_dir, "none")
+
+ if not os.path.isdir(out_dir):
+ os.mkdir(out_dir)
+
+ ret = convert_single_demo(
+ subdirs[i],
+ i,
+ out_dir,
+ rgb_output_dir,
+ depth_output_dir,
+ state_output_dir,
+ action_output_dir,
+ scale_factor=scale_factor,
+ bias_factor=bias_factor,
+ )
+
+ if ret != 0:
+ all_rgbs.append(ret[0])
+ all_depths.append(ret[1])
+ all_actions.append(ret[2])
+ all_states.append(ret[3])
+ tot += 1
+
+ pbar.set_description(f"t: {i}")
+
+ print(
+ f"Finished converting all demos to {output_dir}! (num demos: {tot} / {len(subdirs)})"
+ )
+
+ if args.vis:
+ if len(all_rgbs) > 0:
+ print(f"Visualizing all demos...")
+
+ plot_in_grid(
+ all_actions, os.path.join(action_output_dir, "_all_actions.png")
+ )
+ plot_in_grid(all_states, os.path.join(state_output_dir, "_all_states.png"))
+ make_grid_video_from_numpy(
+ all_rgbs, 10, os.path.join(rgb_output_dir, "_all_rgb.mp4"), fps=30
+ )
+ make_grid_video_from_numpy(
+ all_depths, 10, os.path.join(depth_output_dir, "_all_depth.mp4"), fps=30
+ )
+
+ exit(0)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/gello/data_utils/format_obs.py b/gello/data_utils/format_obs.py
new file mode 100644
index 0000000..005f839
--- /dev/null
+++ b/gello/data_utils/format_obs.py
@@ -0,0 +1,22 @@
+import datetime
+import pickle
+from pathlib import Path
+from typing import Dict
+
+import numpy as np
+
+
+def save_frame(
+ folder: Path,
+ timestamp: datetime.datetime,
+ obs: Dict[str, np.ndarray],
+ action: np.ndarray,
+) -> None:
+ obs["control"] = action # add action to obs
+
+ # make folder if it doesn't exist
+ folder.mkdir(exist_ok=True, parents=True)
+ recorded_file = folder / (timestamp.isoformat() + ".pkl")
+
+ with open(recorded_file, "wb") as f:
+ pickle.dump(obs, f)
diff --git a/gello/data_utils/keyboard_interface.py b/gello/data_utils/keyboard_interface.py
new file mode 100644
index 0000000..e1f79a2
--- /dev/null
+++ b/gello/data_utils/keyboard_interface.py
@@ -0,0 +1,59 @@
+import pygame
+
+NORMAL = (128, 128, 128)
+RED = (255, 0, 0)
+GREEN = (0, 255, 0)
+
+KEY_START = pygame.K_s
+KEY_CONTINUE = pygame.K_c
+KEY_QUIT_RECORDING = pygame.K_q
+
+
+class KBReset:
+ def __init__(self):
+ pygame.init()
+ self._screen = pygame.display.set_mode((800, 800))
+ self._set_color(NORMAL)
+ self._saved = False
+
+ def update(self) -> str:
+ pressed_last = self._get_pressed()
+ if KEY_QUIT_RECORDING in pressed_last:
+ self._set_color(RED)
+ self._saved = False
+ return "normal"
+
+ if self._saved:
+ return "save"
+
+ if KEY_START in pressed_last:
+ self._set_color(GREEN)
+ self._saved = True
+ return "start"
+
+ self._set_color(NORMAL)
+ return "normal"
+
+ def _get_pressed(self):
+ pressed = []
+ pygame.event.pump()
+ for event in pygame.event.get():
+ if event.type == pygame.KEYDOWN:
+ pressed.append(event.key)
+ return pressed
+
+ def _set_color(self, color):
+ self._screen.fill(color)
+ pygame.display.flip()
+
+
+def main():
+ kb = KBReset()
+ while True:
+ state = kb.update()
+ if state == "start":
+ print("start")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/data_utils/plot_utils.py b/gello/data_utils/plot_utils.py
new file mode 100644
index 0000000..e7eb79e
--- /dev/null
+++ b/gello/data_utils/plot_utils.py
@@ -0,0 +1,103 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def plot_in_grid(vals: np.ndarray, save_path: str):
+ """Plot the trajectories in a grid.
+
+ Args:
+ vals: B x T x N, where
+ B is the number of trajectories,
+ T is the number of timesteps,
+ N is the dimensionality of the values.
+ save_path: path to save the plot.
+ """
+ B = len(vals)
+ N = vals[0].shape[-1]
+ # fig, axes = plt.subplots(2, 4, figsize=(20, 10))
+ rows = (N // 4) + (N % 4 > 0)
+ fig, axes = plt.subplots(rows, 4, figsize=(20, 10))
+ for b in range(B):
+ curr = vals[b]
+ for i in range(N):
+ T = curr.shape[0]
+ # give them transparency
+ axes[i // 4, i % 4].plot(np.arange(T), curr[:, i], alpha=0.5)
+
+ for i in range(N):
+ axes[i // 4, i % 4].set_title(f"Dim {i}")
+ axes[i // 4, i % 4].set_ylim([-1.0, 1.0])
+
+ plt.savefig(save_path)
+ plt.close()
+
+ fig = plt.figure(figsize=(20, 10))
+ ax = fig.add_subplot(141, projection="3d")
+ for b in range(B):
+ curr = vals[b]
+ ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
+ # scatter the start and end points
+ ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
+ ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ ax.set_zlabel("z")
+ ax.legend(["Trajectory", "Start", "End"])
+ ax.set_xlim([-1, 1])
+ ax.set_ylim([-1, 1])
+ ax.set_zlim([-1, 1])
+
+ # get the 2D view of the XY plane, with X pointing downwards
+ ax.view_init(270, 0)
+
+ ax = fig.add_subplot(142, projection="3d")
+ for b in range(B):
+ curr = vals[b]
+ ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
+ ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
+ ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ ax.set_zlabel("z")
+ ax.legend(["Trajectory", "Start", "End"])
+ ax.set_xlim([-1, 1])
+ ax.set_ylim([-1, 1])
+ ax.set_zlim([-1, 1])
+
+ # get the 2D view of the XZ plane, with X pointing leftwards
+ ax.view_init(0, 0)
+
+ ax = fig.add_subplot(143, projection="3d")
+ for b in range(B):
+ curr = vals[b]
+ ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
+ ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
+ ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ ax.set_zlabel("z")
+ ax.legend(["Trajectory", "Start", "End"])
+ ax.set_xlim([-1, 1])
+ ax.set_ylim([-1, 1])
+ ax.set_zlim([-1, 1])
+
+ # get the 2D view of the YZ plane, with Y pointing leftwards
+ ax.view_init(0, 90)
+
+ ax = fig.add_subplot(144, projection="3d")
+ for b in range(B):
+ curr = vals[b]
+ ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
+ ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
+ ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
+ ax.set_xlabel("x")
+ ax.set_ylabel("y")
+ ax.set_zlabel("z")
+ ax.legend(["Trajectory", "Start", "End"])
+ ax.set_xlim([-1, 1])
+ ax.set_ylim([-1, 1])
+ ax.set_zlim([-1, 1])
+
+ plt.savefig(save_path[:-4] + "_3d.png")
+
+ plt.close()
diff --git a/gello/dm_control_tasks/arenas/__init__.py b/gello/dm_control_tasks/arenas/__init__.py
new file mode 100644
index 0000000..06ca331
--- /dev/null
+++ b/gello/dm_control_tasks/arenas/__init__.py
@@ -0,0 +1,5 @@
+from gello.dm_control_tasks.arenas.base import Arena
+
+__all__ = [
+ "Arena",
+]
diff --git a/gello/dm_control_tasks/arenas/arena.xml b/gello/dm_control_tasks/arenas/arena.xml
new file mode 100644
index 0000000..fdff4cc
--- /dev/null
+++ b/gello/dm_control_tasks/arenas/arena.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/gello/dm_control_tasks/arenas/base.py b/gello/dm_control_tasks/arenas/base.py
new file mode 100644
index 0000000..af2194d
--- /dev/null
+++ b/gello/dm_control_tasks/arenas/base.py
@@ -0,0 +1,26 @@
+"""Arena composer class."""
+
+from pathlib import Path
+
+from dm_control import composer, mjcf
+
+_ARENA_XML = Path(__file__).resolve().parent / "arena.xml"
+
+
+class Arena(composer.Entity):
+ """Base arena class."""
+
+ def _build(self, name: str = "arena") -> None:
+ self._mjcf_root = mjcf.from_path(str(_ARENA_XML))
+ if name is not None:
+ self._mjcf_root.model = name
+
+ def add_free_entity(self, entity) -> mjcf.Element:
+ """Includes an entity as a free moving body, i.e., with a freejoint."""
+ frame = self.attach(entity)
+ frame.add("freejoint")
+ return frame
+
+ @property
+ def mjcf_model(self) -> mjcf.RootElement:
+ return self._mjcf_root
diff --git a/gello/dm_control_tasks/arms/__init__.py b/gello/dm_control_tasks/arms/__init__.py
new file mode 100644
index 0000000..33da9e7
--- /dev/null
+++ b/gello/dm_control_tasks/arms/__init__.py
@@ -0,0 +1,7 @@
+from gello.dm_control_tasks.arms.franka import Franka
+from gello.dm_control_tasks.arms.ur5e import UR5e
+
+__all__ = [
+ "UR5e",
+ "Franka",
+]
diff --git a/gello/dm_control_tasks/arms/franka.py b/gello/dm_control_tasks/arms/franka.py
new file mode 100644
index 0000000..ef712ad
--- /dev/null
+++ b/gello/dm_control_tasks/arms/franka.py
@@ -0,0 +1,28 @@
+"""Franka composer class."""
+
+from pathlib import Path
+from typing import Optional, Union
+
+from dm_control import mjcf
+
+from gello.dm_control_tasks.arms.manipulator import Manipulator
+from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
+
+XML = MENAGERIE_ROOT / "franka_emika_panda" / "panda_nohand.xml"
+GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
+
+
+class Franka(Manipulator):
+ """Franka Robot."""
+
+ def _build(
+ self,
+ name: str = "franka",
+ xml_path: Union[str, Path] = XML,
+ gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
+ ) -> None:
+ super()._build(name="franka", xml_path=XML, gripper_xml_path=GRIPPER_XML)
+
+ @property
+ def flange(self) -> mjcf.Element:
+ return self._mjcf_root.find("site", "attachment_site")
diff --git a/gello/dm_control_tasks/arms/manipulator.py b/gello/dm_control_tasks/arms/manipulator.py
new file mode 100644
index 0000000..6f251e8
--- /dev/null
+++ b/gello/dm_control_tasks/arms/manipulator.py
@@ -0,0 +1,229 @@
+"""Manipulator composer class."""
+import abc
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+from dm_control import composer, mjcf
+from dm_control.composer.observation import observable
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.suite.utils.randomizers import random_limited_quaternion
+
+from gello.dm_control_tasks import mjcf_utils
+
+
+def attach_hand_to_arm(
+ arm_mjcf: mjcf.RootElement,
+ hand_mjcf: mjcf.RootElement,
+ # attach_site: str,
+) -> None:
+ """Attaches a hand to an arm.
+
+ The arm must have a site named "attachment_site".
+
+ Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
+
+ Args:
+ arm_mjcf: The mjcf.RootElement of the arm.
+ hand_mjcf: The mjcf.RootElement of the hand.
+ attach_site: The name of the site to attach the hand to.
+
+ Raises:
+ ValueError: If the arm does not have a site named "attachment_site".
+ """
+ physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
+
+ # attachment_site = arm_mjcf.find("site", attach_site)
+ attachment_site = arm_mjcf.find("site", "attachment_site")
+ if attachment_site is None:
+ raise ValueError("No attachment site found in the arm model.")
+
+ # Expand the ctrl and qpos keyframes to account for the new hand DoFs.
+ arm_key = arm_mjcf.find("key", "home")
+ if arm_key is not None:
+ hand_key = hand_mjcf.find("key", "home")
+ if hand_key is None:
+ arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
+ arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
+ else:
+ arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
+ arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
+
+ attachment_site.attach(hand_mjcf)
+
+
+class Manipulator(composer.Entity, abc.ABC):
+ """A manipulator entity."""
+
+ def _build(
+ self,
+ name: str,
+ xml_path: Union[str, Path],
+ gripper_xml_path: Optional[Union[str, Path]],
+ ) -> None:
+ """Builds the manipulator.
+
+ Subclasses can not override this method, but should call this method in their
+ own _build() method.
+ """
+ self._mjcf_root = mjcf.from_path(str(xml_path))
+ self._arm_joints = tuple(mjcf_utils.safe_find_all(self._mjcf_root, "joint"))
+ if gripper_xml_path:
+ gripper_mjcf = mjcf.from_path(str(gripper_xml_path))
+ attach_hand_to_arm(self._mjcf_root, gripper_mjcf)
+
+ self._mjcf_root.model = name
+ self._add_mjcf_elements()
+
+ def set_joints(self, physics: mjcf.Physics, joints: np.ndarray) -> None:
+ assert len(joints) == len(self._arm_joints)
+ for joint, joint_value in zip(self._arm_joints, joints):
+ joint_id = physics.bind(joint).element_id
+ joint_name = physics.model.id2name(joint_id, "joint")
+ physics.named.data.qpos[joint_name] = joint_value
+
+ def randomize_joints(
+ self,
+ physics: mjcf.Physics,
+ random: Optional[np.random.RandomState] = None,
+ ) -> None:
+ random = random or np.random # type: ignore
+ assert random is not None
+ hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
+ slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
+ ball = mjbindings.enums.mjtJoint.mjJNT_BALL
+ free = mjbindings.enums.mjtJoint.mjJNT_FREE
+
+ qpos = physics.named.data.qpos
+
+ for joint in self._arm_joints:
+ joint_id = physics.bind(joint).element_id
+ # joint_id = physics.model.name2id(joint.name, "joint")
+ joint_name = physics.model.id2name(joint_id, "joint")
+ joint_type = physics.model.jnt_type[joint_id]
+ is_limited = physics.model.jnt_limited[joint_id]
+ range_min, range_max = physics.model.jnt_range[joint_id]
+
+ if is_limited:
+ if joint_type in [hinge, slide]:
+ qpos[joint_name] = random.uniform(range_min, range_max)
+
+ elif joint_type == ball:
+ qpos[joint_name] = random_limited_quaternion(random, range_max)
+
+ else:
+ if joint_type == hinge:
+ qpos[joint_name] = random.uniform(-np.pi, np.pi)
+
+ elif joint_type == ball:
+ quat = random.randn(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name] = quat
+
+ elif joint_type == free:
+ # this should be random.randn, but changing it now could significantly
+ # affect benchmark results.
+ quat = random.rand(4)
+ quat /= np.linalg.norm(quat)
+ qpos[joint_name][3:] = quat
+
+ def _add_mjcf_elements(self) -> None:
+ # Parse joints.
+ joints = mjcf_utils.safe_find_all(self._mjcf_root, "joint")
+ joints = [joint for joint in joints if joint.tag != "freejoint"]
+ self._joints = tuple(joints)
+
+ # Parse actuators.
+ actuators = mjcf_utils.safe_find_all(self._mjcf_root, "actuator")
+ self._actuators = tuple(actuators)
+
+ # Parse qpos / ctrl keyframes.
+ self._keyframes = {}
+ keyframes = mjcf_utils.safe_find_all(self._mjcf_root, "key")
+ if keyframes:
+ for frame in keyframes:
+ if frame.qpos is not None:
+ qpos = np.array(frame.qpos)
+ self._keyframes[frame.name] = qpos
+
+ # add a visualizeation the flange position that is green
+ self.flange.parent.add(
+ "geom",
+ name="flange_geom",
+ type="sphere",
+ size="0.01",
+ rgba="0 1 0 1",
+ pos=self.flange.pos,
+ contype="0",
+ conaffinity="0",
+ )
+
+ def _build_observables(self):
+ return ArmObservables(self)
+
+ @property
+ @abc.abstractmethod
+ def flange(self) -> mjcf.Element:
+ """Returns the flange element.
+
+ The flange is the end effector of the manipulator where tools can be
+ attached, such as a gripper.
+ """
+
+ @property
+ def mjcf_model(self) -> mjcf.RootElement:
+ return self._mjcf_root
+
+ @property
+ def name(self) -> str:
+ return self._mjcf_root.model
+
+ @property
+ def joints(self) -> Tuple[mjcf.Element, ...]:
+ return self._joints
+
+ @property
+ def actuators(self) -> Tuple[mjcf.Element, ...]:
+ return self._actuators
+
+ @property
+ def keyframes(self) -> Dict[str, np.ndarray]:
+ return self._keyframes
+
+
+class ArmObservables(composer.Observables):
+ """Base class for quadruped observables."""
+
+ @composer.observable
+ def joints_pos(self):
+ return observable.MJCFFeature("qpos", self._entity.joints)
+
+ @composer.observable
+ def joints_vel(self):
+ return observable.MJCFFeature("qvel", self._entity.joints)
+
+ @composer.observable
+ def flange_position(self):
+ return observable.MJCFFeature("xpos", self._entity.flange)
+
+ @composer.observable
+ def flange_orientation(self):
+ return observable.MJCFFeature("xmat", self._entity.flange)
+
+ # Semantic grouping of observables.
+ def _collect_from_attachments(self, attribute_name: str):
+ out: List[observable.MJCFFeature] = []
+ for entity in self._entity.iter_entities(exclude_self=True):
+ out.extend(getattr(entity.observables, attribute_name, []))
+ return out
+
+ @property
+ def proprioception(self):
+ return [
+ self.joints_pos,
+ self.joints_vel,
+ self.flange_position,
+ # self.flange_orientation,
+ # self.flange_velocity,
+ self.flange_angular_velocity,
+ ] + self._collect_from_attachments("proprioception")
diff --git a/gello/dm_control_tasks/arms/ur5e.py b/gello/dm_control_tasks/arms/ur5e.py
new file mode 100644
index 0000000..8ab4aba
--- /dev/null
+++ b/gello/dm_control_tasks/arms/ur5e.py
@@ -0,0 +1,26 @@
+"""UR5e composer class."""
+
+from pathlib import Path
+from typing import Optional, Union
+
+from dm_control import mjcf
+
+from gello.dm_control_tasks.arms.manipulator import Manipulator
+from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
+
+
+class UR5e(Manipulator):
+ GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
+ XML = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
+
+ def _build(
+ self,
+ name: str = "UR5e",
+ xml_path: Union[str, Path] = XML,
+ gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
+ ) -> None:
+ super()._build(name=name, xml_path=xml_path, gripper_xml_path=gripper_xml_path)
+
+ @property
+ def flange(self) -> mjcf.Element:
+ return self._mjcf_root.find("site", "attachment_site")
diff --git a/gello/dm_control_tasks/arms/ur5e_test.py b/gello/dm_control_tasks/arms/ur5e_test.py
new file mode 100644
index 0000000..78bfdc5
--- /dev/null
+++ b/gello/dm_control_tasks/arms/ur5e_test.py
@@ -0,0 +1,22 @@
+"""Tests for ur5e.py."""
+
+from absl.testing import absltest
+from dm_control import mjcf
+
+from gello.dm_control_tasks.arms import ur5e
+
+
+class UR5eTest(absltest.TestCase):
+ def test_compiles_and_steps(self) -> None:
+ robot = ur5e.UR5e()
+ physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
+ physics.step()
+
+ def test_joints(self) -> None:
+ robot = ur5e.UR5e()
+ for joint in robot.joints:
+ self.assertEqual(joint.tag, "joint")
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/gello/dm_control_tasks/arms/utils.py b/gello/dm_control_tasks/arms/utils.py
new file mode 100644
index 0000000..4e778de
--- /dev/null
+++ b/gello/dm_control_tasks/arms/utils.py
@@ -0,0 +1,261 @@
+import collections
+
+import numpy as np
+from absl import logging
+from dm_control.mujoco.wrapper import mjbindings
+
+mjlib = mjbindings.mjlib
+
+
+_INVALID_JOINT_NAMES_TYPE = (
+ "`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}."
+)
+_REQUIRE_TARGET_POS_OR_QUAT = (
+ "At least one of `target_pos` or `target_quat` must be specified."
+)
+
+IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"])
+
+
+def qpos_from_site_pose(
+ physics,
+ site_name,
+ target_pos=None,
+ target_quat=None,
+ joint_names=None,
+ tol=1e-14,
+ rot_weight=1.0,
+ regularization_threshold=0.1,
+ regularization_strength=3e-2,
+ max_update_norm=2.0,
+ progress_thresh=20.0,
+ max_steps=100,
+ inplace=False,
+):
+ """Find joint positions that satisfy a target site position and/or rotation.
+
+ Args:
+ physics: A `mujoco.Physics` instance.
+ site_name: A string specifying the name of the target site.
+ target_pos: A (3,) numpy array specifying the desired Cartesian position of
+ the site, or None if the position should be unconstrained (default).
+ One or both of `target_pos` or `target_quat` must be specified.
+ target_quat: A (4,) numpy array specifying the desired orientation of the
+ site as a quaternion, or None if the orientation should be unconstrained
+ (default). One or both of `target_pos` or `target_quat` must be specified.
+ joint_names: (optional) A list, tuple or numpy array specifying the names of
+ one or more joints that can be manipulated in order to achieve the target
+ site pose. If None (default), all joints may be manipulated.
+ tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
+ in the stopping criterion).
+ rot_weight: (optional) Determines the weight given to rotational error
+ relative to translational error.
+ regularization_threshold: (optional) L2 regularization will be used when
+ inverting the Jacobian whilst `err_norm` is greater than this value.
+ regularization_strength: (optional) Coefficient of the quadratic penalty
+ on joint movements.
+ max_update_norm: (optional) The maximum L2 norm of the update applied to
+ the joint positions on each iteration. The update vector will be scaled
+ such that its magnitude never exceeds this value.
+ progress_thresh: (optional) If `err_norm` divided by the magnitude of the
+ joint position update is greater than this value then the optimization
+ will terminate prematurely. This is a useful heuristic to avoid getting
+ stuck in local minima.
+ max_steps: (optional) The maximum number of iterations to perform.
+ inplace: (optional) If True, `physics.data` will be modified in place.
+ Default value is False, i.e. a copy of `physics.data` will be made.
+
+ Returns:
+ An `IKResult` namedtuple with the following fields:
+ qpos: An (nq,) numpy array of joint positions.
+ err_norm: A float, the weighted sum of L2 norms for the residual
+ translational and rotational errors.
+ steps: An int, the number of iterations that were performed.
+ success: Boolean, True if we converged on a solution within `max_steps`,
+ False otherwise.
+
+ Raises:
+ ValueError: If both `target_pos` and `target_quat` are None, or if
+ `joint_names` has an invalid type.
+ """
+ dtype = physics.data.qpos.dtype
+
+ if target_pos is not None and target_quat is not None:
+ jac = np.empty((6, physics.model.nv), dtype=dtype)
+ err = np.empty(6, dtype=dtype)
+ jac_pos, jac_rot = jac[:3], jac[3:]
+ err_pos, err_rot = err[:3], err[3:]
+ else:
+ jac = np.empty((3, physics.model.nv), dtype=dtype)
+ err = np.empty(3, dtype=dtype)
+ if target_pos is not None:
+ jac_pos, jac_rot = jac, None
+ err_pos, err_rot = err, None
+ elif target_quat is not None:
+ jac_pos, jac_rot = None, jac
+ err_pos, err_rot = None, err
+ else:
+ raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
+
+ update_nv = np.zeros(physics.model.nv, dtype=dtype)
+
+ if target_quat is not None:
+ site_xquat = np.empty(4, dtype=dtype)
+ neg_site_xquat = np.empty(4, dtype=dtype)
+ err_rot_quat = np.empty(4, dtype=dtype)
+
+ if not inplace:
+ physics = physics.copy(share_model=True)
+
+ # Ensure that the Cartesian position of the site is up to date.
+ mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
+
+ # Convert site name to index.
+ site_id = physics.model.name2id(site_name, "site")
+
+ # These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
+ # update them in place, so we can avoid indexing overhead in the main loop.
+ site_xpos = physics.named.data.site_xpos[site_name]
+ site_xmat = physics.named.data.site_xmat[site_name]
+
+ # This is an index into the rows of `update` and the columns of `jac`
+ # that selects DOFs associated with joints that we are allowed to manipulate.
+ if joint_names is None:
+ dof_indices = slice(None) # Update all DOFs.
+ elif isinstance(joint_names, (list, np.ndarray, tuple)):
+ if isinstance(joint_names, tuple):
+ joint_names = list(joint_names)
+ # Find the indices of the DOFs belonging to each named joint. Note that
+ # these are not necessarily the same as the joint IDs, since a single joint
+ # may have >1 DOF (e.g. ball joints).
+ indexer = physics.named.model.dof_jntid.axes.row
+ # `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
+ # indexer to map each joint name to the indices of its corresponding DOFs.
+ dof_indices = indexer.convert_key_item(joint_names)
+ else:
+ raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
+
+ steps = 0
+ success = False
+
+ for steps in range(max_steps):
+ err_norm = 0.0
+
+ if target_pos is not None:
+ # Translational error.
+ err_pos[:] = target_pos - site_xpos
+ err_norm += np.linalg.norm(err_pos)
+ if target_quat is not None:
+ # Rotational error.
+ mjlib.mju_mat2Quat(site_xquat, site_xmat)
+ mjlib.mju_negQuat(neg_site_xquat, site_xquat)
+ mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
+ mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
+ err_norm += np.linalg.norm(err_rot) * rot_weight
+
+ if err_norm < tol:
+ logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm)
+ success = True
+ break
+ else:
+ # TODO(b/112141670): Generalize this to other entities besides sites.
+ mjlib.mj_jacSite(
+ physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id
+ )
+ jac_joints = jac[:, dof_indices]
+
+ # TODO(b/112141592): This does not take joint limits into consideration.
+ reg_strength = (
+ regularization_strength if err_norm > regularization_threshold else 0.0
+ )
+ update_joints = nullspace_method(
+ jac_joints, err, regularization_strength=reg_strength
+ )
+
+ update_norm = np.linalg.norm(update_joints)
+
+ # Check whether we are still making enough progress, and halt if not.
+ progress_criterion = err_norm / update_norm
+ if progress_criterion > progress_thresh:
+ logging.debug(
+ "Step %2i: err_norm / update_norm (%3g) > "
+ "tolerance (%3g). Halting due to insufficient progress",
+ steps,
+ progress_criterion,
+ progress_thresh,
+ )
+ break
+
+ if update_norm > max_update_norm:
+ update_joints *= max_update_norm / update_norm
+
+ # Write the entries for the specified joints into the full `update_nv`
+ # vector.
+ update_nv[dof_indices] = update_joints
+
+ # Update `physics.qpos`, taking quaternions into account.
+ mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
+
+ # Compute the new Cartesian position of the site.
+ mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
+
+ logging.debug(
+ "Step %2i: err_norm=%-10.3g update_norm=%-10.3g",
+ steps,
+ err_norm,
+ update_norm,
+ )
+
+ if not success and steps == max_steps - 1:
+ logging.warning(
+ "Failed to converge after %i steps: err_norm=%3g", steps, err_norm
+ )
+
+ if not inplace:
+ # Our temporary copy of physics.data is about to go out of scope, and when
+ # it does the underlying mjData pointer will be freed and physics.data.qpos
+ # will be a view onto a block of deallocated memory. We therefore need to
+ # make a copy of physics.data.qpos while physics.data is still alive.
+ qpos = physics.data.qpos.copy()
+ else:
+ # If we're modifying physics.data in place then it's fine to return a view.
+ qpos = physics.data.qpos
+
+ return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
+
+
+def nullspace_method(jac_joints, delta, regularization_strength=0.0):
+ """Calculates the joint velocities to achieve a specified end effector delta.
+
+ Args:
+ jac_joints: The Jacobian of the end effector with respect to the joints. A
+ numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
+ and `nv` is the number of degrees of freedom.
+ delta: The desired end-effector delta. A numpy array of shape `(3,)` or
+ `(6,)` containing either position deltas, rotation deltas, or both.
+ regularization_strength: (optional) Coefficient of the quadratic penalty
+ on joint movements. Default is zero, i.e. no regularization.
+
+ Returns:
+ An `(nv,)` numpy array of joint velocities.
+
+ Reference:
+ Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
+ transpose, pseudoinverse and damped least squares methods.
+ https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
+ """
+ hess_approx = jac_joints.T.dot(jac_joints)
+ joint_delta = jac_joints.T.dot(delta)
+ if regularization_strength > 0:
+ # L2 regularization
+ hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
+ return np.linalg.solve(hess_approx, joint_delta)
+ else:
+ return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
+
+
+class InverseKinematics:
+ def __init__(self, xml_path: str):
+ """Initializes the inverse kinematics class."""
+ ...
+ # TODO
diff --git a/gello/dm_control_tasks/manipulation/__init__.py b/gello/dm_control_tasks/manipulation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gello/dm_control_tasks/manipulation/arenas/floors.py b/gello/dm_control_tasks/manipulation/arenas/floors.py
new file mode 100644
index 0000000..0a67d60
--- /dev/null
+++ b/gello/dm_control_tasks/manipulation/arenas/floors.py
@@ -0,0 +1,111 @@
+"""Simple floor arenas."""
+from typing import Tuple
+
+import numpy as np
+from dm_control import mjcf
+
+from gello.dm_control_tasks.arenas import base
+
+_GROUNDPLANE_QUAD_SIZE = 0.25
+
+
+class FixedManipulationArena(base.Arena):
+ @property
+ def arm_attachment_site(self) -> np.ndarray:
+ return self.attachment_site
+
+
+class Floor(base.Arena):
+ """An arena with a checkered pattern."""
+
+ def _build(
+ self,
+ name: str = "floor",
+ size: Tuple[float, float] = (8, 8),
+ reflectance: float = 0.2,
+ top_camera_y_padding_factor: float = 1.1,
+ top_camera_distance: float = 5.0,
+ ) -> None:
+ super()._build(name=name)
+
+ self._size = size
+ self._top_camera_y_padding_factor = top_camera_y_padding_factor
+ self._top_camera_distance = top_camera_distance
+
+ assert self._mjcf_root.worldbody is not None
+
+ z_offset = 0.00
+
+ # Add arm attachement site
+ self._mjcf_root.worldbody.add(
+ "site",
+ name="arm_attachment",
+ pos=(0, 0, z_offset),
+ size=(0.01, 0.01, 0.01),
+ type="sphere",
+ rgba=(0, 0, 0, 0),
+ )
+
+ # Add light.
+ self._mjcf_root.worldbody.add(
+ "light",
+ pos=(0, 0, 1.5),
+ dir=(0, 0, -1),
+ directional=True,
+ )
+
+ self._ground_texture = self._mjcf_root.asset.add(
+ "texture",
+ rgb1=[0.2, 0.3, 0.4],
+ rgb2=[0.1, 0.2, 0.3],
+ type="2d",
+ builtin="checker",
+ name="groundplane",
+ width=200,
+ height=200,
+ mark="edge",
+ markrgb=[0.8, 0.8, 0.8],
+ )
+
+ self._ground_material = self._mjcf_root.asset.add(
+ "material",
+ name="groundplane",
+ texrepeat=[2, 2], # Makes white squares exactly 1x1 length units.
+ texuniform=True,
+ reflectance=reflectance,
+ texture=self._ground_texture,
+ )
+
+ # Build groundplane.
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ "geom",
+ type="plane",
+ name="groundplane",
+ material=self._ground_material,
+ size=list(size) + [_GROUNDPLANE_QUAD_SIZE],
+ )
+
+ # Choose the FOV so that the floor always fits nicely within the frame
+ # irrespective of actual floor size.
+ fovy_radians = 2 * np.arctan2(
+ top_camera_y_padding_factor * size[1], top_camera_distance
+ )
+ self._top_camera = self._mjcf_root.worldbody.add(
+ "camera",
+ name="top_camera",
+ pos=[0, 0, top_camera_distance],
+ quat=[1, 0, 0, 0],
+ fovy=np.rad2deg(fovy_radians),
+ )
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_geom,)
+
+ @property
+ def size(self):
+ return self._size
+
+ @property
+ def arm_attachment_site(self) -> mjcf.Element:
+ return self._mjcf_root.worldbody.find("site", "arm_attachment")
diff --git a/gello/dm_control_tasks/manipulation/tasks/base.py b/gello/dm_control_tasks/manipulation/tasks/base.py
new file mode 100644
index 0000000..b15fb05
--- /dev/null
+++ b/gello/dm_control_tasks/manipulation/tasks/base.py
@@ -0,0 +1,56 @@
+"""A abstract class for all walker tasks."""
+
+from dm_control import composer, mjcf
+
+from gello.dm_control_tasks.arms.manipulator import Manipulator
+from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
+
+# Timestep of the physics simulation.
+_PHYSICS_TIMESTEP: float = 0.002
+
+# Interval between agent actions, in seconds.
+# We send a control signal every (_CONTROL_TIMESTEP / _PHYSICS_TIMESTEP) physics steps.
+_CONTROL_TIMESTEP: float = 0.02 # 50 Hz.
+
+
+class ManipulationTask(composer.Task):
+ """Base composer task for walker robots."""
+
+ def __init__(
+ self,
+ arm: Manipulator,
+ arena: FixedManipulationArena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP,
+ ) -> None:
+ self._arm = arm
+ self._arena = arena
+
+ self._arm_geoms = arm.root_body.find_all("geom")
+ self._arena_geoms = arena.root_body.find_all("geom")
+ arena.attach(arm, attach_site=arena.arm_attachment_site)
+
+ self.set_timesteps(
+ control_timestep=control_timestep, physics_timestep=physics_timestep
+ )
+
+ # Enable all robot observables. Note: No additional entity observables should
+ # be added in subclasses.
+ for observable in self._arm.observables.proprioception:
+ observable.enabled = True
+
+ def in_collision(self, physics: mjcf.Physics) -> bool:
+ """Checks if the arm is in collision with the floor (self._arena)."""
+ arm_ids = [physics.bind(g).element_id for g in self._arm_geoms]
+ arena_ids = [physics.bind(g).element_id for g in self._arena_geoms]
+
+ return any(
+ (con.geom1 in arm_ids and con.geom2 in arena_ids) # arm and arena
+ or (con.geom1 in arena_ids and con.geom2 in arm_ids) # arena and arm
+ or (con.geom1 in arm_ids and con.geom2 in arm_ids) # self collision
+ for con in physics.data.contact[: physics.data.ncon]
+ )
+
+ @property
+ def root_entity(self):
+ return self._arena
diff --git a/gello/dm_control_tasks/manipulation/tasks/block_play.py b/gello/dm_control_tasks/manipulation/tasks/block_play.py
new file mode 100644
index 0000000..09589a0
--- /dev/null
+++ b/gello/dm_control_tasks/manipulation/tasks/block_play.py
@@ -0,0 +1,122 @@
+"""A task where a walker must learn to stand."""
+from typing import Optional
+
+import numpy as np
+from dm_control import mjcf
+from dm_control.suite.utils.randomizers import random_limited_quaternion
+
+from gello.dm_control_tasks.arms.manipulator import Manipulator
+from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
+from gello.dm_control_tasks.manipulation.tasks import base
+
+_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
+
+
+class BlockPlay(base.ManipulationTask):
+ """Task for a manipulator. Blocks are randomly placed in the scene."""
+
+ def __init__(
+ self,
+ arm: Manipulator,
+ arena: FixedManipulationArena,
+ physics_timestep=base._PHYSICS_TIMESTEP,
+ control_timestep=base._CONTROL_TIMESTEP,
+ num_blocks: int = 10,
+ size: float = 0.03,
+ reset_joints: Optional[np.ndarray] = None,
+ ) -> None:
+ super().__init__(arm, arena, physics_timestep, control_timestep)
+
+ # find key frame
+ key_frames = self.root_entity.mjcf_model.find_all("key")
+ if len(key_frames) == 0:
+ key_frames = None
+ else:
+ key_frames = key_frames[0]
+
+ # Create target.
+ block_joints = []
+ for i in range(num_blocks):
+ # select random colors = np.random.uniform(0, 1, size=3)
+ color = np.concatenate([np.random.uniform(0, 1, size=3), [1.0]])
+
+ # attach a body for block i
+ b = self.root_entity.mjcf_model.worldbody.add(
+ "body", name=f"block_{i}", pos=(0, 0, 0)
+ )
+
+ # # add a freejoint to the block so it can be moved
+ _joint = b.add("freejoint")
+ block_joints.append(_joint)
+
+ # add a geom to the block
+ b.add(
+ "geom",
+ name=f"block_geom_{i}",
+ type="box",
+ size=(size, size, size),
+ rgba=color,
+ # contype=0,
+ # conaffinity=0,
+ )
+ assert key_frames is not None
+ key_frames.qpos = np.concatenate([key_frames.qpos, np.zeros(7)])
+
+ # # save xml to file
+ # _xml_string = self.root_entity.mjcf_model.to_xml_string()
+ # with open("block_play.xml", "w") as f:
+ # f.write(_xml_string)
+
+ self._block_joints = block_joints
+ self._block_size = size
+ self._reset_joints = reset_joints
+
+ def initialize_episode(self, physics, random_state):
+ # Randomly set feasible target position
+ if self._reset_joints is not None:
+ self._arm.set_joints(physics, self._reset_joints)
+ else:
+ self._arm.randomize_joints(physics, random_state)
+ physics.forward()
+
+ # check if arm is in collision with floor
+ while self.in_collision(physics):
+ self._arm.randomize_joints(physics, random_state)
+ physics.forward()
+
+ # Randomize block positions
+ for block_j in self._block_joints:
+ randomize_pose(
+ block_j,
+ physics,
+ random_state=random_state,
+ position_range=0.5,
+ z_offset=self._block_size * 2,
+ )
+
+ physics.forward()
+
+ def get_reward(self, physics):
+ # flange position
+ return 0
+
+
+def randomize_pose(
+ free_joint: mjcf.RootElement,
+ physics: mjcf.Physics,
+ random_state: np.random.RandomState,
+ position_range: float = 0.5,
+ z_offset: float = 0.0,
+) -> None:
+ """Randomize the pose of an entity."""
+ entity_pos = random_state.uniform(-position_range, position_range, size=2)
+ # make x, y farther than 0.1 from 0, 0
+ while np.linalg.norm(entity_pos) < 0.3:
+ entity_pos = random_state.uniform(-position_range, position_range, size=2)
+
+ entity_pos = np.concatenate([entity_pos, [z_offset]])
+ entity_quat = random_limited_quaternion(random_state, limit=np.pi)
+
+ qpos = np.concatenate([entity_pos, entity_quat])
+
+ physics.bind(free_joint).qpos = qpos
diff --git a/gello/dm_control_tasks/manipulation/tasks/reach.py b/gello/dm_control_tasks/manipulation/tasks/reach.py
new file mode 100644
index 0000000..e743e46
--- /dev/null
+++ b/gello/dm_control_tasks/manipulation/tasks/reach.py
@@ -0,0 +1,63 @@
+"""A task where a walker must learn to stand."""
+
+import numpy as np
+from dm_control.suite.utils import randomizers
+from dm_control.utils import rewards
+
+from gello.dm_control_tasks.arms.manipulator import Manipulator
+from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
+from gello.dm_control_tasks.manipulation.tasks import base
+
+_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
+
+
+class Reach(base.ManipulationTask):
+ """Reach task for a manipulator."""
+
+ def __init__(
+ self,
+ arm: Manipulator,
+ arena: FixedManipulationArena,
+ physics_timestep=base._PHYSICS_TIMESTEP,
+ control_timestep=base._CONTROL_TIMESTEP,
+ distance_tolerance: float = 0.5,
+ ) -> None:
+ super().__init__(arm, arena, physics_timestep, control_timestep)
+
+ self._distance_tolerance = distance_tolerance
+
+ # Create target.
+ self._target = self.root_entity.mjcf_model.worldbody.add(
+ "site",
+ name="target",
+ type="sphere",
+ pos=(0, 0, 0),
+ size=(0.1,),
+ rgba=_TARGET_COLOR,
+ )
+
+ def initialize_episode(self, physics, random_state):
+ # Randomly set feasible target position
+ randomizers.randomize_limited_and_rotational_joints(physics, random_state)
+ physics.forward()
+ flange_position = physics.bind(self._arm.flange).xpos[:3]
+ print(flange_position)
+
+ # set target position to flange position
+ physics.bind(self._target).pos = flange_position
+
+ # Randomize initial position of the arm.
+ randomizers.randomize_limited_and_rotational_joints(physics, random_state)
+ physics.forward()
+
+ def get_reward(self, physics):
+ # flange position
+ flange_pos = physics.bind(self._arm.flange).pos[:3]
+ distance = np.linalg.norm(physics.bind(self._target).pos[:3] - flange_pos)
+ return -rewards.tolerance(
+ distance,
+ bounds=(0, self._distance_tolerance),
+ margin=self._distance_tolerance,
+ value_at_margin=0,
+ sigmoid="linear",
+ )
diff --git a/gello/dm_control_tasks/mjcf_utils.py b/gello/dm_control_tasks/mjcf_utils.py
new file mode 100644
index 0000000..1e36cec
--- /dev/null
+++ b/gello/dm_control_tasks/mjcf_utils.py
@@ -0,0 +1,23 @@
+from pathlib import Path
+from typing import List
+
+from dm_control import mjcf
+
+# Path to the root of the project.
+_PROJECT_ROOT: Path = Path(__file__).parent.parent.parent
+
+# Path to the Menagerie submodule.
+MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie"
+
+
+def safe_find_all(
+ root: mjcf.RootElement,
+ feature_name: str,
+ immediate_children_only: bool = False,
+ exclude_attachments: bool = False,
+) -> List[mjcf.Element]:
+ """Find all given elements or throw an error if none are found."""
+ features = root.find_all(feature_name, immediate_children_only, exclude_attachments)
+ if not features:
+ raise ValueError(f"No {feature_name} found in the MJCF model.")
+ return features
diff --git a/gello/dynamixel/__init__.py b/gello/dynamixel/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/gello/dynamixel/driver.py b/gello/dynamixel/driver.py
new file mode 100644
index 0000000..360c64b
--- /dev/null
+++ b/gello/dynamixel/driver.py
@@ -0,0 +1,274 @@
+import time
+from threading import Event, Lock, Thread
+from typing import Protocol, Sequence
+
+import numpy as np
+from dynamixel_sdk.group_sync_read import GroupSyncRead
+from dynamixel_sdk.group_sync_write import GroupSyncWrite
+from dynamixel_sdk.packet_handler import PacketHandler
+from dynamixel_sdk.port_handler import PortHandler
+from dynamixel_sdk.robotis_def import (
+ COMM_SUCCESS,
+ DXL_HIBYTE,
+ DXL_HIWORD,
+ DXL_LOBYTE,
+ DXL_LOWORD,
+)
+
+# Constants
+ADDR_TORQUE_ENABLE = 64
+ADDR_GOAL_POSITION = 116
+LEN_GOAL_POSITION = 4
+ADDR_PRESENT_POSITION = 132
+ADDR_PRESENT_POSITION = 140
+LEN_PRESENT_POSITION = 4
+TORQUE_ENABLE = 1
+TORQUE_DISABLE = 0
+
+
+class DynamixelDriverProtocol(Protocol):
+ def set_joints(self, joint_angles: Sequence[float]):
+ """Set the joint angles for the Dynamixel servos.
+
+ Args:
+ joint_angles (Sequence[float]): A list of joint angles.
+ """
+ ...
+
+ def torque_enabled(self) -> bool:
+ """Check if torque is enabled for the Dynamixel servos.
+
+ Returns:
+ bool: True if torque is enabled, False if it is disabled.
+ """
+ ...
+
+ def set_torque_mode(self, enable: bool):
+ """Set the torque mode for the Dynamixel servos.
+
+ Args:
+ enable (bool): True to enable torque, False to disable.
+ """
+ ...
+
+ def get_joints(self) -> np.ndarray:
+ """Get the current joint angles in radians.
+
+ Returns:
+ np.ndarray: An array of joint angles.
+ """
+ ...
+
+ def close(self):
+ """Close the driver."""
+
+
+class FakeDynamixelDriver(DynamixelDriverProtocol):
+ def __init__(self, ids: Sequence[int]):
+ self._ids = ids
+ self._joint_angles = np.zeros(len(ids), dtype=int)
+ self._torque_enabled = False
+
+ def set_joints(self, joint_angles: Sequence[float]):
+ if len(joint_angles) != len(self._ids):
+ raise ValueError(
+ "The length of joint_angles must match the number of servos"
+ )
+ if not self._torque_enabled:
+ raise RuntimeError("Torque must be enabled to set joint angles")
+ self._joint_angles = np.array(joint_angles)
+
+ def torque_enabled(self) -> bool:
+ return self._torque_enabled
+
+ def set_torque_mode(self, enable: bool):
+ self._torque_enabled = enable
+
+ def get_joints(self) -> np.ndarray:
+ return self._joint_angles.copy()
+
+ def close(self):
+ pass
+
+
+class DynamixelDriver(DynamixelDriverProtocol):
+ def __init__(
+ self, ids: Sequence[int], port: str = "/dev/ttyUSB0", baudrate: int = 57600
+ ):
+ """Initialize the DynamixelDriver class.
+
+ Args:
+ ids (Sequence[int]): A list of IDs for the Dynamixel servos.
+ port (str): The USB port to connect to the arm.
+ baudrate (int): The baudrate for communication.
+ """
+ self._ids = ids
+ self._joint_angles = None
+ self._lock = Lock()
+
+ # Initialize the port handler, packet handler, and group sync read/write
+ self._portHandler = PortHandler(port)
+ self._packetHandler = PacketHandler(2.0)
+ self._groupSyncRead = GroupSyncRead(
+ self._portHandler,
+ self._packetHandler,
+ ADDR_PRESENT_POSITION,
+ LEN_PRESENT_POSITION,
+ )
+ self._groupSyncWrite = GroupSyncWrite(
+ self._portHandler,
+ self._packetHandler,
+ ADDR_GOAL_POSITION,
+ LEN_GOAL_POSITION,
+ )
+
+ # Open the port and set the baudrate
+ if not self._portHandler.openPort():
+ raise RuntimeError("Failed to open the port")
+
+ if not self._portHandler.setBaudRate(baudrate):
+ raise RuntimeError(f"Failed to change the baudrate, {baudrate}")
+
+ # Add parameters for each Dynamixel servo to the group sync read
+ for dxl_id in self._ids:
+ if not self._groupSyncRead.addParam(dxl_id):
+ raise RuntimeError(
+ f"Failed to add parameter for Dynamixel with ID {dxl_id}"
+ )
+
+ # Disable torque for each Dynamixel servo
+ self._torque_enabled = False
+ try:
+ self.set_torque_mode(self._torque_enabled)
+ except Exception as e:
+ print(f"port: {port}, {e}")
+
+ self._stop_thread = Event()
+ self._start_reading_thread()
+
+ def set_joints(self, joint_angles: Sequence[float]):
+ if len(joint_angles) != len(self._ids):
+ raise ValueError(
+ "The length of joint_angles must match the number of servos"
+ )
+ if not self._torque_enabled:
+ raise RuntimeError("Torque must be enabled to set joint angles")
+
+ for dxl_id, angle in zip(self._ids, joint_angles):
+ # Convert the angle to the appropriate value for the servo
+ position_value = int(angle * 2048 / np.pi)
+
+ # Allocate goal position value into byte array
+ param_goal_position = [
+ DXL_LOBYTE(DXL_LOWORD(position_value)),
+ DXL_HIBYTE(DXL_LOWORD(position_value)),
+ DXL_LOBYTE(DXL_HIWORD(position_value)),
+ DXL_HIBYTE(DXL_HIWORD(position_value)),
+ ]
+
+ # Add goal position value to the Syncwrite parameter storage
+ dxl_addparam_result = self._groupSyncWrite.addParam(
+ dxl_id, param_goal_position
+ )
+ if not dxl_addparam_result:
+ raise RuntimeError(
+ f"Failed to set joint angle for Dynamixel with ID {dxl_id}"
+ )
+
+ # Syncwrite goal position
+ dxl_comm_result = self._groupSyncWrite.txPacket()
+ if dxl_comm_result != COMM_SUCCESS:
+ raise RuntimeError("Failed to syncwrite goal position")
+
+ # Clear syncwrite parameter storage
+ self._groupSyncWrite.clearParam()
+
+ def torque_enabled(self) -> bool:
+ return self._torque_enabled
+
+ def set_torque_mode(self, enable: bool):
+ torque_value = TORQUE_ENABLE if enable else TORQUE_DISABLE
+ with self._lock:
+ for dxl_id in self._ids:
+ dxl_comm_result, dxl_error = self._packetHandler.write1ByteTxRx(
+ self._portHandler, dxl_id, ADDR_TORQUE_ENABLE, torque_value
+ )
+ if dxl_comm_result != COMM_SUCCESS or dxl_error != 0:
+ print(dxl_comm_result)
+ print(dxl_error)
+ raise RuntimeError(
+ f"Failed to set torque mode for Dynamixel with ID {dxl_id}"
+ )
+
+ self._torque_enabled = enable
+
+ def _start_reading_thread(self):
+ self._reading_thread = Thread(target=self._read_joint_angles)
+ self._reading_thread.daemon = True
+ self._reading_thread.start()
+
+ def _read_joint_angles(self):
+ # Continuously read joint angles and update the joint_angles array
+ while not self._stop_thread.is_set():
+ time.sleep(0.001)
+ with self._lock:
+ _joint_angles = np.zeros(len(self._ids), dtype=int)
+ dxl_comm_result = self._groupSyncRead.txRxPacket()
+ if dxl_comm_result != COMM_SUCCESS:
+ print(f"warning, comm failed: {dxl_comm_result}")
+ continue
+ for i, dxl_id in enumerate(self._ids):
+ if self._groupSyncRead.isAvailable(
+ dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION
+ ):
+ angle = self._groupSyncRead.getData(
+ dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION
+ )
+ angle = np.int32(np.uint32(angle))
+ _joint_angles[i] = angle
+ else:
+ raise RuntimeError(
+ f"Failed to get joint angles for Dynamixel with ID {dxl_id}"
+ )
+ self._joint_angles = _joint_angles
+ # self._groupSyncRead.clearParam() # TODO what does this do? should i add it
+
+ def get_joints(self) -> np.ndarray:
+ # Return a copy of the joint_angles array to avoid race conditions
+ while self._joint_angles is None:
+ time.sleep(0.1)
+ with self._lock:
+ _j = self._joint_angles.copy()
+ return _j / 2048.0 * np.pi
+
+ def close(self):
+ self._stop_thread.set()
+ self._reading_thread.join()
+ self._portHandler.closePort()
+
+
+def main():
+ # Set the port, baudrate, and servo IDs
+ ids = [1]
+
+ # Create a DynamixelDriver instance
+ try:
+ driver = DynamixelDriver(ids)
+ except FileNotFoundError:
+ driver = DynamixelDriver(ids, port="/dev/cu.usbserial-FT7WBMUB")
+
+ driver.set_torque_mode(True)
+ driver.set_torque_mode(False)
+
+ # Print the joint angles
+ try:
+ while True:
+ joint_angles = driver.get_joints()
+ print(f"Joint angles for IDs {ids}: {joint_angles}")
+ # print(f"Joint angles for IDs {ids[1]}: {joint_angles[1]}")
+ except KeyboardInterrupt:
+ driver.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/dynamixel/tests/test_driver.py b/gello/dynamixel/tests/test_driver.py
new file mode 100644
index 0000000..218744b
--- /dev/null
+++ b/gello/dynamixel/tests/test_driver.py
@@ -0,0 +1,35 @@
+import numpy as np
+import pytest
+
+from gello.dynamixel.driver import FakeDynamixelDriver
+
+
+@pytest.fixture
+def fake_driver():
+ return FakeDynamixelDriver(ids=[1, 2])
+
+
+def test_set_joints(fake_driver):
+ fake_driver.set_torque_mode(True)
+ fake_driver.set_joints([np.pi / 2, np.pi / 2])
+ assert np.allclose(fake_driver.get_joints(), [np.pi / 2, np.pi / 2])
+
+
+def test_set_joints_wrong_length(fake_driver):
+ with pytest.raises(ValueError):
+ fake_driver.set_joints([np.pi / 2])
+
+
+def test_set_joints_torque_disabled(fake_driver):
+ with pytest.raises(RuntimeError):
+ fake_driver.set_joints([np.pi / 2, np.pi / 2])
+
+
+def test_torque_enabled(fake_driver):
+ assert not fake_driver.torque_enabled()
+ fake_driver.set_torque_mode(True)
+ assert fake_driver.torque_enabled()
+
+
+def test_get_joints(fake_driver):
+ assert np.allclose(fake_driver.get_joints(), [0, 0])
diff --git a/gello/env.py b/gello/env.py
new file mode 100644
index 0000000..f88d583
--- /dev/null
+++ b/gello/env.py
@@ -0,0 +1,88 @@
+import time
+from typing import Any, Dict, Optional
+
+import numpy as np
+
+from gello.cameras.camera import CameraDriver
+from gello.robots.robot import Robot
+
+
+class Rate:
+ def __init__(self, rate: float):
+ self.last = time.time()
+ self.rate = rate
+
+ def sleep(self) -> None:
+ while self.last + 1.0 / self.rate > time.time():
+ time.sleep(0.0001)
+ self.last = time.time()
+
+
+class RobotEnv:
+ def __init__(
+ self,
+ robot: Robot,
+ control_rate_hz: float = 100.0,
+ camera_dict: Optional[Dict[str, CameraDriver]] = None,
+ ) -> None:
+ self._robot = robot
+ self._rate = Rate(control_rate_hz)
+ self._camera_dict = {} if camera_dict is None else camera_dict
+
+ def robot(self) -> Robot:
+ """Get the robot object.
+
+ Returns:
+ robot: the robot object.
+ """
+ return self._robot
+
+ def __len__(self):
+ return 0
+
+ def step(self, joints: np.ndarray) -> Dict[str, Any]:
+ """Step the environment forward.
+
+ Args:
+ joints: joint angles command to step the environment with.
+
+ Returns:
+ obs: observation from the environment.
+ """
+ assert len(joints) == (
+ self._robot.num_dofs()
+ ), f"input:{len(joints)}, robot:{self._robot.num_dofs()}"
+ assert self._robot.num_dofs() == len(joints)
+ self._robot.command_joint_state(joints)
+ self._rate.sleep()
+ return self.get_obs()
+
+ def get_obs(self) -> Dict[str, Any]:
+ """Get observation from the environment.
+
+ Returns:
+ obs: observation from the environment.
+ """
+ observations = {}
+ for name, camera in self._camera_dict.items():
+ image, depth = camera.read()
+ observations[f"{name}_rgb"] = image
+ observations[f"{name}_depth"] = depth
+
+ robot_obs = self._robot.get_observations()
+ assert "joint_positions" in robot_obs
+ assert "joint_velocities" in robot_obs
+ assert "ee_pos_quat" in robot_obs
+ observations["joint_positions"] = robot_obs["joint_positions"]
+ observations["joint_velocities"] = robot_obs["joint_velocities"]
+ observations["ee_pos_quat"] = robot_obs["ee_pos_quat"]
+ observations["gripper_position"] = robot_obs["gripper_position"]
+ return observations
+
+
+def main() -> None:
+ pass
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/robots/dynamixel.py b/gello/robots/dynamixel.py
new file mode 100644
index 0000000..d8676f6
--- /dev/null
+++ b/gello/robots/dynamixel.py
@@ -0,0 +1,134 @@
+from typing import Dict, Optional, Sequence, Tuple
+
+import numpy as np
+
+from gello.robots.robot import Robot
+
+
+class DynamixelRobot(Robot):
+ """A class representing a UR robot."""
+
+ def __init__(
+ self,
+ joint_ids: Sequence[int],
+ joint_offsets: Optional[Sequence[float]] = None,
+ joint_signs: Optional[Sequence[int]] = None,
+ real: bool = False,
+ port: str = "/dev/ttyUSB0",
+ baudrate: int = 57600,
+ gripper_config: Optional[Tuple[int, float, float]] = None,
+ start_joints: Optional[np.ndarray] = None,
+ ):
+ from gello.dynamixel.driver import (
+ DynamixelDriver,
+ DynamixelDriverProtocol,
+ FakeDynamixelDriver,
+ )
+
+ print(f"attempting to connect to port: {port}")
+ self.gripper_open_close: Optional[Tuple[float, float]]
+ if gripper_config is not None:
+ assert joint_offsets is not None
+ assert joint_signs is not None
+
+ # joint_ids.append(gripper_config[0])
+ # joint_offsets.append(0.0)
+ # joint_signs.append(1)
+ joint_ids = tuple(joint_ids) + (gripper_config[0],)
+ joint_offsets = tuple(joint_offsets) + (0.0,)
+ joint_signs = tuple(joint_signs) + (1,)
+ self.gripper_open_close = (
+ gripper_config[1] * np.pi / 180,
+ gripper_config[2] * np.pi / 180,
+ )
+ else:
+ self.gripper_open_close = None
+
+ self._joint_ids = joint_ids
+ self._driver: DynamixelDriverProtocol
+
+ if joint_offsets is None:
+ self._joint_offsets = np.zeros(len(joint_ids))
+ else:
+ self._joint_offsets = np.array(joint_offsets)
+
+ if joint_signs is None:
+ self._joint_signs = np.ones(len(joint_ids))
+ else:
+ self._joint_signs = np.array(joint_signs)
+
+ assert len(self._joint_ids) == len(self._joint_offsets), (
+ f"joint_ids: {len(self._joint_ids)}, "
+ f"joint_offsets: {len(self._joint_offsets)}"
+ )
+ assert len(self._joint_ids) == len(self._joint_signs), (
+ f"joint_ids: {len(self._joint_ids)}, "
+ f"joint_signs: {len(self._joint_signs)}"
+ )
+ assert np.all(
+ np.abs(self._joint_signs) == 1
+ ), f"joint_signs: {self._joint_signs}"
+
+ if real:
+ self._driver = DynamixelDriver(joint_ids, port=port, baudrate=baudrate)
+ self._driver.set_torque_mode(False)
+ else:
+ self._driver = FakeDynamixelDriver(joint_ids)
+ self._torque_on = False
+ self._last_pos = None
+ self._alpha = 0.99
+
+ if start_joints is not None:
+ # loop through all joints and add +- 2pi to the joint offsets to get the closest to start joints
+ new_joint_offsets = []
+ current_joints = self.get_joint_state()
+ assert current_joints.shape == start_joints.shape
+ if gripper_config is not None:
+ current_joints = current_joints[:-1]
+ start_joints = start_joints[:-1]
+ for c_joint, s_joint, joint_offset in zip(
+ current_joints, start_joints, self._joint_offsets
+ ):
+ new_joint_offsets.append(
+ np.pi * 2 * np.round((s_joint - c_joint) / (2 * np.pi))
+ + joint_offset
+ )
+ if gripper_config is not None:
+ new_joint_offsets.append(self._joint_offsets[-1])
+ self._joint_offsets = np.array(new_joint_offsets)
+
+ def num_dofs(self) -> int:
+ return len(self._joint_ids)
+
+ def get_joint_state(self) -> np.ndarray:
+ pos = (self._driver.get_joints() - self._joint_offsets) * self._joint_signs
+ assert len(pos) == self.num_dofs()
+
+ if self.gripper_open_close is not None:
+ # map pos to [0, 1]
+ g_pos = (pos[-1] - self.gripper_open_close[0]) / (
+ self.gripper_open_close[1] - self.gripper_open_close[0]
+ )
+ g_pos = min(max(0, g_pos), 1)
+ pos[-1] = g_pos
+
+ if self._last_pos is None:
+ self._last_pos = pos
+ else:
+ # exponential smoothing
+ pos = self._last_pos * (1 - self._alpha) + pos * self._alpha
+ self._last_pos = pos
+
+ return pos
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ self._driver.set_joints((joint_state + self._joint_offsets).tolist())
+
+ def set_torque_mode(self, mode: bool):
+ if mode == self._torque_on:
+ return
+ self._driver.set_torque_mode(mode)
+ self._torque_on = mode
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ return {"joint_state": self.get_joint_state()}
diff --git a/gello/robots/panda.py b/gello/robots/panda.py
new file mode 100644
index 0000000..704f16a
--- /dev/null
+++ b/gello/robots/panda.py
@@ -0,0 +1,88 @@
+import time
+from typing import Dict
+
+import numpy as np
+
+from gello.robots.robot import Robot
+
+MAX_OPEN = 0.09
+
+
+class PandaRobot(Robot):
+ """A class representing a UR robot."""
+
+ def __init__(self, robot_ip: str = "100.97.47.74"):
+ from polymetis import GripperInterface, RobotInterface
+
+ self.robot = RobotInterface(
+ ip_address=robot_ip,
+ )
+ self.gripper = GripperInterface(
+ ip_address="localhost",
+ )
+ self.robot.go_home()
+ self.robot.start_joint_impedance()
+ self.gripper.goto(width=MAX_OPEN, speed=255, force=255)
+ time.sleep(1)
+
+ def num_dofs(self) -> int:
+ """Get the number of joints of the robot.
+
+ Returns:
+ int: The number of joints of the robot.
+ """
+ return 8
+
+ def get_joint_state(self) -> np.ndarray:
+ """Get the current state of the leader robot.
+
+ Returns:
+ T: The current state of the leader robot.
+ """
+ robot_joints = self.robot.get_joint_positions()
+ gripper_pos = self.gripper.get_state()
+ pos = np.append(robot_joints, gripper_pos.width / MAX_OPEN)
+ return pos
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ """Command the leader robot to a given state.
+
+ Args:
+ joint_state (np.ndarray): The state to command the leader robot to.
+ """
+ import torch
+
+ self.robot.update_desired_joint_positions(torch.tensor(joint_state[:-1]))
+ self.gripper.goto(width=(MAX_OPEN * (1 - joint_state[-1])), speed=1, force=1)
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ joints = self.get_joint_state()
+ pos_quat = np.zeros(7)
+ gripper_pos = np.array([joints[-1]])
+ return {
+ "joint_positions": joints,
+ "joint_velocities": joints,
+ "ee_pos_quat": pos_quat,
+ "gripper_position": gripper_pos,
+ }
+
+
+def main():
+ robot = PandaRobot()
+ current_joints = robot.get_joint_state()
+ # move a small delta 0.1 rad
+ move_joints = current_joints + 0.05
+ # make last joint (gripper) closed
+ move_joints[-1] = 0.5
+ time.sleep(1)
+ m = 0.09
+ robot.gripper.goto(1 * m, speed=255, force=255)
+ time.sleep(1)
+ robot.gripper.goto(1.05 * m, speed=255, force=255)
+ time.sleep(1)
+ robot.gripper.goto(1.1 * m, speed=255, force=255)
+ time.sleep(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/robots/robot.py b/gello/robots/robot.py
new file mode 100644
index 0000000..b4a6e96
--- /dev/null
+++ b/gello/robots/robot.py
@@ -0,0 +1,128 @@
+from abc import abstractmethod
+from typing import Dict, Protocol
+
+import numpy as np
+
+
+class Robot(Protocol):
+ """Robot protocol.
+
+ A protocol for a robot that can be controlled.
+ """
+
+ @abstractmethod
+ def num_dofs(self) -> int:
+ """Get the number of joints of the robot.
+
+ Returns:
+ int: The number of joints of the robot.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_joint_state(self) -> np.ndarray:
+ """Get the current state of the leader robot.
+
+ Returns:
+ T: The current state of the leader robot.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ """Command the leader robot to a given state.
+
+ Args:
+ joint_state (np.ndarray): The state to command the leader robot to.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ """Get the current observations of the robot.
+
+ This is to extract all the information that is available from the robot,
+ such as joint positions, joint velocities, etc. This may also include
+ information from additional sensors, such as cameras, force sensors, etc.
+
+ Returns:
+ Dict[str, np.ndarray]: A dictionary of observations.
+ """
+ raise NotImplementedError
+
+
+class PrintRobot(Robot):
+ """A robot that prints the commanded joint state."""
+
+ def __init__(self, num_dofs: int, dont_print: bool = False):
+ self._num_dofs = num_dofs
+ self._joint_state = np.zeros(num_dofs)
+ self._dont_print = dont_print
+
+ def num_dofs(self) -> int:
+ return self._num_dofs
+
+ def get_joint_state(self) -> np.ndarray:
+ return self._joint_state
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ assert len(joint_state) == (self._num_dofs), (
+ f"Expected joint state of length {self._num_dofs}, "
+ f"got {len(joint_state)}."
+ )
+ self._joint_state = joint_state
+ if not self._dont_print:
+ print(self._joint_state)
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ joint_state = self.get_joint_state()
+ pos_quat = np.zeros(7)
+ return {
+ "joint_positions": joint_state,
+ "joint_velocities": joint_state,
+ "ee_pos_quat": pos_quat,
+ "gripper_position": np.array(0),
+ }
+
+
+class BimanualRobot(Robot):
+ def __init__(self, robot_l: Robot, robot_r: Robot):
+ self._robot_l = robot_l
+ self._robot_r = robot_r
+
+ def num_dofs(self) -> int:
+ return self._robot_l.num_dofs() + self._robot_r.num_dofs()
+
+ def get_joint_state(self) -> np.ndarray:
+ return np.concatenate(
+ (self._robot_l.get_joint_state(), self._robot_r.get_joint_state())
+ )
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()])
+ self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :])
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ l_obs = self._robot_l.get_observations()
+ r_obs = self._robot_r.get_observations()
+ assert l_obs.keys() == r_obs.keys()
+ return_obs = {}
+ for k in l_obs.keys():
+ try:
+ return_obs[k] = np.concatenate((l_obs[k], r_obs[k]))
+ except Exception as e:
+ print(e)
+ print(k)
+ print(l_obs[k])
+ print(r_obs[k])
+ raise RuntimeError()
+
+ return return_obs
+
+
+def main():
+ pass
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/robots/robotiq_gripper.py b/gello/robots/robotiq_gripper.py
new file mode 100644
index 0000000..40192df
--- /dev/null
+++ b/gello/robots/robotiq_gripper.py
@@ -0,0 +1,358 @@
+"""Module to control Robotiq's grippers - tested with HAND-E.
+
+Taken from https://github.com/githubuser0xFFFF/py_robotiq_gripper/blob/master/src/robotiq_gripper.py
+"""
+
+import socket
+import threading
+import time
+from enum import Enum
+from typing import OrderedDict, Tuple, Union
+
+
+class RobotiqGripper:
+ """Communicates with the gripper directly, via socket with string commands, leveraging string names for variables."""
+
+ # WRITE VARIABLES (CAN ALSO READ)
+ ACT = (
+ "ACT" # act : activate (1 while activated, can be reset to clear fault status)
+ )
+ GTO = (
+ "GTO" # gto : go to (will perform go to with the actions set in pos, for, spe)
+ )
+ ATR = "ATR" # atr : auto-release (emergency slow move)
+ ADR = (
+ "ADR" # adr : auto-release direction (open(1) or close(0) during auto-release)
+ )
+ FOR = "FOR" # for : force (0-255)
+ SPE = "SPE" # spe : speed (0-255)
+ POS = "POS" # pos : position (0-255), 0 = open
+ # READ VARIABLES
+ STA = "STA" # status (0 = is reset, 1 = activating, 3 = active)
+ PRE = "PRE" # position request (echo of last commanded position)
+ OBJ = "OBJ" # object detection (0 = moving, 1 = outer grip, 2 = inner grip, 3 = no object at rest)
+ FLT = "FLT" # fault (0=ok, see manual for errors if not zero)
+
+ ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work
+
+ class GripperStatus(Enum):
+ """Gripper status reported by the gripper. The integer values have to match what the gripper sends."""
+
+ RESET = 0
+ ACTIVATING = 1
+ # UNUSED = 2 # This value is currently not used by the gripper firmware
+ ACTIVE = 3
+
+ class ObjectStatus(Enum):
+ """Object status reported by the gripper. The integer values have to match what the gripper sends."""
+
+ MOVING = 0
+ STOPPED_OUTER_OBJECT = 1
+ STOPPED_INNER_OBJECT = 2
+ AT_DEST = 3
+
+ def __init__(self):
+ """Constructor."""
+ self.socket = None
+ self.command_lock = threading.Lock()
+ self._min_position = 0
+ self._max_position = 255
+ self._min_speed = 0
+ self._max_speed = 255
+ self._min_force = 0
+ self._max_force = 255
+
+ def connect(self, hostname: str, port: int, socket_timeout: float = 10.0) -> None:
+ """Connects to a gripper at the given address.
+
+ :param hostname: Hostname or ip.
+ :param port: Port.
+ :param socket_timeout: Timeout for blocking socket operations.
+ """
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ assert self.socket is not None
+ self.socket.connect((hostname, port))
+ self.socket.settimeout(socket_timeout)
+
+ def disconnect(self) -> None:
+ """Closes the connection with the gripper."""
+ assert self.socket is not None
+ self.socket.close()
+
+ def _set_vars(self, var_dict: OrderedDict[str, Union[int, float]]):
+ """Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response.
+
+ :param var_dict: Dictionary of variables to set (variable_name, value).
+ :return: True on successful reception of ack, false if no ack was received, indicating the set may not
+ have been effective.
+ """
+ assert self.socket is not None
+ # construct unique command
+ cmd = "SET"
+ for variable, value in var_dict.items():
+ cmd += f" {variable} {str(value)}"
+ cmd += "\n" # new line is required for the command to finish
+ # atomic commands send/rcv
+ with self.command_lock:
+ self.socket.sendall(cmd.encode(self.ENCODING))
+ data = self.socket.recv(1024)
+ return self._is_ack(data)
+
+ def _set_var(self, variable: str, value: Union[int, float]):
+ """Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response.
+
+ :param variable: Variable to set.
+ :param value: Value to set for the variable.
+ :return: True on successful reception of ack, false if no ack was received, indicating the set may not
+ have been effective.
+ """
+ return self._set_vars(OrderedDict([(variable, value)]))
+
+ def _get_var(self, variable: str):
+ """Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the response is received or the socket times out.
+
+ :param variable: Name of the variable to retrieve.
+ :return: Value of the variable as integer.
+ """
+ assert self.socket is not None
+ # atomic commands send/rcv
+ with self.command_lock:
+ cmd = f"GET {variable}\n"
+ self.socket.sendall(cmd.encode(self.ENCODING))
+ data = self.socket.recv(1024)
+
+ # expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value
+ # note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here
+ var_name, value_str = data.decode(self.ENCODING).split()
+ if var_name != variable:
+ raise ValueError(
+ f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'"
+ )
+ value = int(value_str)
+ return value
+
+ @staticmethod
+ def _is_ack(data: str):
+ return data == b"ack"
+
+ def _reset(self):
+ """Reset the gripper.
+
+ The following code is executed in the corresponding script function
+ def rq_reset(gripper_socket="1"):
+ rq_set_var("ACT", 0, gripper_socket)
+ rq_set_var("ATR", 0, gripper_socket)
+
+ while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
+ rq_set_var("ACT", 0, gripper_socket)
+ rq_set_var("ATR", 0, gripper_socket)
+ sync()
+ end
+
+ sleep(0.5)
+ end
+ """
+ self._set_var(self.ACT, 0)
+ self._set_var(self.ATR, 0)
+ while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
+ self._set_var(self.ACT, 0)
+ self._set_var(self.ATR, 0)
+ time.sleep(0.5)
+
+ def activate(self, auto_calibrate: bool = True):
+ """Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags.
+
+ :param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion.
+
+ The following code is executed in the corresponding script function
+ def rq_activate(gripper_socket="1"):
+ if (not rq_is_gripper_activated(gripper_socket)):
+ rq_reset(gripper_socket)
+
+ while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
+ rq_reset(gripper_socket)
+ sync()
+ end
+
+ rq_set_var("ACT",1, gripper_socket)
+ end
+ end
+
+ def rq_activate_and_wait(gripper_socket="1"):
+ if (not rq_is_gripper_activated(gripper_socket)):
+ rq_activate(gripper_socket)
+ sleep(1.0)
+
+ while(not rq_get_var("ACT", 1, gripper_socket) == 1 or not rq_get_var("STA", 1, gripper_socket) == 3):
+ sleep(0.1)
+ end
+
+ sleep(0.5)
+ end
+ end
+ """
+ if not self.is_active():
+ self._reset()
+ while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
+ time.sleep(0.01)
+
+ self._set_var(self.ACT, 1)
+ time.sleep(1.0)
+ while not self._get_var(self.ACT) == 1 or not self._get_var(self.STA) == 3:
+ time.sleep(0.01)
+
+ # auto-calibrate position range if desired
+ if auto_calibrate:
+ self.auto_calibrate()
+
+ def is_active(self):
+ """Returns whether the gripper is active."""
+ status = self._get_var(self.STA)
+ return (
+ RobotiqGripper.GripperStatus(status) == RobotiqGripper.GripperStatus.ACTIVE
+ )
+
+ def get_min_position(self) -> int:
+ """Returns the minimum position the gripper can reach (open position)."""
+ return self._min_position
+
+ def get_max_position(self) -> int:
+ """Returns the maximum position the gripper can reach (closed position)."""
+ return self._max_position
+
+ def get_open_position(self) -> int:
+ """Returns what is considered the open position for gripper (minimum position value)."""
+ return self.get_min_position()
+
+ def get_closed_position(self) -> int:
+ """Returns what is considered the closed position for gripper (maximum position value)."""
+ return self.get_max_position()
+
+ def is_open(self):
+ """Returns whether the current position is considered as being fully open."""
+ return self.get_current_position() <= self.get_open_position()
+
+ def is_closed(self):
+ """Returns whether the current position is considered as being fully closed."""
+ return self.get_current_position() >= self.get_closed_position()
+
+ def get_current_position(self) -> int:
+ """Returns the current position as returned by the physical hardware."""
+ return self._get_var(self.POS)
+
+ def auto_calibrate(self, log: bool = True) -> None:
+ """Attempts to calibrate the open and closed positions, by slowly closing and opening the gripper.
+
+ :param log: Whether to print the results to log.
+ """
+ # first try to open in case we are holding an object
+ (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
+ if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
+ raise RuntimeError(f"Calibration failed opening to start: {str(status)}")
+
+ # try to close as far as possible, and record the number
+ (position, status) = self.move_and_wait_for_pos(
+ self.get_closed_position(), 64, 1
+ )
+ if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
+ raise RuntimeError(
+ f"Calibration failed because of an object: {str(status)}"
+ )
+ assert position <= self._max_position
+ self._max_position = position
+
+ # try to open as far as possible, and record the number
+ (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
+ if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
+ raise RuntimeError(
+ f"Calibration failed because of an object: {str(status)}"
+ )
+ assert position >= self._min_position
+ self._min_position = position
+
+ if log:
+ print(
+ f"Gripper auto-calibrated to [{self.get_min_position()}, {self.get_max_position()}]"
+ )
+
+ def move(self, position: int, speed: int, force: int) -> Tuple[bool, int]:
+ """Sends commands to start moving towards the given position, with the specified speed and force.
+
+ :param position: Position to move to [min_position, max_position]
+ :param speed: Speed to move at [min_speed, max_speed]
+ :param force: Force to use [min_force, max_force]
+ :return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with
+ the actual position that was requested, after being adjusted to the min/max calibrated range.
+ """
+
+ def clip_val(min_val, val, max_val):
+ return max(min_val, min(val, max_val))
+
+ clip_pos = clip_val(self._min_position, position, self._max_position)
+ clip_spe = clip_val(self._min_speed, speed, self._max_speed)
+ clip_for = clip_val(self._min_force, force, self._max_force)
+
+ # moves to the given position with the given speed and force
+ var_dict = OrderedDict(
+ [
+ (self.POS, clip_pos),
+ (self.SPE, clip_spe),
+ (self.FOR, clip_for),
+ (self.GTO, 1),
+ ]
+ )
+ succ = self._set_vars(var_dict)
+ time.sleep(0.008) # need to wait (dont know why)
+ return succ, clip_pos
+
+ def move_and_wait_for_pos(
+ self, position: int, speed: int, force: int
+ ) -> Tuple[int, ObjectStatus]: # noqa
+ """Sends commands to start moving towards the given position, with the specified speed and force, and then waits for the move to complete.
+
+ :param position: Position to move to [min_position, max_position]
+ :param speed: Speed to move at [min_speed, max_speed]
+ :param force: Force to use [min_force, max_force]
+ :return: A tuple with an integer representing the last position returned by the gripper after it notified
+ that the move had completed, a status indicating how the move ended (see ObjectStatus enum for details). Note
+ that it is possible that the position was not reached, if an object was detected during motion.
+ """
+ set_ok, cmd_pos = self.move(position, speed, force)
+ if not set_ok:
+ raise RuntimeError("Failed to set variables for move.")
+
+ # wait until the gripper acknowledges that it will try to go to the requested position
+ while self._get_var(self.PRE) != cmd_pos:
+ time.sleep(0.001)
+
+ # wait until not moving
+ cur_obj = self._get_var(self.OBJ)
+ while (
+ RobotiqGripper.ObjectStatus(cur_obj) == RobotiqGripper.ObjectStatus.MOVING
+ ):
+ cur_obj = self._get_var(self.OBJ)
+
+ # report the actual position and the object status
+ final_pos = self._get_var(self.POS)
+ final_obj = cur_obj
+ return final_pos, RobotiqGripper.ObjectStatus(final_obj)
+
+
+def main():
+ # test open and closing the gripper
+ gripper = RobotiqGripper()
+ gripper.connect(hostname="192.168.1.10", port=63352)
+ # gripper.activate()
+ print(gripper.get_current_position())
+ gripper.move(20, 255, 1)
+ time.sleep(0.2)
+ print(gripper.get_current_position())
+ gripper.move(65, 255, 1)
+ time.sleep(0.2)
+ print(gripper.get_current_position())
+ gripper.move(20, 255, 1)
+ gripper.disconnect()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/robots/sim_robot.py b/gello/robots/sim_robot.py
new file mode 100644
index 0000000..9ae5c26
--- /dev/null
+++ b/gello/robots/sim_robot.py
@@ -0,0 +1,256 @@
+import pickle
+import threading
+import time
+from typing import Any, Dict, Optional
+
+import mujoco
+import mujoco.viewer
+import numpy as np
+import zmq
+from dm_control import mjcf
+
+from gello.robots.robot import Robot
+
+assert mujoco.viewer is mujoco.viewer
+
+
+def attach_hand_to_arm(
+ arm_mjcf: mjcf.RootElement,
+ hand_mjcf: mjcf.RootElement,
+) -> None:
+ """Attaches a hand to an arm.
+
+ The arm must have a site named "attachment_site".
+
+ Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
+
+ Args:
+ arm_mjcf: The mjcf.RootElement of the arm.
+ hand_mjcf: The mjcf.RootElement of the hand.
+
+ Raises:
+ ValueError: If the arm does not have a site named "attachment_site".
+ """
+ physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
+
+ attachment_site = arm_mjcf.find("site", "attachment_site")
+ if attachment_site is None:
+ raise ValueError("No attachment site found in the arm model.")
+
+ # Expand the ctrl and qpos keyframes to account for the new hand DoFs.
+ arm_key = arm_mjcf.find("key", "home")
+ if arm_key is not None:
+ hand_key = hand_mjcf.find("key", "home")
+ if hand_key is None:
+ arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
+ arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
+ else:
+ arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
+ arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
+
+ attachment_site.attach(hand_mjcf)
+
+
+def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None):
+ # assert robot_xml_path.endswith(".xml")
+
+ arena = mjcf.RootElement()
+ arm_simulate = mjcf.from_path(robot_xml_path)
+ # arm_copy = mjcf.from_path(xml_path)
+
+ if gripper_xml_path is not None:
+ # attach gripper to the robot at "attachment_site"
+ gripper_simulate = mjcf.from_path(gripper_xml_path)
+ attach_hand_to_arm(arm_simulate, gripper_simulate)
+
+ arena.worldbody.attach(arm_simulate)
+ # arena.worldbody.attach(arm_copy)
+
+ return arena
+
+
+class ZMQServerThread(threading.Thread):
+ def __init__(self, server):
+ super().__init__()
+ self._server = server
+
+ def run(self):
+ self._server.serve()
+
+ def terminate(self):
+ self._server.stop()
+
+
+class ZMQRobotServer:
+ """A class representing a ZMQ server for a robot."""
+
+ def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556):
+ self._robot = robot
+ self._context = zmq.Context()
+ self._socket = self._context.socket(zmq.REP)
+ addr = f"tcp://{host}:{port}"
+ self._socket.bind(addr)
+ self._stop_event = threading.Event()
+
+ def serve(self) -> None:
+ """Serve the robot state and commands over ZMQ."""
+ self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
+ while not self._stop_event.is_set():
+ try:
+ message = self._socket.recv()
+ request = pickle.loads(message)
+
+ # Call the appropriate method based on the request
+ method = request.get("method")
+ args = request.get("args", {})
+ result: Any
+ if method == "num_dofs":
+ result = self._robot.num_dofs()
+ elif method == "get_joint_state":
+ result = self._robot.get_joint_state()
+ elif method == "command_joint_state":
+ result = self._robot.command_joint_state(**args)
+ elif method == "get_observations":
+ result = self._robot.get_observations()
+ else:
+ result = {"error": "Invalid method"}
+ print(result)
+ raise NotImplementedError(
+ f"Invalid method: {method}, {args, result}"
+ )
+
+ self._socket.send(pickle.dumps(result))
+ except zmq.error.Again:
+ print("Timeout in ZMQLeaderServer serve")
+ # Timeout occurred, check if the stop event is set
+
+ def stop(self) -> None:
+ self._stop_event.set()
+ self._socket.close()
+ self._context.term()
+
+
+class MujocoRobotServer:
+ def __init__(
+ self,
+ xml_path: str,
+ gripper_xml_path: Optional[str] = None,
+ host: str = "127.0.0.1",
+ port: int = 5556,
+ print_joints: bool = False,
+ ):
+ self._has_gripper = gripper_xml_path is not None
+ arena = build_scene(xml_path, gripper_xml_path)
+
+ assets: Dict[str, str] = {}
+ for asset in arena.asset.all_children():
+ if asset.tag == "mesh":
+ f = asset.file
+ assets[f.get_vfs_filename()] = asset.file.contents
+
+ xml_string = arena.to_xml_string()
+ # save xml_string to file
+ with open("arena.xml", "w") as f:
+ f.write(xml_string)
+
+ self._model = mujoco.MjModel.from_xml_string(xml_string, assets)
+ self._data = mujoco.MjData(self._model)
+
+ self._num_joints = self._model.nu
+
+ self._joint_state = np.zeros(self._num_joints)
+ self._joint_cmd = self._joint_state
+
+ self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port)
+ self._zmq_server_thread = ZMQServerThread(self._zmq_server)
+
+ self._print_joints = print_joints
+
+ def num_dofs(self) -> int:
+ return self._num_joints
+
+ def get_joint_state(self) -> np.ndarray:
+ return self._joint_state
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ assert len(joint_state) == self._num_joints, (
+ f"Expected joint state of length {self._num_joints}, "
+ f"got {len(joint_state)}."
+ )
+ if self._has_gripper:
+ _joint_state = joint_state.copy()
+ _joint_state[-1] = _joint_state[-1] * 255
+ self._joint_cmd = _joint_state
+ else:
+ self._joint_cmd = joint_state.copy()
+
+ def freedrive_enabled(self) -> bool:
+ return True
+
+ def set_freedrive_mode(self, enable: bool):
+ pass
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ joint_positions = self._data.qpos.copy()[: self._num_joints]
+ joint_velocities = self._data.qvel.copy()[: self._num_joints]
+ ee_site = "attachment_site"
+ try:
+ ee_pos = self._data.site_xpos.copy()[
+ mujoco.mj_name2id(self._model, 6, ee_site)
+ ]
+ ee_mat = self._data.site_xmat.copy()[
+ mujoco.mj_name2id(self._model, 6, ee_site)
+ ]
+ ee_quat = np.zeros(4)
+ mujoco.mju_mat2Quat(ee_quat, ee_mat)
+ except Exception:
+ ee_pos = np.zeros(3)
+ ee_quat = np.zeros(4)
+ ee_quat[0] = 1
+ gripper_pos = self._data.qpos.copy()[self._num_joints - 1]
+ return {
+ "joint_positions": joint_positions,
+ "joint_velocities": joint_velocities,
+ "ee_pos_quat": np.concatenate([ee_pos, ee_quat]),
+ "gripper_position": gripper_pos,
+ }
+
+ def serve(self) -> None:
+ # start the zmq server
+ self._zmq_server_thread.start()
+ with mujoco.viewer.launch_passive(self._model, self._data) as viewer:
+ while viewer.is_running():
+ step_start = time.time()
+
+ # mj_step can be replaced with code that also evaluates
+ # a policy and applies a control signal before stepping the physics.
+ self._data.ctrl[:] = self._joint_cmd
+ # self._data.qpos[:] = self._joint_cmd
+ mujoco.mj_step(self._model, self._data)
+ self._joint_state = self._data.qpos.copy()[: self._num_joints]
+
+ if self._print_joints:
+ print(self._joint_state)
+
+ # Example modification of a viewer option: toggle contact points every two seconds.
+ with viewer.lock():
+ # TODO remove?
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(
+ self._data.time % 2
+ )
+
+ # Pick up changes to the physics state, apply perturbations, update options from GUI.
+ viewer.sync()
+
+ # Rudimentary time keeping, will drift relative to wall clock.
+ time_until_next_step = self._model.opt.timestep - (
+ time.time() - step_start
+ )
+ if time_until_next_step > 0:
+ time.sleep(time_until_next_step)
+
+ def stop(self) -> None:
+ self._zmq_server_thread.join()
+
+ def __del__(self) -> None:
+ self.stop()
diff --git a/gello/robots/ur.py b/gello/robots/ur.py
new file mode 100644
index 0000000..b79cfad
--- /dev/null
+++ b/gello/robots/ur.py
@@ -0,0 +1,133 @@
+from typing import Dict
+
+import numpy as np
+
+from gello.robots.robot import Robot
+
+
+class URRobot(Robot):
+ """A class representing a UR robot."""
+
+ def __init__(self, robot_ip: str = "192.168.1.10", no_gripper: bool = False):
+ import rtde_control
+ import rtde_receive
+
+ [print("in ur robot") for _ in range(4)]
+ try:
+ self.robot = rtde_control.RTDEControlInterface(robot_ip)
+ except Exception as e:
+ print(e)
+ print(robot_ip)
+
+ self.r_inter = rtde_receive.RTDEReceiveInterface(robot_ip)
+ if not no_gripper:
+ from gello.robots.robotiq_gripper import RobotiqGripper
+
+ self.gripper = RobotiqGripper()
+ self.gripper.connect(hostname=robot_ip, port=63352)
+ print("gripper connected")
+ # gripper.activate()
+
+ [print("connect") for _ in range(4)]
+
+ self._free_drive = False
+ self.robot.endFreedriveMode()
+ self._use_gripper = not no_gripper
+
+ def num_dofs(self) -> int:
+ """Get the number of joints of the robot.
+
+ Returns:
+ int: The number of joints of the robot.
+ """
+ if self._use_gripper:
+ return 7
+ return 6
+
+ def _get_gripper_pos(self) -> float:
+ import time
+
+ time.sleep(0.01)
+ gripper_pos = self.gripper.get_current_position()
+ assert 0 <= gripper_pos <= 255, "Gripper position must be between 0 and 255"
+ return gripper_pos / 255
+
+ def get_joint_state(self) -> np.ndarray:
+ """Get the current state of the leader robot.
+
+ Returns:
+ T: The current state of the leader robot.
+ """
+ robot_joints = self.r_inter.getActualQ()
+ if self._use_gripper:
+ gripper_pos = self._get_gripper_pos()
+ pos = np.append(robot_joints, gripper_pos)
+ else:
+ pos = robot_joints
+ return pos
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ """Command the leader robot to a given state.
+
+ Args:
+ joint_state (np.ndarray): The state to command the leader robot to.
+ """
+ velocity = 0.5
+ acceleration = 0.5
+ dt = 1.0 / 500 # 2ms
+ lookahead_time = 0.2
+ gain = 100
+
+ robot_joints = joint_state[:6]
+ t_start = self.robot.initPeriod()
+ self.robot.servoJ(
+ robot_joints, velocity, acceleration, dt, lookahead_time, gain
+ )
+ if self._use_gripper:
+ gripper_pos = joint_state[-1] * 255
+ self.gripper.move(gripper_pos, 255, 10)
+ self.robot.waitPeriod(t_start)
+
+ def freedrive_enabled(self) -> bool:
+ """Check if the robot is in freedrive mode.
+
+ Returns:
+ bool: True if the robot is in freedrive mode, False otherwise.
+ """
+ return self._free_drive
+
+ def set_freedrive_mode(self, enable: bool) -> None:
+ """Set the freedrive mode of the robot.
+
+ Args:
+ enable (bool): True to enable freedrive mode, False to disable it.
+ """
+ if enable and not self._free_drive:
+ self._free_drive = True
+ self.robot.freedriveMode()
+ elif not enable and self._free_drive:
+ self._free_drive = False
+ self.robot.endFreedriveMode()
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ joints = self.get_joint_state()
+ pos_quat = np.zeros(7)
+ gripper_pos = np.array([joints[-1]])
+ return {
+ "joint_positions": joints,
+ "joint_velocities": joints,
+ "ee_pos_quat": pos_quat,
+ "gripper_position": gripper_pos,
+ }
+
+
+def main():
+ robot_ip = "192.168.1.11"
+ ur = URRobot(robot_ip, no_gripper=True)
+ print(ur)
+ ur.set_freedrive_mode(True)
+ print(ur.get_observations())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/robots/xarm_robot.py b/gello/robots/xarm_robot.py
new file mode 100644
index 0000000..366367b
--- /dev/null
+++ b/gello/robots/xarm_robot.py
@@ -0,0 +1,358 @@
+import dataclasses
+import threading
+import time
+from typing import Dict, Optional
+
+import numpy as np
+from pyquaternion import Quaternion
+
+from gello.robots.robot import Robot
+
+
+def _aa_from_quat(quat: np.ndarray) -> np.ndarray:
+ """Convert a quaternion to an axis-angle representation.
+
+ Args:
+ quat (np.ndarray): The quaternion to convert.
+
+ Returns:
+ np.ndarray: The axis-angle representation of the quaternion.
+ """
+ assert quat.shape == (4,), "Input quaternion must be a 4D vector."
+ norm = np.linalg.norm(quat)
+ assert norm != 0, "Input quaternion must not be a zero vector."
+ quat = quat / norm # Normalize the quaternion
+
+ Q = Quaternion(w=quat[3], x=quat[0], y=quat[1], z=quat[2])
+ angle = Q.angle
+ axis = Q.axis
+ aa = axis * angle
+ return aa
+
+
+def _quat_from_aa(aa: np.ndarray) -> np.ndarray:
+ """Convert an axis-angle representation to a quaternion.
+
+ Args:
+ aa (np.ndarray): The axis-angle representation to convert.
+
+ Returns:
+ np.ndarray: The quaternion representation of the axis-angle.
+ """
+ assert aa.shape == (3,), "Input axis-angle must be a 3D vector."
+ norm = np.linalg.norm(aa)
+ assert norm != 0, "Input axis-angle must not be a zero vector."
+ axis = aa / norm # Normalize the axis-angle
+
+ Q = Quaternion(axis=axis, angle=norm)
+ quat = np.array([Q.x, Q.y, Q.z, Q.w])
+ return quat
+
+
+@dataclasses.dataclass(frozen=True)
+class RobotState:
+ x: float
+ y: float
+ z: float
+ gripper: float
+ j1: float
+ j2: float
+ j3: float
+ j4: float
+ j5: float
+ j6: float
+ j7: float
+ aa: np.ndarray
+
+ @staticmethod
+ def from_robot(
+ cartesian: np.ndarray,
+ joints: np.ndarray,
+ gripper: float,
+ aa: np.ndarray,
+ ) -> "RobotState":
+ return RobotState(
+ cartesian[0],
+ cartesian[1],
+ cartesian[2],
+ gripper,
+ joints[0],
+ joints[1],
+ joints[2],
+ joints[3],
+ joints[4],
+ joints[5],
+ joints[6],
+ aa,
+ )
+
+ def cartesian_pos(self) -> np.ndarray:
+ return np.array([self.x, self.y, self.z])
+
+ def quat(self) -> np.ndarray:
+ return _quat_from_aa(self.aa)
+
+ def joints(self) -> np.ndarray:
+ return np.array([self.j1, self.j2, self.j3, self.j4, self.j5, self.j6, self.j7])
+
+ def gripper_pos(self) -> float:
+ return self.gripper
+
+
+class Rate:
+ def __init__(self, *, duration):
+ self.duration = duration
+ self.last = time.time()
+
+ def sleep(self, duration=None) -> None:
+ duration = self.duration if duration is None else duration
+ assert duration >= 0
+ now = time.time()
+ passed = now - self.last
+ remaining = duration - passed
+ assert passed >= 0
+ if remaining > 0.0001:
+ time.sleep(remaining)
+ self.last = time.time()
+
+
+class XArmRobot(Robot):
+ GRIPPER_OPEN = 800
+ GRIPPER_CLOSE = 0
+ # MAX_DELTA = 0.2
+ DEFAULT_MAX_DELTA = 0.05
+
+ def num_dofs(self) -> int:
+ return 8
+
+ def get_joint_state(self) -> np.ndarray:
+ state = self.get_state()
+ gripper = state.gripper_pos()
+ all_dofs = np.concatenate([state.joints(), np.array([gripper])])
+ return all_dofs
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ if len(joint_state) == 7:
+ self.set_command(joint_state, None)
+ elif len(joint_state) == 8:
+ self.set_command(joint_state[:7], joint_state[7])
+ else:
+ raise ValueError(
+ f"Invalid joint state: {joint_state}, len={len(joint_state)}"
+ )
+
+ def stop(self):
+ self.running = False
+ if self.robot is not None:
+ self.robot.disconnect()
+
+ if self.command_thread is not None:
+ self.command_thread.join()
+
+ def __init__(
+ self,
+ ip: str = "192.168.1.226",
+ real: bool = True,
+ control_frequency: float = 50.0,
+ max_delta: float = DEFAULT_MAX_DELTA,
+ ):
+ print(ip)
+ self.real = real
+ self.max_delta = max_delta
+ if real:
+ from xarm.wrapper import XArmAPI
+
+ self.robot = XArmAPI(ip, is_radian=True)
+ else:
+ self.robot = None
+
+ self._control_frequency = control_frequency
+ self._clear_error_states()
+ self._set_gripper_position(self.GRIPPER_OPEN)
+
+ self.last_state_lock = threading.Lock()
+ self.target_command_lock = threading.Lock()
+ self.last_state = self._update_last_state()
+ self.target_command = {
+ "joints": self.last_state.joints(),
+ "gripper": 0,
+ }
+ self.running = True
+ self.command_thread = None
+ if real:
+ self.command_thread = threading.Thread(target=self._robot_thread)
+ self.command_thread.start()
+
+ def get_state(self) -> RobotState:
+ with self.last_state_lock:
+ return self.last_state
+
+ def set_command(self, joints: np.ndarray, gripper: Optional[float] = None) -> None:
+ with self.target_command_lock:
+ self.target_command = {
+ "joints": joints,
+ "gripper": gripper,
+ }
+
+ def _clear_error_states(self):
+ if self.robot is None:
+ return
+ self.robot.clean_error()
+ self.robot.clean_warn()
+ self.robot.motion_enable(True)
+ time.sleep(1)
+ self.robot.set_mode(1)
+ time.sleep(1)
+ self.robot.set_collision_sensitivity(0)
+ time.sleep(1)
+ self.robot.set_state(state=0)
+ time.sleep(1)
+ self.robot.set_gripper_enable(True)
+ time.sleep(1)
+ self.robot.set_gripper_mode(0)
+ time.sleep(1)
+ self.robot.set_gripper_speed(3000)
+ time.sleep(1)
+
+ def _get_gripper_pos(self) -> float:
+ if self.robot is None:
+ return 0.0
+ code, gripper_pos = self.robot.get_gripper_position()
+ while code != 0 or gripper_pos is None:
+ print(f"Error code {code} in get_gripper_position(). {gripper_pos}")
+ time.sleep(0.001)
+ code, gripper_pos = self.robot.get_gripper_position()
+ if code == 22:
+ self._clear_error_states()
+
+ normalized_gripper_pos = (gripper_pos - self.GRIPPER_OPEN) / (
+ self.GRIPPER_CLOSE - self.GRIPPER_OPEN
+ )
+ return normalized_gripper_pos
+
+ def _set_gripper_position(self, pos: int) -> None:
+ if self.robot is None:
+ return
+ self.robot.set_gripper_position(pos, wait=False)
+ # while self.robot.get_is_moving():
+ # time.sleep(0.01)
+
+ def _robot_thread(self):
+ rate = Rate(
+ duration=1 / self._control_frequency
+ ) # command and update rate for robot
+ step_times = []
+ count = 0
+
+ while self.running:
+ s_t = time.time()
+ # update last state
+ self.last_state = self._update_last_state()
+ with self.target_command_lock:
+ joint_delta = np.array(
+ self.target_command["joints"] - self.last_state.joints()
+ )
+ gripper_command = self.target_command["gripper"]
+
+ norm = np.linalg.norm(joint_delta)
+
+ # threshold delta to be at most 0.01 in norm space
+ if norm > self.max_delta:
+ delta = joint_delta / norm * self.max_delta
+ else:
+ delta = joint_delta
+
+ # command position
+ self._set_position(
+ self.last_state.joints() + delta,
+ )
+
+ if gripper_command is not None:
+ set_point = gripper_command
+ self._set_gripper_position(
+ self.GRIPPER_OPEN
+ + set_point * (self.GRIPPER_CLOSE - self.GRIPPER_OPEN)
+ )
+ self.last_state = self._update_last_state()
+
+ rate.sleep()
+ step_times.append(time.time() - s_t)
+ count += 1
+ if count % 1000 == 0:
+ # Mean, Std, Min, Max, only show 3 decimal places and string pad with 10 spaces
+ frequency = 1 / np.mean(step_times)
+ # print(f"Step time - mean: {np.mean(step_times):10.3f}, std: {np.std(step_times):10.3f}, min: {np.min(step_times):10.3f}, max: {np.max(step_times):10.3f}")
+ print(
+ f"Low Level Frequency - mean: {frequency:10.3f}, std: {np.std(frequency):10.3f}, min: {np.min(frequency):10.3f}, max: {np.max(frequency):10.3f}"
+ )
+ step_times = []
+
+ def _update_last_state(self) -> RobotState:
+ with self.last_state_lock:
+ if self.robot is None:
+ return RobotState(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, np.zeros(3))
+
+ gripper_pos = self._get_gripper_pos()
+
+ code, servo_angle = self.robot.get_servo_angle(is_radian=True)
+ while code != 0:
+ print(f"Error code {code} in get_servo_angle().")
+ self._clear_error_states()
+ code, servo_angle = self.robot.get_servo_angle(is_radian=True)
+
+ code, cart_pos = self.robot.get_position_aa(is_radian=True)
+ while code != 0:
+ print(f"Error code {code} in get_position().")
+ self._clear_error_states()
+ code, cart_pos = self.robot.get_position_aa(is_radian=True)
+
+ cart_pos = np.array(cart_pos)
+ aa = cart_pos[3:]
+ cart_pos[:3] /= 1000
+
+ return RobotState.from_robot(
+ cart_pos,
+ servo_angle,
+ gripper_pos,
+ aa,
+ )
+
+ def _set_position(
+ self,
+ joints: np.ndarray,
+ ) -> None:
+ if self.robot is None:
+ return
+ # threhold xyz to be in min max
+ ret = self.robot.set_servo_angle_j(joints, wait=False, is_radian=True)
+ if ret in [1, 9]:
+ self._clear_error_states()
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ state = self.get_state()
+ pos_quat = np.concatenate([state.cartesian_pos(), state.quat()])
+ joints = self.get_joint_state()
+ return {
+ "joint_positions": joints, # rotational joint + gripper state
+ "joint_velocities": joints,
+ "ee_pos_quat": pos_quat,
+ "gripper_position": np.array(state.gripper_pos()),
+ }
+
+
+def main():
+ ip = "192.168.1.226"
+ robot = XArmRobot(ip)
+ import time
+
+ time.sleep(1)
+ print(robot.get_state())
+
+ time.sleep(1)
+ print(robot.get_state())
+ print("end")
+ robot.stop()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/gello/zmq_core/camera_node.py b/gello/zmq_core/camera_node.py
new file mode 100644
index 0000000..932212c
--- /dev/null
+++ b/gello/zmq_core/camera_node.py
@@ -0,0 +1,69 @@
+import pickle
+import threading
+from typing import Optional, Tuple
+
+import numpy as np
+import zmq
+
+from gello.cameras.camera import CameraDriver
+
+DEFAULT_CAMERA_PORT = 5000
+
+
+class ZMQClientCamera(CameraDriver):
+ """A class representing a ZMQ client for a leader robot."""
+
+ def __init__(self, port: int = DEFAULT_CAMERA_PORT, host: str = "127.0.0.1"):
+ self._context = zmq.Context()
+ self._socket = self._context.socket(zmq.REQ)
+ self._socket.connect(f"tcp://{host}:{port}")
+
+ def read(
+ self,
+ img_size: Optional[Tuple[int, int]] = None,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Get the current state of the leader robot.
+
+ Returns:
+ T: The current state of the leader robot.
+ """
+ # pack the image_size and send it to the server
+ send_message = pickle.dumps(img_size)
+ self._socket.send(send_message)
+ state_dict = pickle.loads(self._socket.recv())
+ return state_dict
+
+
+class ZMQServerCamera:
+ def __init__(
+ self,
+ camera: CameraDriver,
+ port: int = DEFAULT_CAMERA_PORT,
+ host: str = "127.0.0.1",
+ ):
+ self._camera = camera
+ self._context = zmq.Context()
+ self._socket = self._context.socket(zmq.REP)
+ addr = f"tcp://{host}:{port}"
+ debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}"
+ print(debug_message)
+ self._timout_message = f"Timeout in Camera Server, Camera: {camera}"
+ self._socket.bind(addr)
+ self._stop_event = threading.Event()
+
+ def serve(self) -> None:
+ """Serve the leader robot state over ZMQ."""
+ self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
+ while not self._stop_event.is_set():
+ try:
+ message = self._socket.recv()
+ img_size = pickle.loads(message)
+ camera_read = self._camera.read(img_size)
+ self._socket.send(pickle.dumps(camera_read))
+ except zmq.Again:
+ print(self._timout_message)
+ # Timeout occurred, check if the stop event is set
+
+ def stop(self) -> None:
+ """Signal the server to stop serving."""
+ self._stop_event.set()
diff --git a/gello/zmq_core/robot_node.py b/gello/zmq_core/robot_node.py
new file mode 100644
index 0000000..bc4a636
--- /dev/null
+++ b/gello/zmq_core/robot_node.py
@@ -0,0 +1,125 @@
+import pickle
+import threading
+from typing import Any, Dict
+
+import numpy as np
+import zmq
+
+from gello.robots.robot import Robot
+
+DEFAULT_ROBOT_PORT = 6000
+
+
+class ZMQServerRobot:
+ def __init__(
+ self,
+ robot: Robot,
+ port: int = DEFAULT_ROBOT_PORT,
+ host: str = "127.0.0.1",
+ ):
+ self._robot = robot
+ self._context = zmq.Context()
+ self._socket = self._context.socket(zmq.REP)
+ addr = f"tcp://{host}:{port}"
+ debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}"
+ print(debug_message)
+ self._timout_message = f"Timeout in Robot Server, Robot: {robot}"
+ self._socket.bind(addr)
+ self._stop_event = threading.Event()
+
+ def serve(self) -> None:
+ """Serve the leader robot state over ZMQ."""
+ self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
+ while not self._stop_event.is_set():
+ try:
+ # Wait for next request from client
+ message = self._socket.recv()
+ request = pickle.loads(message)
+
+ # Call the appropriate method based on the request
+ method = request.get("method")
+ args = request.get("args", {})
+ result: Any
+ if method == "num_dofs":
+ result = self._robot.num_dofs()
+ elif method == "get_joint_state":
+ result = self._robot.get_joint_state()
+ elif method == "command_joint_state":
+ result = self._robot.command_joint_state(**args)
+ elif method == "get_observations":
+ result = self._robot.get_observations()
+ else:
+ result = {"error": "Invalid method"}
+ print(result)
+ raise NotImplementedError(
+ f"Invalid method: {method}, {args, result}"
+ )
+
+ self._socket.send(pickle.dumps(result))
+ except zmq.Again:
+ print(self._timout_message)
+ # Timeout occurred, check if the stop event is set
+
+ def stop(self) -> None:
+ """Signal the server to stop serving."""
+ self._stop_event.set()
+
+
+class ZMQClientRobot(Robot):
+ """A class representing a ZMQ client for a leader robot."""
+
+ def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"):
+ self._context = zmq.Context()
+ self._socket = self._context.socket(zmq.REQ)
+ self._socket.connect(f"tcp://{host}:{port}")
+
+ def num_dofs(self) -> int:
+ """Get the number of joints in the robot.
+
+ Returns:
+ int: The number of joints in the robot.
+ """
+ request = {"method": "num_dofs"}
+ send_message = pickle.dumps(request)
+ self._socket.send(send_message)
+ result = pickle.loads(self._socket.recv())
+ return result
+
+ def get_joint_state(self) -> np.ndarray:
+ """Get the current state of the leader robot.
+
+ Returns:
+ T: The current state of the leader robot.
+ """
+ request = {"method": "get_joint_state"}
+ send_message = pickle.dumps(request)
+ self._socket.send(send_message)
+ result = pickle.loads(self._socket.recv())
+ return result
+
+ def command_joint_state(self, joint_state: np.ndarray) -> None:
+ """Command the leader robot to the given state.
+
+ Args:
+ joint_state (T): The state to command the leader robot to.
+ """
+ request = {
+ "method": "command_joint_state",
+ "args": {"joint_state": joint_state},
+ }
+ send_message = pickle.dumps(request)
+ self._socket.send(send_message)
+ result = pickle.loads(self._socket.recv())
+ return result
+
+ def get_observations(self) -> Dict[str, np.ndarray]:
+ """Get the current observations of the leader robot.
+
+ Returns:
+ Dict[str, np.ndarray]: The current observations of the leader robot.
+ """
+ request = {"method": "get_observations"}
+ send_message = pickle.dumps(request)
+ self._socket.send(send_message)
+ result = pickle.loads(self._socket.recv())
+ return result
diff --git a/imgs/gello_matching_joints.jpg b/imgs/gello_matching_joints.jpg
new file mode 100644
index 0000000..ed92f87
Binary files /dev/null and b/imgs/gello_matching_joints.jpg differ
diff --git a/imgs/robot_known_configuration.jpg b/imgs/robot_known_configuration.jpg
new file mode 100644
index 0000000..5f8fc96
Binary files /dev/null and b/imgs/robot_known_configuration.jpg differ
diff --git a/imgs/title.png b/imgs/title.png
new file mode 100644
index 0000000..ee82708
Binary files /dev/null and b/imgs/title.png differ
diff --git a/kill_nodes.sh b/kill_nodes.sh
new file mode 100755
index 0000000..0108a90
--- /dev/null
+++ b/kill_nodes.sh
@@ -0,0 +1 @@
+pkill -9 -f launch_nodes.py
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000..0e89563
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,8 @@
+[mypy]
+ignore_missing_imports = True
+namespace_packages = True
+no_implicit_optional = True
+exclude = third_party/
+
+[mypy-dynamixel_sdk.*]
+ignore_missing_imports = True
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 0000000..d4697f8
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,12 @@
+{
+ "include": [
+ "gello"
+ ],
+ "exclude": [
+ "**/node_modules",
+ "src/typestubs"
+ ],
+ "reportMissingImports": false,
+ "reportMissingTypeStubs": false,
+ "reportGeneralTypeIssues": false
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5ac85ad
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+# keep in alphabetical order
+dm_control
+Pillow
+pygame
+pyspacemouse
+PyQt6
+pyquaternion
+pyrealsense2
+pure-python-adb
+quaternion
+tyro
+ur-rtde
+zmq
+xarm
+xarm-python-sdk
+numpy-quaternion
+termcolor
diff --git a/requirements_dev.txt b/requirements_dev.txt
new file mode 100644
index 0000000..7d8fe63
--- /dev/null
+++ b/requirements_dev.txt
@@ -0,0 +1,13 @@
+# keep in alphabetical order
+# for development
+black
+flake8
+flake8-docstrings
+isort
+ipdb
+jupyterlab
+mypy
+neovim
+pyright
+pytest
+python-lsp-server[all]
diff --git a/scripts/arm_blocks_play.py b/scripts/arm_blocks_play.py
new file mode 100644
index 0000000..851e85c
--- /dev/null
+++ b/scripts/arm_blocks_play.py
@@ -0,0 +1,59 @@
+from dataclasses import dataclass
+
+import numpy as np
+import tyro
+from dm_control import composer, viewer
+
+from gello.agents.gello_agent import DynamixelRobotConfig
+from gello.dm_control_tasks.arms.ur5e import UR5e
+from gello.dm_control_tasks.manipulation.arenas.floors import Floor
+from gello.dm_control_tasks.manipulation.tasks.block_play import BlockPlay
+
+
+@dataclass
+class Args:
+ use_gello: bool = False
+
+
+config = DynamixelRobotConfig(
+ joint_ids=(1, 2, 3, 4, 5, 6),
+ joint_offsets=(
+ -np.pi / 2,
+ 1 * np.pi / 2 + np.pi,
+ np.pi / 2 + 0 * np.pi,
+ 0 * np.pi + np.pi / 2,
+ np.pi - 2 * np.pi / 2,
+ -1 * np.pi / 2 + 2 * np.pi,
+ ),
+ joint_signs=(1, 1, -1, 1, 1, 1),
+ gripper_config=(7, 20, -22),
+)
+
+
+def main(args: Args) -> None:
+ reset_joints_left = np.deg2rad([90, -90, -90, -90, 90, 0, 0])
+ robot = UR5e()
+ task = BlockPlay(robot, Floor(), reset_joints=reset_joints_left[:-1])
+ # task = BlockPlay(robot, Floor())
+ env = composer.Environment(task=task)
+
+ action_space = env.action_spec()
+ if args.use_gello:
+ gello = config.make_robot(
+ port="/dev/cu.usbserial-FT7WBEIA", start_joints=reset_joints_left
+ )
+
+ def policy(timestep) -> np.ndarray:
+ if args.use_gello:
+ joint_command = gello.get_joint_state()
+ joint_command = np.array(joint_command).copy()
+
+ joint_command[-1] = joint_command[-1] * 255
+ return joint_command
+ return np.random.uniform(action_space.minimum, action_space.maximum)
+
+ viewer.launch(env, policy=policy)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/scripts/gello_get_offset.py b/scripts/gello_get_offset.py
new file mode 100644
index 0000000..4e5aa8c
--- /dev/null
+++ b/scripts/gello_get_offset.py
@@ -0,0 +1,98 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Tuple
+
+import numpy as np
+import tyro
+
+from gello.dynamixel.driver import DynamixelDriver
+
+MENAGERIE_ROOT: Path = Path(__file__).parent / "third_party" / "mujoco_menagerie"
+
+
+@dataclass
+class Args:
+ port: str = "/dev/ttyUSB0"
+ """The port that GELLO is connected to."""
+
+ start_joints: Tuple[float, ...] = (0, 0, 0, 0, 0, 0)
+ """The joint angles that the GELLO is placed in at (in radians)."""
+
+ joint_signs: Tuple[float, ...] = (1, 1, -1, 1, 1, 1)
+ """The joint angles that the GELLO is placed in at (in radians)."""
+
+ gripper: bool = True
+ """Whether or not the gripper is attached."""
+
+ def __post_init__(self):
+ assert len(self.joint_signs) == len(self.start_joints)
+ for idx, j in enumerate(self.joint_signs):
+ assert (
+ j == -1 or j == 1
+ ), f"Joint idx: {idx} should be -1 or 1, but got {j}."
+
+ @property
+ def num_robot_joints(self) -> int:
+ return len(self.start_joints)
+
+ @property
+ def num_joints(self) -> int:
+ extra_joints = 1 if self.gripper else 0
+ return self.num_robot_joints + extra_joints
+
+
+def get_config(args: Args) -> None:
+ joint_ids = list(range(1, args.num_joints + 1))
+ driver = DynamixelDriver(joint_ids, port=args.port, baudrate=57600)
+
+ # assume that the joint state shouold be args.start_joints
+ # find the offset, which is a multiple of np.pi/2 that minimizes the error between the current joint state and args.start_joints
+ # this is done by brute force, we seach in a range of +/- 8pi
+
+ def get_error(offset: float, index: int, joint_state: np.ndarray) -> float:
+ joint_sign_i = args.joint_signs[index]
+ joint_i = joint_sign_i * (joint_state[index] - offset)
+ start_i = args.start_joints[index]
+ return np.abs(joint_i - start_i)
+
+ for _ in range(10):
+ driver.get_joints() # warmup
+
+ for _ in range(1):
+ best_offsets = []
+ curr_joints = driver.get_joints()
+ for i in range(args.num_robot_joints):
+ best_offset = 0
+ best_error = 1e6
+ for offset in np.linspace(
+ -8 * np.pi, 8 * np.pi, 8 * 4 + 1
+ ): # intervals of pi/2
+ error = get_error(offset, i, curr_joints)
+ if error < best_error:
+ best_error = error
+ best_offset = offset
+ best_offsets.append(best_offset)
+ print()
+ print("best offsets : ", [f"{x:.3f}" for x in best_offsets])
+ print(
+ "best offsets function of pi: ["
+ + ", ".join([f"{int(np.round(x/(np.pi/2)))}*np.pi/2" for x in best_offsets])
+ + " ]",
+ )
+ if args.gripper:
+ print(
+ "gripper open (degrees) ",
+ np.rad2deg(driver.get_joints()[-1]) - 0.2,
+ )
+ print(
+ "gripper close (degrees) ",
+ np.rad2deg(driver.get_joints()[-1]) - 42,
+ )
+
+
+def main(args: Args) -> None:
+ get_config(args)
+
+
+if __name__ == "__main__":
+ main(tyro.cli(Args))
diff --git a/scripts/launch.py b/scripts/launch.py
new file mode 100644
index 0000000..5ba980d
--- /dev/null
+++ b/scripts/launch.py
@@ -0,0 +1,39 @@
+import os
+import subprocess
+
+current_file_path = os.path.abspath(__file__)
+
+
+def run_docker_container():
+ user = os.getenv("USER")
+ container_name = f"gello_{user}"
+ gello_path = os.path.abspath(os.path.join(current_file_path, "../../"))
+ volume_mapping = f"{gello_path}:/gello"
+
+ cmd = [
+ "docker",
+ "run",
+ "--runtime=nvidia",
+ "--rm",
+ "--name",
+ container_name,
+ "--privileged",
+ "--volume",
+ volume_mapping,
+ "--volume",
+ "/home/gello:/homefolder",
+ "--net=host",
+ "--volume",
+ "/dev/serial/by-id/:/dev/serial/by-id/",
+ "-it",
+ "gello:latest",
+ "bash",
+ "-c",
+ "pip install -e third_party/DynamixelSDK/python && exec bash",
+ ]
+
+ subprocess.run(cmd)
+
+
+if __name__ == "__main__":
+ run_docker_container()
diff --git a/scripts/visualize_example.py b/scripts/visualize_example.py
new file mode 100644
index 0000000..ba2b160
--- /dev/null
+++ b/scripts/visualize_example.py
@@ -0,0 +1,41 @@
+import time
+from pathlib import Path
+
+import mujoco
+import mujoco.viewer
+
+
+def main():
+ _PROJECT_ROOT: Path = Path(__file__).parent.parent
+ _MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie"
+ xml = _MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
+
+ # xml = _MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
+
+ m = mujoco.MjModel.from_xml_path(xml.as_posix())
+ d = mujoco.MjData(m)
+
+ with mujoco.viewer.launch_passive(m, d) as viewer:
+ # Close the viewer automatically after 30 wall-seconds.
+ while viewer.is_running():
+ step_start = time.time()
+
+ # mj_step can be replaced with code that also evaluates
+ # a policy and applies a control signal before stepping the physics.
+ mujoco.mj_step(m, d)
+
+ # Example modification of a viewer option: toggle contact points every two seconds.
+ with viewer.lock():
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(d.time % 2)
+
+ # Pick up changes to the physics state, apply perturbations, update options from GUI.
+ viewer.sync()
+
+ # Rudimentary time keeping, will drift relative to wall clock.
+ time_until_next_step = m.opt.timestep - (time.time() - step_start)
+ if time_until_next_step > 0:
+ time.sleep(time_until_next_step)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..7278068
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,25 @@
+import setuptools
+
+with open("README.md", "r") as fh:
+ long_description = fh.read()
+
+setuptools.setup(
+ name="gello",
+ version="0.0.1",
+ author="Philipp Wu",
+ author_email="philippwu@berkeley.edu",
+ description="software for GELLO",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url="https://github.com/wuphilipp/gello_software",
+ packages=setuptools.find_packages(),
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: MIT License",
+ ],
+ python_requires=">=3.8",
+ license="MIT",
+ install_requires=[
+ "numpy",
+ ],
+)
diff --git a/third_party/DynamixelSDK b/third_party/DynamixelSDK
new file mode 160000
index 0000000..3450c70
--- /dev/null
+++ b/third_party/DynamixelSDK
@@ -0,0 +1 @@
+Subproject commit 3450c7078917b262d9b36042c15444047aae226e
diff --git a/third_party/mujoco_menagerie b/third_party/mujoco_menagerie
new file mode 160000
index 0000000..c118f3d
--- /dev/null
+++ b/third_party/mujoco_menagerie
@@ -0,0 +1 @@
+Subproject commit c118f3d5d5ec9fb27f832ac78b3c4971234f0f4f