initial commit, add gello software code and instructions

This commit is contained in:
Philipp Wu 2023-11-13 09:17:27 -08:00
parent e7d842ad35
commit 18cc23a38e
70 changed files with 5875 additions and 4 deletions

3
.flake8 Normal file
View file

@ -0,0 +1,3 @@
[flake8]
ignore = D100, D101, D102, D103, D104, D105, D107, E203, E501, W503, SIM201, SIM113, B027, C408, B008
docstring-convention = google

44
.github/workflows/pythonapp.yml vendored Normal file
View file

@ -0,0 +1,44 @@
name: Python application
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
schedule:
- cron: "0 0 * * 0"
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
# python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements_dev.txt
pip install -r requirements.txt
pip install -e .
git submodule init
git submodule update
pip install third_party/DynamixelSDK/python
- name: Black
run: |
black --check gello
- name: flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 gello --count --select=E9,F63,F7,F82 --show-source --statistics
- name: Test with pytest
run: |
pytest gello

29
.gitignore vendored Normal file
View file

@ -0,0 +1,29 @@
# keep each top level section alphabetical
# keep each item within the sections alphabetical
# large files
*.png
*.gif
# mac
*.DS_Store
# misc
*logs*
*runs*
*tb_logs*
*wandb*
*wandb_logs*
outputs/*
# python
*__pycache__
*.egg-info
*.hypothesis
*.ipynb_checkpoints
*.mypy_cache
*.pyc
# vim
*.swp
MUJOCO_LOG.TXT

6
.gitmodules vendored Normal file
View file

@ -0,0 +1,6 @@
[submodule "third_party/mujoco_menagerie"]
path = third_party/mujoco_menagerie
url = https://github.com/deepmind/mujoco_menagerie.git
[submodule "third_party/DynamixelSDK"]
path = third_party/DynamixelSDK
url = https://github.com/ROBOTIS-GIT/DynamixelSDK.git

6
.isort.cfg Normal file
View file

@ -0,0 +1,6 @@
[settings]
use_parentheses=True
include_trailing_comma=True
multi_line_output=3
ensure_newline_before_comments=True
line_length=88

38
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,38 @@
repos:
# remove unused python imports
- repo: https://github.com/myint/autoflake.git
rev: v2.1.1
hooks:
- id: autoflake
args: ["--in-place", "--remove-all-unused-imports", "--ignore-init-module-imports"]
# sort imports
- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
hooks:
- id: isort
# code format according to black
- repo: https://github.com/ambv/black
rev: 23.3.0
hooks:
- id: black
# check for python styling with flake8
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [
'flake8-docstrings',
'flake8-bugbear',
'flake8-comprehensions',
'flake8-simplify',
]
# cleanup notebooks
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
- id: nbstripout

22
Dockerfile Normal file
View file

@ -0,0 +1,22 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
WORKDIR /gello
# Set environment variables first (less likely to change)
ENV PYTHONPATH=/gello:/gello/third_party/oculus_reader/
# Group apt updates and installs together
RUN apt update && apt install -y \
libhidapi-dev \
python3-pip \
android-tools-adb \
libegl1-mesa-dev && \
rm -rf /var/lib/apt/lists/*
# Python alias setup
RUN echo "alias python=python3" >> ~/.bashrc
# Install Python dependencies
COPY requirements.txt /gello
RUN pip install -r requirements.txt

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Philipp Wu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

180
README.md
View file

@ -1,10 +1,174 @@
# GELLO Software
This is the central repo that holds the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/
# GELLO
This is the central repo that holds the all the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/
See the GELLO hardware repo for the STL files and hardware instructions for building your own GELLO https://github.com/wuphilipp/gello_mechanical
```
git clone https://github.com/wuphilipp/gello_software.git
cd gello
```
# Installation
<p align="center">
<img src="imgs/title.png" />
</p>
Coming Soon
## Use your own enviroment
```
git submodule init
git submodule update
pip install -r requirements.txt
pip install -e .
pip install -e third_party/DynamixelSDK/python
```
## Use with Docker
First install ```docker``` following this [link](https://docs.docker.com/engine/install/ubuntu/) on your host machine.
Then you can clone the repo and build the corresponding docker environment
Build the docker image and tag it as gello:latest. If you are going to name it differently, you need to change the launch.py image name
```
docker build . -t gello:latest
```
We have provided an entry point into the docker container
```
python experiments/launch.py
```
# GELLO configuration setup (PLEASE READ)
Now that you have downloaded the code, there is some additional preparation work to properly configure the Dynamixels and GELLO.
These instructions will guide you on how to update the motor ids of the Dynamixels and then how to extract the joint offsets to configure your GELLO.
## Update motor IDs
Install the [dynamixel_wizard](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2/).
By default, each motor has the ID 1. In order for multiple dynamixels to be controlled by the same U2D2 controller board, each dynamixel must have a unique ID.
This process must be done one motor at a time. Connect each motor, starting from the base motor, and assign them in increasing order until you reach the gripper.
Steps:
* Connect a single motor to the controller and connect the controller to the computer.
* Open the dynamixel wizard
* Click scan (found at the top left corner), this should detect the dynamixel. Connect to the motor
* Look for the ID address and change the ID to the appropriate number.
* Repeat for each motor
## Create the GELLO configuration and determining joint ID's
After the motor ID's are set, we can now connect to the GELLO controller device. However each motor has its own joint offset, which will result in a joint offset between GELLO and your actual robot arm.
Dynamixels have a symmetric 4 hole pattern which means there the joint offset is a multiple of pi/2.
The `GelloAgent` class accepts a `DynamixelRobotConfig` (found in `gello/agents/gello_agent.py`). The Dynamixel config specifies the parameters you need to find to operate your GELLO. Look at the documentation for more details.
We have created a simple script to automatically detect the joint offset:
* set GELLO into a known configuration, where you know what the corresponding joint angles should be. For example, we set out GELLO in this configuration, where we know the desired ground truth joints. (0, -90, 90, -90, -90, 0)
<p align="center">
<img src="imgs/gello_matching_joints.jpg" width="45%"/>
<img src="imgs/robot_known_configuration.jpg" width="45%"/>
</p>
* run
```
python scripts/gello_get_offset.py \
--start-joints 0 -1.57 1.57 -1.57 -1.57 0 \ # in radians
--joint-signs 1 1 -1 1 1 1 \
--port /dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6
# replace values with your own
```
* Use the known starting joints for `start-joints`.
* Use the `joint-signs` for your own robot (see below).
* Use your serial port for `port`. You can find the port id of your U2D2 Dynamixel device by running `ls /dev/serial/by-id` and looking for the path that starts with `usb-FTDI_USB__-__Serial_Converter` (on Ubuntu). On Mac, look in /dev/ and the device that starts with `cu.usbserial`
`joint-signs` for each robot type:
* UR: `1 1 -1 1 1 1`
* Panda: `1 -1 1 1 1 -1 1`
* xArm: `1 1 1 1 1 1 1`
The script prints out a list of joint offsets. Go to `gello/agents/gello_agent.py` and add a DynamixelRobotConfig to the PORT_CONFIG_MAP. You are now ready to run your GELLO!
# Using GELLO to control a robot!
The code provided here is simple and only relies on python packages. The code does NOT use ROS, but a ROS wrapper can easily be adapted from this code.
For multiprocessing, we leverage [ZMQ](https://zeromq.org/)
## Testing in sim
First test your GELLO with a simulated robot to make sure that the joint angles match as expected.
In one terminal run
```
python experiments/launch_nodes.py --robot <sim_ur, sim_panda, or sim_xarm>
```
This launched the robot node. A simulated robot using the mujoco viewer should appear.
Then, launch your GELLO (the controller node).
```
python experiments/run_env.py --agent=gello
```
You should be able to use GELLO to control the simulated robot!
## Running on a real robot.
Once you have verified that your GELLO is properly configured, you can test it on a real robot!
Before you run with the real robot, you will have to install a robot specific python package.
The supported robots are in `gello/robots`.
* UR: [ur_rtde](https://sdurobotics.gitlab.io/ur_rtde/installation/installation.html)
* panda: [polymetis](https://facebookresearch.github.io/fairo/polymetis/installation.html). If you use a different framework to control the panda, the code is easy to adpot. See/Modify `gello/robots/panda.py`
* xArm: [xArm python SDK](https://github.com/xArm-Developer/xArm-Python-SDK)
```
# Launch all of the node
python experiments/launch_nodes.py --robot=<your robot>
# run the enviroment loop
python experiments/run_env.py --agent=gello
```
Ideally you can start your GELLO near a known configuration each time. If this is possible, you can set the `--start-joint` flag with GELLO's known starting configuration. This also enables the robot to reset before you begin teleoperation.
## Collect data
We have provided a simple example for collecting data with gello.
To save trajectories with the keyboard, add the following flag `--use-save-interface`
Data can then be processed using the demo_to_gdict script.
```
python gello/data_utils/demo_to_gdict.py --source-dir=<source dir location>
```
## Running a bimanual system with GELLO
GELLO also be used in bimanual configurations.
For an example, see the `bimanual_ur` robot in `launch_nodes.py` and `--bimanual` flag in the `run_env.py` script.
## Notes
Due to the use of multiprocessing, sometimes python process are not killed properly. We have provided the kill_nodes script which will kill the
python processes.
```
./kill_nodes.sh
```
### Using a new robot!
If you want to use a new robot you need a GELLO that is compatible. If the kiniamtics are close enough, you may directly use an existing GELLO. Otherwise you will have to design your own.
To add a new robot, simply implement the `Robot` protocol found in `gello/robots/robot`. See `gello/robots/panda.py`, `gello/robots/ur.py`, `gello/robots/xarm_robot.py` for examples.
### Contributing
Please make a PR if you would like to contribute! The goal of this project is to enable more accessible and higher quality teleoperation devices and we would love your input!
You can optionally install some dev packages.
```
pip install -r requirements_dev.txt
```
The code is organized as follows:
* `scripts`: contains some helpful python `scripts`
* `experiments`: contains entrypoints into the gello code
* `gello`: contains all of the `gello` python package code
* `agents`: teleoperation agents
* `cameras`: code to interface with camera hardware
* `data_utils`: data processing utils. used for imitation learning
* `dm_control_tasks`: dm_control utils to build a simple dm_control enviroment. used for demos
* `dynamixel`: code to interface with the dynamixel hardware
* `robots`: robot specific interfaces
* `zmq_core`: zmq utilities for enabling a multi node system
This code base uses `isort` and `black` for code formatting.
pre-commits hooks are great. This will automatically do some checking/formatting. To use the pre-commit hooks, run the following:
```
pip install pre-commit
pre-commit install
```
# Citation
@ -15,3 +179,11 @@ Coming Soon
year={2023},
}
```
# License & Acknowledgements
This source code is licensed under the MIT license found in the LICENSE file. in the root directory of this source tree.
This project builds on top of or utilizes the following third party dependencies.
* [google-deepmind/mujoco_menagerie](https://github.com/google-deepmind/mujoco_menagerie): Prebuilt robot models for mujoco
* [brentyi/tyro](https://github.com/brentyi/tyro): Argument parsing and configuration
* [ZMQ](https://zeromq.org/): Enables easy create of node like processes in python.

20
config_hostmachine.sh Normal file
View file

@ -0,0 +1,20 @@
#!/bin/bash
# install nvidia driver 520
sudo apt-get install nvidia-520 -y
# install docker
sudo apt-get update
sudo apt-get install ca-certificates curl gnupg
sudo install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
sudo chmod a+r /etc/apt/keyrings/docker.gpg
echo \
"deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update
sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin

View file

@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import tyro
from gello.zmq_core.camera_node import ZMQClientCamera
@dataclass
class Args:
ports: Tuple[int, ...] = (5000, 5001)
hostname: str = "127.0.0.1"
# hostname: str = "128.32.175.167"
def main(args):
cameras = []
import cv2
images_display_names = []
for port in args.ports:
cameras.append(ZMQClientCamera(port=port, host=args.hostname))
images_display_names.append(f"image_{port}")
cv2.namedWindow(images_display_names[-1], cv2.WINDOW_NORMAL)
while True:
for display_name, camera in zip(images_display_names, cameras):
image, depth = camera.read()
stacked_depth = np.dstack([depth, depth, depth]).astype(np.uint8)
image_depth = cv2.hconcat([image[:, :, ::-1], stacked_depth])
cv2.imshow(display_name, image_depth)
cv2.waitKey(1)
if __name__ == "__main__":
main(tyro.cli(Args))

View file

@ -0,0 +1,40 @@
from dataclasses import dataclass
from multiprocessing import Process
import tyro
from gello.cameras.realsense_camera import RealSenseCamera, get_device_ids
from gello.zmq_core.camera_node import ZMQServerCamera
@dataclass
class Args:
# hostname: str = "127.0.0.1"
hostname: str = "128.32.175.167"
def launch_server(port: int, camera_id: int, args: Args):
camera = RealSenseCamera(camera_id)
server = ZMQServerCamera(camera, port=port, host=args.hostname)
print(f"Starting camera server on port {port}")
server.serve()
def main(args):
ids = get_device_ids()
camera_port = 5000
camera_servers = []
for camera_id in ids:
# start a python process for each camera
print(f"Launching camera {camera_id} on port {camera_port}")
camera_servers.append(
Process(target=launch_server, args=(camera_port, camera_id, args))
)
camera_port += 1
for server in camera_servers:
server.start()
if __name__ == "__main__":
main(tyro.cli(Args))

View file

@ -0,0 +1,94 @@
from dataclasses import dataclass
from pathlib import Path
import tyro
from gello.robots.robot import BimanualRobot, PrintRobot
from gello.zmq_core.robot_node import ZMQServerRobot
@dataclass
class Args:
robot: str = "xarm"
robot_port: int = 6001
hostname: str = "127.0.0.1"
robot_ip: str = "192.168.1.10"
def launch_robot_server(args: Args):
port = args.robot_port
if args.robot == "sim_ur":
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
from gello.robots.sim_robot import MujocoRobotServer
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_panda":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_xarm":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "ufactory_xarm7" / "xarm7.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
else:
if args.robot == "xarm":
from gello.robots.xarm_robot import XArmRobot
robot = XArmRobot(ip=args.robot_ip)
elif args.robot == "ur":
from gello.robots.ur import URRobot
robot = URRobot(robot_ip=args.robot_ip)
elif args.robot == "panda":
from gello.robots.panda import PandaRobot
robot = PandaRobot(robot_ip=args.robot_ip)
elif args.robot == "bimanual_ur":
from gello.robots.ur import URRobot
# IP for the bimanual robot setup is hardcoded
_robot_l = URRobot(robot_ip="192.168.2.10")
_robot_r = URRobot(robot_ip="192.168.1.10")
robot = BimanualRobot(_robot_l, _robot_r)
elif args.robot == "none" or args.robot == "print":
robot = PrintRobot(8)
else:
raise NotImplementedError(
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
)
server = ZMQServerRobot(robot, port=port, host=args.hostname)
print(f"Starting robot server on port {port}")
server.serve()
def main(args):
launch_robot_server(args)
if __name__ == "__main__":
main(tyro.cli(Args))

192
experiments/quick_run.py Normal file
View file

@ -0,0 +1,192 @@
import atexit
import glob
import time
from dataclasses import dataclass
from multiprocessing import Process
from pathlib import Path
from typing import Optional
import numpy as np
import tyro
from gello.agents.agent import DummyAgent
from gello.agents.gello_agent import GelloAgent
from gello.agents.spacemouse_agent import SpacemouseAgent
from gello.env import RobotEnv
from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
@dataclass
class Args:
hz: int = 100
agent: str = "gello"
robot: str = "ur5"
gello_port: Optional[str] = None
mock: bool = False
verbose: bool = False
hostname: str = "127.0.0.1"
robot_port: int = 6001
def launch_robot_server(port: int, args: Args):
if args.robot == "sim_ur":
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
from gello.robots.sim_robot import MujocoRobotServer
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
elif args.robot == "sim_panda":
from gello.robots.sim_robot import MujocoRobotServer
MENAGERIE_ROOT: Path = (
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
)
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
gripper_xml = None
server = MujocoRobotServer(
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
)
server.serve()
else:
if args.robot == "xarm":
from gello.robots.xarm_robot import XArmRobot
robot = XArmRobot()
elif args.robot == "ur5":
from gello.robots.ur import URRobot
robot = URRobot(robot_ip=args.robot_ip)
else:
raise NotImplementedError(
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
)
server = ZMQServerRobot(robot, port=port, host=args.hostname)
print(f"Starting robot server on port {port}")
server.serve()
def start_robot_process(args: Args):
process = Process(target=launch_robot_server, args=(args.robot_port, args))
# Function to kill the child process
def kill_child_process(process):
print("Killing child process...")
process.terminate()
# Register the kill_child_process function to be called at exit
atexit.register(kill_child_process, process)
process.start()
def main(args: Args):
start_robot_process(args)
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
env = RobotEnv(robot_client, control_rate_hz=args.hz)
if args.agent == "gello":
gello_port = args.gello_port
if gello_port is None:
usb_ports = glob.glob("/dev/serial/by-id/*")
print(f"Found {len(usb_ports)} ports")
if len(usb_ports) > 0:
gello_port = usb_ports[0]
print(f"using port {gello_port}")
else:
raise ValueError(
"No gello port found, please specify one or plug in gello"
)
agent = GelloAgent(port=gello_port)
reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
curr_joints = env.get_obs()["joint_positions"]
if reset_joints.shape == curr_joints.shape:
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
time.sleep(0.001)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
elif args.agent == "spacemouse":
agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
elif args.agent == "dummy" or args.agent == "none":
agent = DummyAgent(num_dofs=robot_client.num_dofs())
else:
raise ValueError("Invalid agent name")
# going to start position
print("Going to start position")
start_pos = agent.act(env.get_obs())
obs = env.get_obs()
joints = obs["joint_positions"]
abs_deltas = np.abs(start_pos - joints)
id_max_joint_delta = np.argmax(abs_deltas)
max_joint_delta = 0.8
if abs_deltas[id_max_joint_delta] > max_joint_delta:
id_mask = abs_deltas > max_joint_delta
print()
ids = np.arange(len(id_mask))[id_mask]
for i, delta, joint, current_j in zip(
ids,
abs_deltas[id_mask],
start_pos[id_mask],
joints[id_mask],
):
print(
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
)
return
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
assert len(start_pos) == len(
joints
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
max_delta = 0.05
for _ in range(25):
obs = env.get_obs()
command_joints = agent.act(obs)
current_joints = obs["joint_positions"]
delta = command_joints - current_joints
max_joint_delta = np.abs(delta).max()
if max_joint_delta > max_delta:
delta = delta / max_joint_delta * max_delta
env.step(current_joints + delta)
obs = env.get_obs()
joints = obs["joint_positions"]
action = agent.act(obs)
if (action - joints > 0.5).any():
print("Action is too big")
# print which joints are too big
joint_index = np.where(action - joints > 0.8)
for j in joint_index:
print(
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
)
exit()
while True:
action = agent.act(obs)
obs = env.step(action)
if __name__ == "__main__":
main(tyro.cli(Args))

246
experiments/run_env.py Normal file
View file

@ -0,0 +1,246 @@
import datetime
import glob
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import tyro
from gello.agents.agent import BimanualAgent, DummyAgent
from gello.agents.gello_agent import GelloAgent
from gello.data_utils.format_obs import save_frame
from gello.env import RobotEnv
from gello.robots.robot import PrintRobot
from gello.zmq_core.robot_node import ZMQClientRobot
def print_color(*args, color=None, attrs=(), **kwargs):
import termcolor
if len(args) > 0:
args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
print(*args, **kwargs)
@dataclass
class Args:
agent: str = "none"
robot_port: int = 6001
wrist_camera_port: int = 5000
base_camera_port: int = 5001
hostname: str = "127.0.0.1"
robot_type: str = None # only needed for quest agent or spacemouse agent
hz: int = 100
start_joints: Optional[Tuple[float, ...]] = None
gello_port: Optional[str] = None
mock: bool = False
use_save_interface: bool = False
data_dir: str = "~/bc_data"
bimanual: bool = False
verbose: bool = False
def main(args):
if args.mock:
robot_client = PrintRobot(8, dont_print=True)
camera_clients = {}
else:
camera_clients = {
# you can optionally add camera nodes here for imitation learning purposes
# "wrist": ZMQClientCamera(port=args.wrist_camera_port, host=args.hostname),
# "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
}
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
env = RobotEnv(robot_client, control_rate_hz=args.hz, camera_dict=camera_clients)
if args.bimanual:
if args.agent == "gello":
# dynamixel control box port map (to distinguish left and right gello)
right = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0"
left = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0"
left_agent = GelloAgent(port=left)
right_agent = GelloAgent(port=right)
agent = BimanualAgent(left_agent, right_agent)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
right_agent = SingleArmQuestAgent(
robot_type=args.robot_type, which_hand="r"
)
agent = BimanualAgent(left_agent, right_agent)
# raise NotImplementedError
elif args.agent == "spacemouse":
from gello.agents.spacemouse_agent import SpacemouseAgent
left_path = "/dev/hidraw0"
right_path = "/dev/hidraw1"
left_agent = SpacemouseAgent(
robot_type=args.robot_type, device_path=left_path, verbose=args.verbose
)
right_agent = SpacemouseAgent(
robot_type=args.robot_type,
device_path=right_path,
verbose=args.verbose,
invert_button=True,
)
agent = BimanualAgent(left_agent, right_agent)
else:
raise ValueError(f"Invalid agent name for bimanual: {args.agent}")
# System setup specific. This reset configuration works well on our setup. If you are mounting the robot
# differently, you need a separate reset joint configuration.
reset_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0, 0])
reset_joints_right = np.deg2rad([0, -90, 90, -90, -90, 0, 0])
reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
curr_joints = env.get_obs()["joint_positions"]
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
else:
if args.agent == "gello":
gello_port = args.gello_port
if gello_port is None:
usb_ports = glob.glob("/dev/serial/by-id/*")
print(f"Found {len(usb_ports)} ports")
if len(usb_ports) > 0:
gello_port = usb_ports[0]
print(f"using port {gello_port}")
else:
raise ValueError(
"No gello port found, please specify one or plug in gello"
)
if args.start_joints is None:
reset_joints = np.deg2rad(
[0, -90, 90, -90, -90, 0, 0]
) # Change this to your own reset joints
else:
reset_joints = args.start_joints
agent = GelloAgent(port=gello_port, start_joints=args.start_joints)
curr_joints = env.get_obs()["joint_positions"]
if reset_joints.shape == curr_joints.shape:
max_delta = (np.abs(curr_joints - reset_joints)).max()
steps = min(int(max_delta / 0.01), 100)
for jnt in np.linspace(curr_joints, reset_joints, steps):
env.step(jnt)
time.sleep(0.001)
elif args.agent == "quest":
from gello.agents.quest_agent import SingleArmQuestAgent
agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
elif args.agent == "spacemouse":
from gello.agents.spacemouse_agent import SpacemouseAgent
agent = SpacemouseAgent(robot_type=args.robot_type, verbose=args.verbose)
elif args.agent == "dummy" or args.agent == "none":
agent = DummyAgent(num_dofs=robot_client.num_dofs())
elif args.agent == "policy":
raise NotImplementedError("add your imitation policy here if there is one")
else:
raise ValueError("Invalid agent name")
# going to start position
print("Going to start position")
start_pos = agent.act(env.get_obs())
obs = env.get_obs()
joints = obs["joint_positions"]
abs_deltas = np.abs(start_pos - joints)
id_max_joint_delta = np.argmax(abs_deltas)
max_joint_delta = 0.8
if abs_deltas[id_max_joint_delta] > max_joint_delta:
id_mask = abs_deltas > max_joint_delta
print()
ids = np.arange(len(id_mask))[id_mask]
for i, delta, joint, current_j in zip(
ids,
abs_deltas[id_mask],
start_pos[id_mask],
joints[id_mask],
):
print(
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
)
return
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
assert len(start_pos) == len(
joints
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
max_delta = 0.05
for _ in range(25):
obs = env.get_obs()
command_joints = agent.act(obs)
current_joints = obs["joint_positions"]
delta = command_joints - current_joints
max_joint_delta = np.abs(delta).max()
if max_joint_delta > max_delta:
delta = delta / max_joint_delta * max_delta
env.step(current_joints + delta)
obs = env.get_obs()
joints = obs["joint_positions"]
action = agent.act(obs)
if (action - joints > 0.5).any():
print("Action is too big")
# print which joints are too big
joint_index = np.where(action - joints > 0.8)
for j in joint_index:
print(
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
)
exit()
if args.use_save_interface:
from gello.data_utils.keyboard_interface import KBReset
kb_interface = KBReset()
print_color("\nStart 🚀🚀🚀", color="green", attrs=("bold",))
save_path = None
start_time = time.time()
while True:
num = time.time() - start_time
message = f"\rTime passed: {round(num, 2)} "
print_color(
message,
color="white",
attrs=("bold",),
end="",
flush=True,
)
action = agent.act(obs)
dt = datetime.datetime.now()
if args.use_save_interface:
state = kb_interface.update()
if state == "start":
dt_time = datetime.datetime.now()
save_path = (
Path(args.data_dir).expanduser()
/ args.agent
/ dt_time.strftime("%m%d_%H%M%S")
)
save_path.mkdir(parents=True, exist_ok=True)
print(f"Saving to {save_path}")
elif state == "save":
assert save_path is not None, "something went wrong"
save_frame(save_path, dt, obs, action)
elif state == "normal":
save_path = None
else:
raise ValueError(f"Invalid state {state}")
obs = env.step(action)
if __name__ == "__main__":
main(tyro.cli(Args))

0
gello/__init__.py Normal file
View file

43
gello/agents/agent.py Normal file
View file

@ -0,0 +1,43 @@
from typing import Any, Dict, Protocol
import numpy as np
class Agent(Protocol):
def act(self, obs: Dict[str, Any]) -> np.ndarray:
"""Returns an action given an observation.
Args:
obs: observation from the environment.
Returns:
action: action to take on the environment.
"""
raise NotImplementedError
class DummyAgent(Agent):
def __init__(self, num_dofs: int):
self.num_dofs = num_dofs
def act(self, obs: Dict[str, Any]) -> np.ndarray:
return np.zeros(self.num_dofs)
class BimanualAgent(Agent):
def __init__(self, agent_left: Agent, agent_right: Agent):
self.agent_left = agent_left
self.agent_right = agent_right
def act(self, obs: Dict[str, Any]) -> np.ndarray:
left_obs = {}
right_obs = {}
for key, val in obs.items():
L = val.shape[0]
half_dim = L // 2
assert L == half_dim * 2, f"{key} must be even, something is wrong"
left_obs[key] = val[:half_dim]
right_obs[key] = val[half_dim:]
return np.concatenate(
[self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
)

139
gello/agents/gello_agent.py Normal file
View file

@ -0,0 +1,139 @@
import os
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Tuple
import numpy as np
from gello.agents.agent import Agent
from gello.robots.dynamixel import DynamixelRobot
@dataclass
class DynamixelRobotConfig:
joint_ids: Sequence[int]
"""The joint ids of GELLO (not including the gripper). Usually (1, 2, 3 ...)."""
joint_offsets: Sequence[float]
"""The joint offsets of GELLO. There needs to be a joint offset for each joint_id and should be a multiple of pi/2."""
joint_signs: Sequence[int]
"""The joint signs of GELLO. There needs to be a joint sign for each joint_id and should be either 1 or -1.
This will be different for each arm design. Refernce the examples below for the correct signs for your robot.
"""
gripper_config: Tuple[int, int, int]
"""The gripper config of GELLO. This is a tuple of (gripper_joint_id, degrees in open_position, degrees in closed_position)."""
def __post_init__(self):
assert len(self.joint_ids) == len(self.joint_offsets)
assert len(self.joint_ids) == len(self.joint_signs)
def make_robot(
self, port: str = "/dev/ttyUSB0", start_joints: Optional[np.ndarray] = None
) -> DynamixelRobot:
return DynamixelRobot(
joint_ids=self.joint_ids,
joint_offsets=list(self.joint_offsets),
real=True,
joint_signs=list(self.joint_signs),
port=port,
gripper_config=self.gripper_config,
start_joints=start_joints,
)
PORT_CONFIG_MAP: Dict[str, DynamixelRobotConfig] = {
# xArm
# "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig(
# joint_ids=(1, 2, 3, 4, 5, 6, 7),
# joint_offsets=(
# 2 * np.pi / 2,
# 2 * np.pi / 2,
# 2 * np.pi / 2,
# 2 * np.pi / 2,
# -1 * np.pi / 2 + 2 * np.pi,
# 1 * np.pi / 2,
# 1 * np.pi / 2,
# ),
# joint_signs=(1, 1, 1, 1, 1, 1, 1),
# gripper_config=(8, 279, 279 - 50),
# ),
# panda
# "/dev/cu.usbserial-FT3M9NVB": DynamixelRobotConfig(
"/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig(
joint_ids=(1, 2, 3, 4, 5, 6, 7),
joint_offsets=(
3 * np.pi / 2,
2 * np.pi / 2,
1 * np.pi / 2,
4 * np.pi / 2,
-2 * np.pi / 2 + 2 * np.pi,
3 * np.pi / 2,
4 * np.pi / 2,
),
joint_signs=(1, -1, 1, 1, 1, -1, 1),
gripper_config=(8, 195, 152),
),
# Left UR
"/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0": DynamixelRobotConfig(
joint_ids=(1, 2, 3, 4, 5, 6),
joint_offsets=(
0,
1 * np.pi / 2 + np.pi,
np.pi / 2 + 0 * np.pi,
0 * np.pi + np.pi / 2,
np.pi - 2 * np.pi / 2,
-1 * np.pi / 2 + 2 * np.pi,
),
joint_signs=(1, 1, -1, 1, 1, 1),
gripper_config=(7, 20, -22),
),
# Right UR
"/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0": DynamixelRobotConfig(
joint_ids=(1, 2, 3, 4, 5, 6),
joint_offsets=(
np.pi + 0 * np.pi,
2 * np.pi + np.pi / 2,
2 * np.pi + np.pi / 2,
2 * np.pi + np.pi / 2,
1 * np.pi,
3 * np.pi / 2,
),
joint_signs=(1, 1, -1, 1, 1, 1),
gripper_config=(7, 286, 248),
),
}
class GelloAgent(Agent):
def __init__(
self,
port: str,
dynamixel_config: Optional[DynamixelRobotConfig] = None,
start_joints: Optional[np.ndarray] = None,
):
if dynamixel_config is not None:
self._robot = dynamixel_config.make_robot(
port=port, start_joints=start_joints
)
else:
assert os.path.exists(port), port
assert port in PORT_CONFIG_MAP, f"Port {port} not in config map"
config = PORT_CONFIG_MAP[port]
self._robot = config.make_robot(port=port, start_joints=start_joints)
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
return self._robot.get_joint_state()
dyna_joints = self._robot.get_joint_state()
# current_q = dyna_joints[:-1] # last one dim is the gripper
current_gripper = dyna_joints[-1] # last one dim is the gripper
print(current_gripper)
if current_gripper < 0.2:
self._robot.set_torque_mode(False)
return obs["joint_positions"]
else:
self._robot.set_torque_mode(False)
return dyna_joints

178
gello/agents/quest_agent.py Normal file
View file

@ -0,0 +1,178 @@
from typing import Dict
import numpy as np
import quaternion
from dm_control import mjcf
from dm_control.utils.inverse_kinematics import qpos_from_site_pose
from oculus_reader.reader import OculusReader
from gello.agents.agent import Agent
from gello.agents.spacemouse_agent import apply_transfer, mj2ur, ur2mj
from gello.dm_control_tasks.arms.ur5e import UR5e
# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on
# our setup, for details please checkout the project page.
quest2ur = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
ur2quest = np.linalg.inv(quest2ur)
translation_scaling_factor = 2.0
class SingleArmQuestAgent(Agent):
def __init__(self, robot_type: str, which_hand: str, verbose: bool = False) -> None:
"""Interact with the robot using the quest controller.
leftTrig: press to start control (also record the current position as the home position)
leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper
"""
self.which_hand = which_hand
assert self.which_hand in ["l", "r"]
self.oculus_reader = OculusReader()
if robot_type == "ur5":
_robot = UR5e()
else:
raise ValueError(f"Unknown robot type: {robot_type}")
self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model)
self.control_active = False
self.reference_quest_pose = None
self.reference_ee_rot_ur = None
self.reference_ee_pos_ur = None
self.robot_type = robot_type
self._verbose = verbose
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
if self.robot_type == "ur5":
num_dof = 6
current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
current_gripper_angle = obs["joint_positions"][-1]
# run the fk
self.physics.data.qpos[:num_dof] = current_qpos
self.physics.step()
ee_rot_mj = np.array(
self.physics.named.data.site_xmat["attachment_site"]
).reshape(3, 3)
ee_pos_mj = np.array(self.physics.named.data.site_xpos["attachment_site"])
if self.which_hand == "l":
pose_key = "l"
trigger_key = "leftTrig"
# joystick_key = "leftJS"
# left yx
gripper_open_key = "Y"
gripper_close_key = "X"
elif self.which_hand == "r":
pose_key = "r"
trigger_key = "rightTrig"
# joystick_key = "rightJS"
# right ba for the key
gripper_open_key = "B"
gripper_close_key = "A"
else:
raise ValueError(f"Unknown hand: {self.which_hand}")
# check the trigger button state
pose_data, button_data = self.oculus_reader.get_transformations_and_buttons()
if len(pose_data) == 0 or len(button_data) == 0:
print("no data, quest not yet ready")
return np.concatenate([current_qpos, [current_gripper_angle]])
new_gripper_angle = current_gripper_angle
if button_data[gripper_open_key]:
new_gripper_angle = 1
if button_data[gripper_close_key]:
new_gripper_angle = 0
arm_not_move_return = np.concatenate([current_qpos, [new_gripper_angle]])
if len(pose_data) == 0:
print("no data, quest not yet ready")
return arm_not_move_return
trigger_state = button_data[trigger_key][0]
if trigger_state > 0.5:
if self.control_active is True:
if self._verbose:
print("controlling the arm")
current_pose = pose_data[pose_key]
delta_rot = current_pose[:3, :3] @ np.linalg.inv(
self.reference_quest_pose[:3, :3]
)
delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3]
delta_pos_ur = (
apply_transfer(quest2ur, delta_pos) * translation_scaling_factor
)
# ? is this the case?
delta_rot_ur = quest2ur[:3, :3] @ delta_rot @ ur2quest[:3, :3]
if self._verbose:
print(
f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}"
)
next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur
next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur
target_quat = quaternion.as_float_array(
quaternion.from_rotation_matrix(ur2mj[:3, :3] @ next_ee_rot_ur)
)
ik_result = qpos_from_site_pose(
self.physics,
"attachment_site",
target_pos=apply_transfer(ur2mj, next_ee_pos_ur),
target_quat=target_quat,
tol=1e-14,
max_steps=400,
)
self.physics.reset()
if ik_result.success:
new_qpos = ik_result.qpos[:num_dof]
else:
print("ik failed, using the original qpos")
return arm_not_move_return
command = np.concatenate([new_qpos, [new_gripper_angle]])
return command
else: # last state is not in active
self.control_active = True
if self._verbose:
print("control activated!")
self.reference_quest_pose = pose_data[pose_key]
self.reference_ee_rot_ur = mj2ur[:3, :3] @ ee_rot_mj
self.reference_ee_pos_ur = apply_transfer(mj2ur, ee_pos_mj)
return arm_not_move_return
else:
if self._verbose:
print("deactive control")
self.control_active = False
self.reference_quest_pose = None
return arm_not_move_return
class DualArmQuestAgent(Agent):
def __init__(self, robot_type: str) -> None:
self.left_arm = SingleArmQuestAgent(robot_type, "l")
self.right_arm = SingleArmQuestAgent(robot_type, "r")
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
pass
if __name__ == "__main__":
oculus_reader = OculusReader()
while True:
"""
example output:
({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ],
[-0.107737 , 0.0958919, 0.989544 , -0.833478 ],
[ 0.549685 , 0.835106 , -0.0210789, -0.892425 ],
[ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ],
[ 0.070887, 0.516083, -0.8536 , -0.238691],
[-0.941994, -0.246809, -0.227447, -0.370447],
[ 0. , 0. , 0. , 1. ]])},
{'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)})
"""
pose_data, button_data = oculus_reader.get_transformations_and_buttons()
if len(pose_data) == 0:
print("no data")
continue
else:
print(pose_data["l"])

View file

@ -0,0 +1,224 @@
import threading
import time
from dataclasses import dataclass
from typing import Dict, Optional
import numpy as np
from dm_control import mjcf
from dm_control.utils.inverse_kinematics import qpos_from_site_pose
from gello.agents.agent import Agent
from gello.dm_control_tasks.arms.ur5e import UR5e
# mujoco has a slightly different coordinate system than UR control box
mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
ur2mj = np.linalg.inv(mj2ur)
# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on
# our setup, for details please checkout the project page.
spacemouse2ur = np.array(
[
[-1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
ur2spacemouse = np.linalg.inv(spacemouse2ur)
def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray:
# xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix
if len(xyz) == 3:
xyz = np.append(xyz, 1)
return np.matmul(mat, xyz)[:3]
@dataclass
class SpacemouseConfig:
angle_scale: float = 0.24
translation_scale: float = 0.06
# only control the xyz, rotation direction, not the gripper
invert_control: np.ndarray = np.ones(6)
rotation_mode: str = "euler"
class SpacemouseAgent(Agent):
def __init__(
self,
robot_type: str,
config: SpacemouseConfig = SpacemouseConfig(),
device_path: Optional[str] = None,
verbose: bool = True,
invert_button: bool = False,
) -> None:
self.config = config
self.last_state_lock = threading.Lock()
self._invert_button = invert_button
# example state:SpaceNavigator(t=3.581528532, x=0.0, y=0.0, z=0.0, roll=0.0, pitch=0.0, yaw=0.0, buttons=[0, 0])
# all continuous inputs range from 0-1, buttons are 0 or 1
self.spacemouse_latest_state = None
self._device_path = device_path
spacemouse_thread = threading.Thread(target=self._read_from_spacemouse)
spacemouse_thread.start()
self._verbose = verbose
if self._verbose:
print(f"robot_type: {robot_type}")
if robot_type == "ur5":
_robot = UR5e()
else:
raise ValueError(f"Unknown robot type: {robot_type}")
self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model)
def _read_from_spacemouse(self):
import pyspacemouse
if self._device_path is None:
mouse = pyspacemouse.open()
else:
mouse = pyspacemouse.open(path=self._device_path)
if mouse:
while 1:
state = mouse.read()
with self.last_state_lock:
self.spacemouse_latest_state = state
time.sleep(0.001)
else:
raise ValueError("Failed to open spacemouse")
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
import quaternion
# obs: the folllow robot's current state
# in rad, 6/7 dof depends on the robot type
num_dof = 6
if self._verbose:
print("act invoked")
current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
current_gripper_angle = obs["joint_positions"][-1]
self.physics.data.qpos[:num_dof] = current_qpos
self.physics.step()
ee_rot = np.array(self.physics.named.data.site_xmat["attachment_site"]).reshape(
3, 3
)
ee_pos = np.array(self.physics.named.data.site_xpos["attachment_site"])
ee_rot = mj2ur[:3, :3] @ ee_rot
ee_pos = apply_transfer(mj2ur, ee_pos)
# ^ mujoco coordinate to UR
with self.last_state_lock:
spacemouse_read = self.spacemouse_latest_state
if self._verbose:
print(f"spacemouse_read: {spacemouse_read}")
assert spacemouse_read is not None
spacemouse_xyz_rot_np = np.array(
[
spacemouse_read.x,
spacemouse_read.y,
spacemouse_read.z,
spacemouse_read.roll,
spacemouse_read.pitch,
spacemouse_read.yaw,
]
)
spacemouse_button = (
spacemouse_read.buttons
) # size 2 list binary indicating left/right button pressing
spacemouse_xyz_rot_np = spacemouse_xyz_rot_np * self.config.invert_control
if np.max(np.abs(spacemouse_xyz_rot_np)) > 0.9:
spacemouse_xyz_rot_np[np.abs(spacemouse_xyz_rot_np) < 0.6] = 0
tx, ty, tz, r, p, y = spacemouse_xyz_rot_np
# convert roll pick yaw to rotation matrix (rpy)
trans_transform = np.eye(4)
# delta xyz from the spacemouse reading
trans_transform[:3, 3] = apply_transfer(
spacemouse2ur, np.array([tx, ty, tz]) * self.config.translation_scale
)
# break rot_transform into each axis
rot_transform_x = np.eye(4)
rot_transform_x[:3, :3] = quaternion.as_rotation_matrix(
quaternion.from_rotation_vector(
np.array([-p, 0, 0]) * self.config.angle_scale
)
)
rot_transform_y = np.eye(4)
rot_transform_y[:3, :3] = quaternion.as_rotation_matrix(
quaternion.from_rotation_vector(
np.array([0, r, 0]) * self.config.angle_scale
)
)
rot_transform_z = np.eye(4)
rot_transform_z[:3, :3] = quaternion.as_rotation_matrix(
quaternion.from_rotation_vector(
np.array([0, 0, -y]) * self.config.angle_scale
)
)
# in ur space
rot_transform = (
spacemouse2ur
@ rot_transform_z
@ rot_transform_y
@ rot_transform_x
@ ur2spacemouse
)
if self._verbose:
print(f"rot_transform: {rot_transform}")
print(f"new spacemounse cmd in ur space = {trans_transform[:3, 3]}")
# import pdb; pdb.set_trace()
new_ee_pos = trans_transform[:3, 3] + ee_pos
if self.config.rotation_mode == "rpy":
new_ee_rot = ee_rot @ rot_transform[:3, :3]
elif self.config.rotation_mode == "euler":
new_ee_rot = rot_transform[:3, :3] @ ee_rot
else:
raise NotImplementedError(
f"Unknown rotation mode: {self.config.rotation_mode}"
)
target_quat = quaternion.as_float_array(
quaternion.from_rotation_matrix(ur2mj[:3, :3] @ new_ee_rot)
)
ik_result = qpos_from_site_pose(
self.physics,
"attachment_site",
target_pos=apply_transfer(ur2mj, new_ee_pos),
target_quat=target_quat,
tol=1e-14,
max_steps=400,
)
self.physics.reset()
if ik_result.success:
new_qpos = ik_result.qpos[:num_dof]
else:
print("ik failed, using the original qpos")
return np.concatenate([current_qpos, [current_gripper_angle]])
new_gripper_angle = current_gripper_angle
if self._invert_button:
if spacemouse_button[1]:
new_gripper_angle = 1
if spacemouse_button[0]:
new_gripper_angle = 0
else:
if spacemouse_button[1]:
new_gripper_angle = 0
if spacemouse_button[0]:
new_gripper_angle = 1
command = np.concatenate([new_qpos, [new_gripper_angle]])
return command
if __name__ == "__main__":
import pyspacemouse
success = pyspacemouse.open("/dev/hidraw4")
success = pyspacemouse.open("/dev/hidraw5")

81
gello/cameras/camera.py Normal file
View file

@ -0,0 +1,81 @@
from pathlib import Path
from typing import Optional, Protocol, Tuple
import numpy as np
class CameraDriver(Protocol):
"""Camera protocol.
A protocol for a camera driver. This is used to abstract the camera from the rest of the code.
"""
def read(
self,
img_size: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera.
Args:
img_size: The size of the image to return. If None, the original size is returned.
farthest: The farthest distance to map to 255.
Returns:
np.ndarray: The color image.
np.ndarray: The depth image.
"""
class DummyCamera(CameraDriver):
"""A dummy camera for testing."""
def read(
self,
img_size: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera.
Args:
img_size: The size of the image to return. If None, the original size is returned.
farthest: The farthest distance to map to 255.
Returns:
np.ndarray: The color image.
np.ndarray: The depth image.
"""
if img_size is None:
return (
np.random.randint(255, size=(480, 640, 3), dtype=np.uint8),
np.random.randint(255, size=(480, 640, 1), dtype=np.uint16),
)
else:
return (
np.random.randint(
255, size=(img_size[0], img_size[1], 3), dtype=np.uint8
),
np.random.randint(
255, size=(img_size[0], img_size[1], 1), dtype=np.uint16
),
)
class SavedCamera(CameraDriver):
def __init__(self, path: str = "example"):
self.path = str(Path(__file__).parent / path)
from PIL import Image
self._color_img = Image.open(f"{self.path}/image.png")
self._depth_img = Image.open(f"{self.path}/depth.png")
def read(
self,
img_size: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
if img_size is not None:
color_img = self._color_img.resize(img_size)
depth_img = self._depth_img.resize(img_size)
else:
color_img = self._color_img
depth_img = self._depth_img
return np.array(color_img), np.array(depth_img)[:, :, 0:1]

View file

@ -0,0 +1,123 @@
import os
import time
from typing import List, Optional, Tuple
import numpy as np
from gello.cameras.camera import CameraDriver
def get_device_ids() -> List[str]:
import pyrealsense2 as rs
ctx = rs.context()
devices = ctx.query_devices()
device_ids = []
for dev in devices:
dev.hardware_reset()
device_ids.append(dev.get_info(rs.camera_info.serial_number))
time.sleep(2)
return device_ids
class RealSenseCamera(CameraDriver):
def __repr__(self) -> str:
return f"RealSenseCamera(device_id={self._device_id})"
def __init__(self, device_id: Optional[str] = None, flip: bool = False):
import pyrealsense2 as rs
self._device_id = device_id
if device_id is None:
ctx = rs.context()
devices = ctx.query_devices()
for dev in devices:
dev.hardware_reset()
time.sleep(2)
self._pipeline = rs.pipeline()
config = rs.config()
else:
self._pipeline = rs.pipeline()
config = rs.config()
config.enable_device(device_id)
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
self._pipeline.start(config)
self._flip = flip
def read(
self,
img_size: Optional[Tuple[int, int]] = None, # farthest: float = 0.12
) -> Tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera.
Args:
img_size: The size of the image to return. If None, the original size is returned.
farthest: The farthest distance to map to 255.
Returns:
np.ndarray: The color image, shape=(H, W, 3)
np.ndarray: The depth image, shape=(H, W, 1)
"""
import cv2
frames = self._pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
color_image = np.asanyarray(color_frame.get_data())
depth_frame = frames.get_depth_frame()
depth_image = np.asanyarray(depth_frame.get_data())
# depth_image = cv2.convertScaleAbs(depth_image, alpha=0.03)
if img_size is None:
image = color_image[:, :, ::-1]
depth = depth_image
else:
image = cv2.resize(color_image, img_size)[:, :, ::-1]
depth = cv2.resize(depth_image, img_size)
# rotate 180 degree's because everything is upside down in order to center the camera
if self._flip:
image = cv2.rotate(image, cv2.ROTATE_180)
depth = cv2.rotate(depth, cv2.ROTATE_180)[:, :, None]
else:
depth = depth[:, :, None]
return image, depth
def _debug_read(camera, save_datastream=False):
import cv2
cv2.namedWindow("image")
cv2.namedWindow("depth")
counter = 0
if not os.path.exists("images"):
os.makedirs("images")
if save_datastream and not os.path.exists("stream"):
os.makedirs("stream")
while True:
time.sleep(0.1)
image, depth = camera.read()
depth = np.concatenate([depth, depth, depth], axis=-1)
key = cv2.waitKey(1)
cv2.imshow("image", image[:, :, ::-1])
cv2.imshow("depth", depth)
if key == ord("s"):
cv2.imwrite(f"images/image_{counter}.png", image[:, :, ::-1])
cv2.imwrite(f"images/depth_{counter}.png", depth)
if save_datastream:
cv2.imwrite(f"stream/image_{counter}.png", image[:, :, ::-1])
cv2.imwrite(f"stream/depth_{counter}.png", depth)
counter += 1
if key == 27:
break
if __name__ == "__main__":
device_ids = get_device_ids()
print(f"Found {len(device_ids)} devices")
print(device_ids)
rs = RealSenseCamera(flip=True, device_id=device_ids[0])
im, depth = rs.read()
_debug_read(rs, save_datastream=True)

View file

@ -0,0 +1,231 @@
from typing import Dict
import cv2
import numpy as np
import torch
import transforms3d._gohlketransforms as ttf
def to_torch(array, device="cpu"):
if isinstance(array, torch.Tensor):
return array.to(device)
if isinstance(array, np.ndarray):
return torch.from_numpy(array).to(device)
else:
return torch.tensor(array).to(device)
def to_numpy(array):
if isinstance(array, torch.Tensor):
return array.cpu().numpy()
return array
def center_crop(rgb_frame, depth_frame):
H, W = rgb_frame.shape[-2:]
sq_size = min(H, W)
# crop the center square
if H > W:
rgb_frame = rgb_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size]
depth_frame = depth_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size]
elif W < H:
rgb_frame = rgb_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)]
depth_frame = depth_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)]
return rgb_frame, depth_frame
def resize(rgb, depth, size=224):
rgb = rgb.transpose([1, 2, 0])
rgb = cv2.resize(rgb, (size, size), interpolation=cv2.INTER_LINEAR)
rgb = rgb.transpose([2, 0, 1])
depth = cv2.resize(depth[0], (size, size), interpolation=cv2.INTER_LINEAR)
depth = depth.reshape([1, size, size])
return rgb, depth
def filter_depth(depth, max_depth=2.0, min_depth=0.0):
depth[np.isnan(depth)] = 0.0
depth[np.isinf(depth)] = 0.0
depth = np.clip(depth, min_depth, max_depth)
return depth
def preproc_obs(
demo: Dict[str, np.ndarray], joint_only: bool = True
) -> Dict[str, np.ndarray]:
# c h w
rgb_wrist = demo.get(f"wrist_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore
depth_wrist = demo.get(f"wrist_depth").transpose([2, 0, 1]) * 1.0 # type: ignore
rgb_base = demo.get("base_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore
depth_base = demo.get("base_depth").transpose([2, 0, 1]) * 1.0 # type: ignore
# Center crop and fitler depth
rgb_wrist, depth_wrist = resize(*center_crop(rgb_wrist, depth_wrist))
rgb_base, depth_base = resize(*center_crop(rgb_base, depth_base))
depth_wrist = filter_depth(depth_wrist)
depth_base = filter_depth(depth_base)
rgb = np.stack([rgb_wrist, rgb_base], axis=0)
depth = np.stack([depth_wrist, depth_base], axis=0)
# Dummy
dummy_cam = np.eye(4)
K = np.eye(3)
# state
qpos, qvel, ee_pos_quat, gripper_pos = (
demo.get("joint_positions"), # type: ignore
demo.get("joint_velocities"), # type: ignore
demo.get("ee_pos_quat"), # type: ignore
demo.get("gripper_position"), # type: ignore
)
if joint_only:
state: np.ndarray = qpos
else:
state: np.ndarray = np.concatenate([qpos, qvel, ee_pos_quat, gripper_pos[None]])
return {
"rgb": rgb,
"depth": depth,
"camera_poses": dummy_cam,
"K_matrices": K,
"state": state,
}
class Pose(object):
def __init__(self, x, y, z, qw, qx, qy, qz):
self.p = np.array([x, y, z])
# we internally use tf.transformations, which uses [x, y, z, w] for quaternions.
self.q = np.array([qx, qy, qz, qw])
# make sure that the quaternion has positive scalar part
if self.q[3] < 0:
self.q *= -1
self.q = self.q / np.linalg.norm(self.q)
def __mul__(self, other):
assert isinstance(other, Pose)
p = self.p + ttf.quaternion_matrix(self.q)[:3, :3].dot(other.p)
q = ttf.quaternion_multiply(self.q, other.q)
return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
def __rmul__(self, other):
assert isinstance(other, Pose)
return other * self
def __str__(self):
return "p: {}, q: {}".format(self.p, self.q)
def inv(self):
R = ttf.quaternion_matrix(self.q)[:3, :3]
p = -R.T.dot(self.p)
q = ttf.quaternion_inverse(self.q)
return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
def to_quaternion(self):
"""
this satisfies Pose(*to_quaternion(p)) == p
"""
q_reverted = np.array([self.q[3], self.q[0], self.q[1], self.q[2]])
return np.concatenate([self.p, q_reverted])
def to_axis_angle(self):
"""
returns the axis-angle representation of the rotation.
"""
angle = 2 * np.arccos(self.q[3])
angle = angle / np.pi
if angle > 1:
angle -= 2
axis = self.q[:3] / np.linalg.norm(self.q[:3])
# keep the axes positive
if axis[0] < 0:
axis *= -1
angle *= -1
return np.concatenate([self.p, axis, [angle]])
def to_euler(self):
q = np.array(ttf.euler_from_quaternion(self.q))
if q[0] > np.pi:
q[0] -= 2 * np.pi
if q[1] > np.pi:
q[1] -= 2 * np.pi
if q[2] > np.pi:
q[2] -= 2 * np.pi
q = q / np.pi
return np.concatenate([self.p, q, [0.0]])
def to_44_matrix(self):
out = np.eye(4)
out[:3, :3] = ttf.quaternion_matrix(self.q)[:3, :3]
out[:3, 3] = self.p
return out
@staticmethod
def from_axis_angle(x, y, z, ax, ay, az, phi):
"""
returns a Pose object from the axis-angle representation of the rotation.
"""
phi = phi * np.pi
p = np.array([x, y, z])
qw = np.cos(phi / 2.0)
qx = ax * np.sin(phi / 2.0)
qy = ay * np.sin(phi / 2.0)
qz = az * np.sin(phi / 2.0)
return Pose(p[0], p[1], p[2], qw, qx, qy, qz)
@staticmethod
def from_euler(x, y, z, roll, pitch, yaw, _):
"""
returns a Pose object from the euler representation of the rotation.
"""
p = np.array([x, y, z])
roll, pitch, yaw = roll * np.pi, pitch * np.pi, yaw * np.pi
q = ttf.quaternion_from_euler(roll, pitch, yaw)
return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
@staticmethod
def from_quaternion(x, y, z, qw, qx, qy, qz):
"""
returns a Pose object from the quaternion representation of the rotation.
"""
p = np.array([x, y, z])
return Pose(p[0], p[1], p[2], qw, qx, qy, qz)
def compute_inverse_action(p, p_new, ee_control=False):
assert isinstance(p, Pose) and isinstance(p_new, Pose)
if ee_control:
dpose = p.inv() * p_new
else:
dpose = p_new * p.inv()
return dpose
def compute_forward_action(p, dpose, ee_control=False):
assert isinstance(p, Pose) and isinstance(dpose, Pose)
dpose = Pose.from_quaternion(*dpose.to_quaternion())
if ee_control:
p_new = p * dpose
else:
p_new = dpose * p
return p_new

View file

@ -0,0 +1,345 @@
import glob
import os
import pickle
import shutil
from dataclasses import dataclass
from typing import Tuple
import numpy as np
import tyro
from natsort import natsorted
from tqdm import tqdm
from gello.data_utils.plot_utils import plot_in_grid
np.set_printoptions(precision=3, suppress=True)
import mediapy as mp
from gdict.data import DictArray, GDict
from simple_bc.utils.visualization_utils import make_grid_video_from_numpy
from gello.data_utils.conversion_utils import preproc_obs
# def get_act_bounds(source_dir: str) -> np.ndarray:
# pkls = natsorted(
# glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
# )
# if len(pkls) <= 30:
# print(f"Skipping {source_dir} because it has less than 30 frames.")
# return None
# pkls = pkls[:-5]
#
# scale_factor = None
# for pkl in pkls:
# try:
# with open(pkl, "rb") as f:
# demo = pickle.load(f)
# except Exception as e:
# print(f"Skipping {pkl} because it is corrupted.")
# print(f"Error: {e}")
# raise Exception("Corrupted pkl")
#
# requested_control = demo.pop("control")
# curr_scale_factor = np.abs(requested_control)
# if scale_factor is None:
# scale_factor = curr_scale_factor
# else:
# scale_factor = np.maximum(scale_factor, curr_scale_factor)
# assert scale_factor is not None
# return scale_factor
def get_act_min_max(source_dir: str) -> Tuple[np.ndarray, np.ndarray]:
pkls = natsorted(
glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
)
if len(pkls) <= 30:
print(f"Skipping {source_dir} because it has less than 30 frames.")
raise RuntimeError("Too few frames")
pkls = pkls[:-5]
scale_min = None
scale_max = None
for pkl in pkls:
try:
with open(pkl, "rb") as f:
demo = pickle.load(f)
except Exception as e:
print(f"Skipping {pkl} because it is corrupted.")
print(f"Error: {e}")
raise Exception("Corrupted pkl")
requested_control = demo.pop("control")
curr_scale_factor = requested_control
if scale_min is None:
assert scale_max is None
scale_min = curr_scale_factor
scale_max = curr_scale_factor
else:
assert scale_max is not None
scale_min = np.minimum(scale_min, curr_scale_factor)
scale_max = np.maximum(scale_min, curr_scale_factor)
assert scale_min is not None
assert scale_max is not None
return scale_min, scale_max
def convert_single_demo(
source_dir,
i,
traj_output_dir,
rgb_output_dir,
depth_output_dir,
state_output_dir,
action_output_dir,
scale_factor,
bias_factor,
):
"""
1. converts the demo into a gdict
2. visualizes the RGB of the demo
3. visualizes the state + action space of the demo
4. returns these to be collated by the caller.
"""
pkls = natsorted(
glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
)
demo_stack = []
if len(pkls) <= 30:
return 0
# go through the demo in reverse order.
# remove the first few frames because they are not useful.
pkls = pkls[:-5]
for pkl in pkls:
curr_ts = {}
try:
with open(pkl, "rb") as f:
demo = pickle.load(f)
except:
print(f"Skipping {pkl} because it is corrupted.")
return 0
obs = preproc_obs(demo)
action = demo.pop("control")
action = (action - bias_factor) / scale_factor # normalize between -1 and 1
curr_ts["obs"] = obs
curr_ts["actions"] = action
curr_ts["dones"] = np.zeros(1) # random fill
curr_ts["episode_dones"] = np.zeros(1) # random fill
curr_ts_wrapped = dict()
curr_ts_wrapped[f"traj_{i}"] = curr_ts
demo_stack = [curr_ts_wrapped] + demo_stack
demo_dict = DictArray.stack(demo_stack)
GDict.to_hdf5(demo_dict, os.path.join(traj_output_dir + "", f"traj_{i}.h5"))
## save the base videos
# save the base rgb and depth videos
all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 1].transpose([0, 2, 3, 1])
all_rgbs = all_rgbs.astype(np.uint8)
_, H, W, _ = all_rgbs.shape
all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 1].reshape([-1, H, W])
all_depths = all_depths / 5.0 # scale to 0-1
mp.write_video(
os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_base.mp4"), all_rgbs, fps=30
)
mp.write_video(
os.path.join(depth_output_dir + "", f"traj_{i}_depth_base.mp4"),
all_depths,
fps=30,
)
## save the wrist videos
# save the rgb and depth videos
all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 0].transpose([0, 2, 3, 1])
all_rgbs = all_rgbs.astype(np.uint8)
_, H, W, _ = all_rgbs.shape
all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 0].reshape([-1, H, W])
all_depths = all_depths / 2.0 # scale to 0-1
mp.write_video(
os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_wrist.mp4"), all_rgbs, fps=30
)
mp.write_video(
os.path.join(depth_output_dir + "", f"traj_{i}_depth_wrist.mp4"),
all_depths,
fps=30,
)
##
all_depths = np.tile(all_depths[..., None], [1, 1, 1, 3])
# save the state and action plots
all_actions = demo_dict[f"traj_{i}"]["actions"]
all_states = demo_dict[f"traj_{i}"]["obs"]["state"]
curr_actions = all_actions.reshape([1, *all_actions.shape])
curr_states = all_states.reshape([-1, *all_states.shape])
plot_in_grid(
curr_actions, os.path.join(action_output_dir + "", f"traj_{i}_actions.png")
)
plot_in_grid(
curr_states, os.path.join(state_output_dir + "", f"traj_{i}_states.png")
)
return all_rgbs, all_depths, all_actions, all_states
@dataclass
class Args:
source_dir: str
vis: bool = True
def main(args):
subdirs = natsorted(glob.glob(os.path.join(args.source_dir, "*/"), recursive=True))
output_dir = args.source_dir
if output_dir[-1] == "/":
output_dir = output_dir[:-1]
output_dir = os.path.join(output_dir, "_conv")
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
output_dir = os.path.join(output_dir, "multiview")
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
else:
print(f"Output directory {output_dir} already exists, and will be deleted")
shutil.rmtree(output_dir)
os.mkdir(output_dir)
train_dir = os.path.join(output_dir, "train")
val_dir = os.path.join(output_dir, "val")
if not os.path.isdir(train_dir):
os.mkdir(train_dir)
if not os.path.isdir(val_dir):
os.mkdir(val_dir)
val_size = int(min(0.1 * len(subdirs), 10))
val_indices = np.random.choice(len(subdirs), size=val_size, replace=False)
val_indices = set(val_indices)
print("Computing scale factors")
pbar = tqdm(range(len(subdirs)))
min_scale_factor = None
max_scale_factor = None
for i in pbar:
try:
curr_min, curr_max = get_act_min_max(subdirs[i])
if min_scale_factor is None:
assert max_scale_factor is None
min_scale_factor = curr_min
max_scale_factor = curr_max
else:
assert max_scale_factor is not None
min_scale_factor = np.minimum(min_scale_factor, curr_min)
max_scale_factor = np.maximum(max_scale_factor, curr_max)
pbar.set_description(f"t: {i}")
except Exception as e:
print(f"Error: {e}")
print(f"Skipping {subdirs[i]}")
continue
bias_factor = (min_scale_factor + max_scale_factor) / 2.0
scale_factor = (max_scale_factor - min_scale_factor) / 2.0
scale_factor[scale_factor == 0] = 1.0
print("*" * 80)
print(f"scale factors: {scale_factor}")
print(f"bias factor: {bias_factor}")
# make it into a copy pasteable string where the numbers are separated by commas
scale_factor_str = ", ".join([f"{x}" for x in scale_factor])
print(f"scale_factor = np.array([{scale_factor_str}])")
bias_factor_str = ", ".join([f"{x}" for x in bias_factor])
print(f"bias_factor = np.array([{bias_factor_str}])")
print("*" * 80)
tot = 0
all_rgbs = []
all_depths = []
all_actions = []
all_states = []
vis_dir = os.path.join(output_dir, "vis")
state_output_dir = os.path.join(vis_dir, "state")
action_output_dir = os.path.join(vis_dir, "action")
rgb_output_dir = os.path.join(vis_dir, "rgb")
depth_output_dir = os.path.join(vis_dir, "depth")
if not os.path.isdir(vis_dir):
os.mkdir(vis_dir)
if not os.path.isdir(state_output_dir):
os.mkdir(state_output_dir)
if not os.path.isdir(action_output_dir):
os.mkdir(action_output_dir)
if not os.path.isdir(rgb_output_dir):
os.mkdir(rgb_output_dir)
if not os.path.isdir(depth_output_dir):
os.mkdir(depth_output_dir)
pbar = tqdm(range(len(subdirs)))
for i in pbar:
out_dir = val_dir if i in val_indices else train_dir
out_dir = os.path.join(out_dir, "none")
if not os.path.isdir(out_dir):
os.mkdir(out_dir)
ret = convert_single_demo(
subdirs[i],
i,
out_dir,
rgb_output_dir,
depth_output_dir,
state_output_dir,
action_output_dir,
scale_factor=scale_factor,
bias_factor=bias_factor,
)
if ret != 0:
all_rgbs.append(ret[0])
all_depths.append(ret[1])
all_actions.append(ret[2])
all_states.append(ret[3])
tot += 1
pbar.set_description(f"t: {i}")
print(
f"Finished converting all demos to {output_dir}! (num demos: {tot} / {len(subdirs)})"
)
if args.vis:
if len(all_rgbs) > 0:
print(f"Visualizing all demos...")
plot_in_grid(
all_actions, os.path.join(action_output_dir, "_all_actions.png")
)
plot_in_grid(all_states, os.path.join(state_output_dir, "_all_states.png"))
make_grid_video_from_numpy(
all_rgbs, 10, os.path.join(rgb_output_dir, "_all_rgb.mp4"), fps=30
)
make_grid_video_from_numpy(
all_depths, 10, os.path.join(depth_output_dir, "_all_depth.mp4"), fps=30
)
exit(0)
if __name__ == "__main__":
main(tyro.cli(Args))

View file

@ -0,0 +1,22 @@
import datetime
import pickle
from pathlib import Path
from typing import Dict
import numpy as np
def save_frame(
folder: Path,
timestamp: datetime.datetime,
obs: Dict[str, np.ndarray],
action: np.ndarray,
) -> None:
obs["control"] = action # add action to obs
# make folder if it doesn't exist
folder.mkdir(exist_ok=True, parents=True)
recorded_file = folder / (timestamp.isoformat() + ".pkl")
with open(recorded_file, "wb") as f:
pickle.dump(obs, f)

View file

@ -0,0 +1,59 @@
import pygame
NORMAL = (128, 128, 128)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
KEY_START = pygame.K_s
KEY_CONTINUE = pygame.K_c
KEY_QUIT_RECORDING = pygame.K_q
class KBReset:
def __init__(self):
pygame.init()
self._screen = pygame.display.set_mode((800, 800))
self._set_color(NORMAL)
self._saved = False
def update(self) -> str:
pressed_last = self._get_pressed()
if KEY_QUIT_RECORDING in pressed_last:
self._set_color(RED)
self._saved = False
return "normal"
if self._saved:
return "save"
if KEY_START in pressed_last:
self._set_color(GREEN)
self._saved = True
return "start"
self._set_color(NORMAL)
return "normal"
def _get_pressed(self):
pressed = []
pygame.event.pump()
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
pressed.append(event.key)
return pressed
def _set_color(self, color):
self._screen.fill(color)
pygame.display.flip()
def main():
kb = KBReset()
while True:
state = kb.update()
if state == "start":
print("start")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,103 @@
import matplotlib.pyplot as plt
import numpy as np
def plot_in_grid(vals: np.ndarray, save_path: str):
"""Plot the trajectories in a grid.
Args:
vals: B x T x N, where
B is the number of trajectories,
T is the number of timesteps,
N is the dimensionality of the values.
save_path: path to save the plot.
"""
B = len(vals)
N = vals[0].shape[-1]
# fig, axes = plt.subplots(2, 4, figsize=(20, 10))
rows = (N // 4) + (N % 4 > 0)
fig, axes = plt.subplots(rows, 4, figsize=(20, 10))
for b in range(B):
curr = vals[b]
for i in range(N):
T = curr.shape[0]
# give them transparency
axes[i // 4, i % 4].plot(np.arange(T), curr[:, i], alpha=0.5)
for i in range(N):
axes[i // 4, i % 4].set_title(f"Dim {i}")
axes[i // 4, i % 4].set_ylim([-1.0, 1.0])
plt.savefig(save_path)
plt.close()
fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(141, projection="3d")
for b in range(B):
curr = vals[b]
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
# scatter the start and end points
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.legend(["Trajectory", "Start", "End"])
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
# get the 2D view of the XY plane, with X pointing downwards
ax.view_init(270, 0)
ax = fig.add_subplot(142, projection="3d")
for b in range(B):
curr = vals[b]
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.legend(["Trajectory", "Start", "End"])
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
# get the 2D view of the XZ plane, with X pointing leftwards
ax.view_init(0, 0)
ax = fig.add_subplot(143, projection="3d")
for b in range(B):
curr = vals[b]
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.legend(["Trajectory", "Start", "End"])
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
# get the 2D view of the YZ plane, with Y pointing leftwards
ax.view_init(0, 90)
ax = fig.add_subplot(144, projection="3d")
for b in range(B):
curr = vals[b]
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.legend(["Trajectory", "Start", "End"])
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
plt.savefig(save_path[:-4] + "_3d.png")
plt.close()

View file

@ -0,0 +1,5 @@
from gello.dm_control_tasks.arenas.base import Arena
__all__ = [
"Arena",
]

View file

@ -0,0 +1,13 @@
<mujoco model="arena">
<visual>
<headlight diffuse="0.6 0.6 0.6" ambient="0.3 0.3 0.3" specular="0 0 0"/>
<rgba haze="0.15 0.25 0.35 1"/>
</visual>
<asset>
<texture type="skybox" builtin="gradient" rgb1="0.3 0.5 0.7" rgb2="0 0 0" width="512"
height="3072"/>
</asset>
</mujoco>

View file

@ -0,0 +1,26 @@
"""Arena composer class."""
from pathlib import Path
from dm_control import composer, mjcf
_ARENA_XML = Path(__file__).resolve().parent / "arena.xml"
class Arena(composer.Entity):
"""Base arena class."""
def _build(self, name: str = "arena") -> None:
self._mjcf_root = mjcf.from_path(str(_ARENA_XML))
if name is not None:
self._mjcf_root.model = name
def add_free_entity(self, entity) -> mjcf.Element:
"""Includes an entity as a free moving body, i.e., with a freejoint."""
frame = self.attach(entity)
frame.add("freejoint")
return frame
@property
def mjcf_model(self) -> mjcf.RootElement:
return self._mjcf_root

View file

@ -0,0 +1,7 @@
from gello.dm_control_tasks.arms.franka import Franka
from gello.dm_control_tasks.arms.ur5e import UR5e
__all__ = [
"UR5e",
"Franka",
]

View file

@ -0,0 +1,28 @@
"""Franka composer class."""
from pathlib import Path
from typing import Optional, Union
from dm_control import mjcf
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
XML = MENAGERIE_ROOT / "franka_emika_panda" / "panda_nohand.xml"
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
class Franka(Manipulator):
"""Franka Robot."""
def _build(
self,
name: str = "franka",
xml_path: Union[str, Path] = XML,
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
) -> None:
super()._build(name="franka", xml_path=XML, gripper_xml_path=GRIPPER_XML)
@property
def flange(self) -> mjcf.Element:
return self._mjcf_root.find("site", "attachment_site")

View file

@ -0,0 +1,229 @@
"""Manipulator composer class."""
import abc
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from dm_control import composer, mjcf
from dm_control.composer.observation import observable
from dm_control.mujoco.wrapper import mjbindings
from dm_control.suite.utils.randomizers import random_limited_quaternion
from gello.dm_control_tasks import mjcf_utils
def attach_hand_to_arm(
arm_mjcf: mjcf.RootElement,
hand_mjcf: mjcf.RootElement,
# attach_site: str,
) -> None:
"""Attaches a hand to an arm.
The arm must have a site named "attachment_site".
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
Args:
arm_mjcf: The mjcf.RootElement of the arm.
hand_mjcf: The mjcf.RootElement of the hand.
attach_site: The name of the site to attach the hand to.
Raises:
ValueError: If the arm does not have a site named "attachment_site".
"""
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
# attachment_site = arm_mjcf.find("site", attach_site)
attachment_site = arm_mjcf.find("site", "attachment_site")
if attachment_site is None:
raise ValueError("No attachment site found in the arm model.")
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
arm_key = arm_mjcf.find("key", "home")
if arm_key is not None:
hand_key = hand_mjcf.find("key", "home")
if hand_key is None:
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
else:
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
attachment_site.attach(hand_mjcf)
class Manipulator(composer.Entity, abc.ABC):
"""A manipulator entity."""
def _build(
self,
name: str,
xml_path: Union[str, Path],
gripper_xml_path: Optional[Union[str, Path]],
) -> None:
"""Builds the manipulator.
Subclasses can not override this method, but should call this method in their
own _build() method.
"""
self._mjcf_root = mjcf.from_path(str(xml_path))
self._arm_joints = tuple(mjcf_utils.safe_find_all(self._mjcf_root, "joint"))
if gripper_xml_path:
gripper_mjcf = mjcf.from_path(str(gripper_xml_path))
attach_hand_to_arm(self._mjcf_root, gripper_mjcf)
self._mjcf_root.model = name
self._add_mjcf_elements()
def set_joints(self, physics: mjcf.Physics, joints: np.ndarray) -> None:
assert len(joints) == len(self._arm_joints)
for joint, joint_value in zip(self._arm_joints, joints):
joint_id = physics.bind(joint).element_id
joint_name = physics.model.id2name(joint_id, "joint")
physics.named.data.qpos[joint_name] = joint_value
def randomize_joints(
self,
physics: mjcf.Physics,
random: Optional[np.random.RandomState] = None,
) -> None:
random = random or np.random # type: ignore
assert random is not None
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
free = mjbindings.enums.mjtJoint.mjJNT_FREE
qpos = physics.named.data.qpos
for joint in self._arm_joints:
joint_id = physics.bind(joint).element_id
# joint_id = physics.model.name2id(joint.name, "joint")
joint_name = physics.model.id2name(joint_id, "joint")
joint_type = physics.model.jnt_type[joint_id]
is_limited = physics.model.jnt_limited[joint_id]
range_min, range_max = physics.model.jnt_range[joint_id]
if is_limited:
if joint_type in [hinge, slide]:
qpos[joint_name] = random.uniform(range_min, range_max)
elif joint_type == ball:
qpos[joint_name] = random_limited_quaternion(random, range_max)
else:
if joint_type == hinge:
qpos[joint_name] = random.uniform(-np.pi, np.pi)
elif joint_type == ball:
quat = random.randn(4)
quat /= np.linalg.norm(quat)
qpos[joint_name] = quat
elif joint_type == free:
# this should be random.randn, but changing it now could significantly
# affect benchmark results.
quat = random.rand(4)
quat /= np.linalg.norm(quat)
qpos[joint_name][3:] = quat
def _add_mjcf_elements(self) -> None:
# Parse joints.
joints = mjcf_utils.safe_find_all(self._mjcf_root, "joint")
joints = [joint for joint in joints if joint.tag != "freejoint"]
self._joints = tuple(joints)
# Parse actuators.
actuators = mjcf_utils.safe_find_all(self._mjcf_root, "actuator")
self._actuators = tuple(actuators)
# Parse qpos / ctrl keyframes.
self._keyframes = {}
keyframes = mjcf_utils.safe_find_all(self._mjcf_root, "key")
if keyframes:
for frame in keyframes:
if frame.qpos is not None:
qpos = np.array(frame.qpos)
self._keyframes[frame.name] = qpos
# add a visualizeation the flange position that is green
self.flange.parent.add(
"geom",
name="flange_geom",
type="sphere",
size="0.01",
rgba="0 1 0 1",
pos=self.flange.pos,
contype="0",
conaffinity="0",
)
def _build_observables(self):
return ArmObservables(self)
@property
@abc.abstractmethod
def flange(self) -> mjcf.Element:
"""Returns the flange element.
The flange is the end effector of the manipulator where tools can be
attached, such as a gripper.
"""
@property
def mjcf_model(self) -> mjcf.RootElement:
return self._mjcf_root
@property
def name(self) -> str:
return self._mjcf_root.model
@property
def joints(self) -> Tuple[mjcf.Element, ...]:
return self._joints
@property
def actuators(self) -> Tuple[mjcf.Element, ...]:
return self._actuators
@property
def keyframes(self) -> Dict[str, np.ndarray]:
return self._keyframes
class ArmObservables(composer.Observables):
"""Base class for quadruped observables."""
@composer.observable
def joints_pos(self):
return observable.MJCFFeature("qpos", self._entity.joints)
@composer.observable
def joints_vel(self):
return observable.MJCFFeature("qvel", self._entity.joints)
@composer.observable
def flange_position(self):
return observable.MJCFFeature("xpos", self._entity.flange)
@composer.observable
def flange_orientation(self):
return observable.MJCFFeature("xmat", self._entity.flange)
# Semantic grouping of observables.
def _collect_from_attachments(self, attribute_name: str):
out: List[observable.MJCFFeature] = []
for entity in self._entity.iter_entities(exclude_self=True):
out.extend(getattr(entity.observables, attribute_name, []))
return out
@property
def proprioception(self):
return [
self.joints_pos,
self.joints_vel,
self.flange_position,
# self.flange_orientation,
# self.flange_velocity,
self.flange_angular_velocity,
] + self._collect_from_attachments("proprioception")

View file

@ -0,0 +1,26 @@
"""UR5e composer class."""
from pathlib import Path
from typing import Optional, Union
from dm_control import mjcf
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
class UR5e(Manipulator):
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
XML = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
def _build(
self,
name: str = "UR5e",
xml_path: Union[str, Path] = XML,
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
) -> None:
super()._build(name=name, xml_path=xml_path, gripper_xml_path=gripper_xml_path)
@property
def flange(self) -> mjcf.Element:
return self._mjcf_root.find("site", "attachment_site")

View file

@ -0,0 +1,22 @@
"""Tests for ur5e.py."""
from absl.testing import absltest
from dm_control import mjcf
from gello.dm_control_tasks.arms import ur5e
class UR5eTest(absltest.TestCase):
def test_compiles_and_steps(self) -> None:
robot = ur5e.UR5e()
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
physics.step()
def test_joints(self) -> None:
robot = ur5e.UR5e()
for joint in robot.joints:
self.assertEqual(joint.tag, "joint")
if __name__ == "__main__":
absltest.main()

View file

@ -0,0 +1,261 @@
import collections
import numpy as np
from absl import logging
from dm_control.mujoco.wrapper import mjbindings
mjlib = mjbindings.mjlib
_INVALID_JOINT_NAMES_TYPE = (
"`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}."
)
_REQUIRE_TARGET_POS_OR_QUAT = (
"At least one of `target_pos` or `target_quat` must be specified."
)
IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"])
def qpos_from_site_pose(
physics,
site_name,
target_pos=None,
target_quat=None,
joint_names=None,
tol=1e-14,
rot_weight=1.0,
regularization_threshold=0.1,
regularization_strength=3e-2,
max_update_norm=2.0,
progress_thresh=20.0,
max_steps=100,
inplace=False,
):
"""Find joint positions that satisfy a target site position and/or rotation.
Args:
physics: A `mujoco.Physics` instance.
site_name: A string specifying the name of the target site.
target_pos: A (3,) numpy array specifying the desired Cartesian position of
the site, or None if the position should be unconstrained (default).
One or both of `target_pos` or `target_quat` must be specified.
target_quat: A (4,) numpy array specifying the desired orientation of the
site as a quaternion, or None if the orientation should be unconstrained
(default). One or both of `target_pos` or `target_quat` must be specified.
joint_names: (optional) A list, tuple or numpy array specifying the names of
one or more joints that can be manipulated in order to achieve the target
site pose. If None (default), all joints may be manipulated.
tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
in the stopping criterion).
rot_weight: (optional) Determines the weight given to rotational error
relative to translational error.
regularization_threshold: (optional) L2 regularization will be used when
inverting the Jacobian whilst `err_norm` is greater than this value.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements.
max_update_norm: (optional) The maximum L2 norm of the update applied to
the joint positions on each iteration. The update vector will be scaled
such that its magnitude never exceeds this value.
progress_thresh: (optional) If `err_norm` divided by the magnitude of the
joint position update is greater than this value then the optimization
will terminate prematurely. This is a useful heuristic to avoid getting
stuck in local minima.
max_steps: (optional) The maximum number of iterations to perform.
inplace: (optional) If True, `physics.data` will be modified in place.
Default value is False, i.e. a copy of `physics.data` will be made.
Returns:
An `IKResult` namedtuple with the following fields:
qpos: An (nq,) numpy array of joint positions.
err_norm: A float, the weighted sum of L2 norms for the residual
translational and rotational errors.
steps: An int, the number of iterations that were performed.
success: Boolean, True if we converged on a solution within `max_steps`,
False otherwise.
Raises:
ValueError: If both `target_pos` and `target_quat` are None, or if
`joint_names` has an invalid type.
"""
dtype = physics.data.qpos.dtype
if target_pos is not None and target_quat is not None:
jac = np.empty((6, physics.model.nv), dtype=dtype)
err = np.empty(6, dtype=dtype)
jac_pos, jac_rot = jac[:3], jac[3:]
err_pos, err_rot = err[:3], err[3:]
else:
jac = np.empty((3, physics.model.nv), dtype=dtype)
err = np.empty(3, dtype=dtype)
if target_pos is not None:
jac_pos, jac_rot = jac, None
err_pos, err_rot = err, None
elif target_quat is not None:
jac_pos, jac_rot = None, jac
err_pos, err_rot = None, err
else:
raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
update_nv = np.zeros(physics.model.nv, dtype=dtype)
if target_quat is not None:
site_xquat = np.empty(4, dtype=dtype)
neg_site_xquat = np.empty(4, dtype=dtype)
err_rot_quat = np.empty(4, dtype=dtype)
if not inplace:
physics = physics.copy(share_model=True)
# Ensure that the Cartesian position of the site is up to date.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
# Convert site name to index.
site_id = physics.model.name2id(site_name, "site")
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
# update them in place, so we can avoid indexing overhead in the main loop.
site_xpos = physics.named.data.site_xpos[site_name]
site_xmat = physics.named.data.site_xmat[site_name]
# This is an index into the rows of `update` and the columns of `jac`
# that selects DOFs associated with joints that we are allowed to manipulate.
if joint_names is None:
dof_indices = slice(None) # Update all DOFs.
elif isinstance(joint_names, (list, np.ndarray, tuple)):
if isinstance(joint_names, tuple):
joint_names = list(joint_names)
# Find the indices of the DOFs belonging to each named joint. Note that
# these are not necessarily the same as the joint IDs, since a single joint
# may have >1 DOF (e.g. ball joints).
indexer = physics.named.model.dof_jntid.axes.row
# `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
# indexer to map each joint name to the indices of its corresponding DOFs.
dof_indices = indexer.convert_key_item(joint_names)
else:
raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
steps = 0
success = False
for steps in range(max_steps):
err_norm = 0.0
if target_pos is not None:
# Translational error.
err_pos[:] = target_pos - site_xpos
err_norm += np.linalg.norm(err_pos)
if target_quat is not None:
# Rotational error.
mjlib.mju_mat2Quat(site_xquat, site_xmat)
mjlib.mju_negQuat(neg_site_xquat, site_xquat)
mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
err_norm += np.linalg.norm(err_rot) * rot_weight
if err_norm < tol:
logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm)
success = True
break
else:
# TODO(b/112141670): Generalize this to other entities besides sites.
mjlib.mj_jacSite(
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id
)
jac_joints = jac[:, dof_indices]
# TODO(b/112141592): This does not take joint limits into consideration.
reg_strength = (
regularization_strength if err_norm > regularization_threshold else 0.0
)
update_joints = nullspace_method(
jac_joints, err, regularization_strength=reg_strength
)
update_norm = np.linalg.norm(update_joints)
# Check whether we are still making enough progress, and halt if not.
progress_criterion = err_norm / update_norm
if progress_criterion > progress_thresh:
logging.debug(
"Step %2i: err_norm / update_norm (%3g) > "
"tolerance (%3g). Halting due to insufficient progress",
steps,
progress_criterion,
progress_thresh,
)
break
if update_norm > max_update_norm:
update_joints *= max_update_norm / update_norm
# Write the entries for the specified joints into the full `update_nv`
# vector.
update_nv[dof_indices] = update_joints
# Update `physics.qpos`, taking quaternions into account.
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
# Compute the new Cartesian position of the site.
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
logging.debug(
"Step %2i: err_norm=%-10.3g update_norm=%-10.3g",
steps,
err_norm,
update_norm,
)
if not success and steps == max_steps - 1:
logging.warning(
"Failed to converge after %i steps: err_norm=%3g", steps, err_norm
)
if not inplace:
# Our temporary copy of physics.data is about to go out of scope, and when
# it does the underlying mjData pointer will be freed and physics.data.qpos
# will be a view onto a block of deallocated memory. We therefore need to
# make a copy of physics.data.qpos while physics.data is still alive.
qpos = physics.data.qpos.copy()
else:
# If we're modifying physics.data in place then it's fine to return a view.
qpos = physics.data.qpos
return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
def nullspace_method(jac_joints, delta, regularization_strength=0.0):
"""Calculates the joint velocities to achieve a specified end effector delta.
Args:
jac_joints: The Jacobian of the end effector with respect to the joints. A
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
and `nv` is the number of degrees of freedom.
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
`(6,)` containing either position deltas, rotation deltas, or both.
regularization_strength: (optional) Coefficient of the quadratic penalty
on joint movements. Default is zero, i.e. no regularization.
Returns:
An `(nv,)` numpy array of joint velocities.
Reference:
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
transpose, pseudoinverse and damped least squares methods.
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
"""
hess_approx = jac_joints.T.dot(jac_joints)
joint_delta = jac_joints.T.dot(delta)
if regularization_strength > 0:
# L2 regularization
hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
return np.linalg.solve(hess_approx, joint_delta)
else:
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
class InverseKinematics:
def __init__(self, xml_path: str):
"""Initializes the inverse kinematics class."""
...
# TODO

View file

@ -0,0 +1,111 @@
"""Simple floor arenas."""
from typing import Tuple
import numpy as np
from dm_control import mjcf
from gello.dm_control_tasks.arenas import base
_GROUNDPLANE_QUAD_SIZE = 0.25
class FixedManipulationArena(base.Arena):
@property
def arm_attachment_site(self) -> np.ndarray:
return self.attachment_site
class Floor(base.Arena):
"""An arena with a checkered pattern."""
def _build(
self,
name: str = "floor",
size: Tuple[float, float] = (8, 8),
reflectance: float = 0.2,
top_camera_y_padding_factor: float = 1.1,
top_camera_distance: float = 5.0,
) -> None:
super()._build(name=name)
self._size = size
self._top_camera_y_padding_factor = top_camera_y_padding_factor
self._top_camera_distance = top_camera_distance
assert self._mjcf_root.worldbody is not None
z_offset = 0.00
# Add arm attachement site
self._mjcf_root.worldbody.add(
"site",
name="arm_attachment",
pos=(0, 0, z_offset),
size=(0.01, 0.01, 0.01),
type="sphere",
rgba=(0, 0, 0, 0),
)
# Add light.
self._mjcf_root.worldbody.add(
"light",
pos=(0, 0, 1.5),
dir=(0, 0, -1),
directional=True,
)
self._ground_texture = self._mjcf_root.asset.add(
"texture",
rgb1=[0.2, 0.3, 0.4],
rgb2=[0.1, 0.2, 0.3],
type="2d",
builtin="checker",
name="groundplane",
width=200,
height=200,
mark="edge",
markrgb=[0.8, 0.8, 0.8],
)
self._ground_material = self._mjcf_root.asset.add(
"material",
name="groundplane",
texrepeat=[2, 2], # Makes white squares exactly 1x1 length units.
texuniform=True,
reflectance=reflectance,
texture=self._ground_texture,
)
# Build groundplane.
self._ground_geom = self._mjcf_root.worldbody.add(
"geom",
type="plane",
name="groundplane",
material=self._ground_material,
size=list(size) + [_GROUNDPLANE_QUAD_SIZE],
)
# Choose the FOV so that the floor always fits nicely within the frame
# irrespective of actual floor size.
fovy_radians = 2 * np.arctan2(
top_camera_y_padding_factor * size[1], top_camera_distance
)
self._top_camera = self._mjcf_root.worldbody.add(
"camera",
name="top_camera",
pos=[0, 0, top_camera_distance],
quat=[1, 0, 0, 0],
fovy=np.rad2deg(fovy_radians),
)
@property
def ground_geoms(self):
return (self._ground_geom,)
@property
def size(self):
return self._size
@property
def arm_attachment_site(self) -> mjcf.Element:
return self._mjcf_root.worldbody.find("site", "arm_attachment")

View file

@ -0,0 +1,56 @@
"""A abstract class for all walker tasks."""
from dm_control import composer, mjcf
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
# Timestep of the physics simulation.
_PHYSICS_TIMESTEP: float = 0.002
# Interval between agent actions, in seconds.
# We send a control signal every (_CONTROL_TIMESTEP / _PHYSICS_TIMESTEP) physics steps.
_CONTROL_TIMESTEP: float = 0.02 # 50 Hz.
class ManipulationTask(composer.Task):
"""Base composer task for walker robots."""
def __init__(
self,
arm: Manipulator,
arena: FixedManipulationArena,
physics_timestep=_PHYSICS_TIMESTEP,
control_timestep=_CONTROL_TIMESTEP,
) -> None:
self._arm = arm
self._arena = arena
self._arm_geoms = arm.root_body.find_all("geom")
self._arena_geoms = arena.root_body.find_all("geom")
arena.attach(arm, attach_site=arena.arm_attachment_site)
self.set_timesteps(
control_timestep=control_timestep, physics_timestep=physics_timestep
)
# Enable all robot observables. Note: No additional entity observables should
# be added in subclasses.
for observable in self._arm.observables.proprioception:
observable.enabled = True
def in_collision(self, physics: mjcf.Physics) -> bool:
"""Checks if the arm is in collision with the floor (self._arena)."""
arm_ids = [physics.bind(g).element_id for g in self._arm_geoms]
arena_ids = [physics.bind(g).element_id for g in self._arena_geoms]
return any(
(con.geom1 in arm_ids and con.geom2 in arena_ids) # arm and arena
or (con.geom1 in arena_ids and con.geom2 in arm_ids) # arena and arm
or (con.geom1 in arm_ids and con.geom2 in arm_ids) # self collision
for con in physics.data.contact[: physics.data.ncon]
)
@property
def root_entity(self):
return self._arena

View file

@ -0,0 +1,122 @@
"""A task where a walker must learn to stand."""
from typing import Optional
import numpy as np
from dm_control import mjcf
from dm_control.suite.utils.randomizers import random_limited_quaternion
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
from gello.dm_control_tasks.manipulation.tasks import base
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
class BlockPlay(base.ManipulationTask):
"""Task for a manipulator. Blocks are randomly placed in the scene."""
def __init__(
self,
arm: Manipulator,
arena: FixedManipulationArena,
physics_timestep=base._PHYSICS_TIMESTEP,
control_timestep=base._CONTROL_TIMESTEP,
num_blocks: int = 10,
size: float = 0.03,
reset_joints: Optional[np.ndarray] = None,
) -> None:
super().__init__(arm, arena, physics_timestep, control_timestep)
# find key frame
key_frames = self.root_entity.mjcf_model.find_all("key")
if len(key_frames) == 0:
key_frames = None
else:
key_frames = key_frames[0]
# Create target.
block_joints = []
for i in range(num_blocks):
# select random colors = np.random.uniform(0, 1, size=3)
color = np.concatenate([np.random.uniform(0, 1, size=3), [1.0]])
# attach a body for block i
b = self.root_entity.mjcf_model.worldbody.add(
"body", name=f"block_{i}", pos=(0, 0, 0)
)
# # add a freejoint to the block so it can be moved
_joint = b.add("freejoint")
block_joints.append(_joint)
# add a geom to the block
b.add(
"geom",
name=f"block_geom_{i}",
type="box",
size=(size, size, size),
rgba=color,
# contype=0,
# conaffinity=0,
)
assert key_frames is not None
key_frames.qpos = np.concatenate([key_frames.qpos, np.zeros(7)])
# # save xml to file
# _xml_string = self.root_entity.mjcf_model.to_xml_string()
# with open("block_play.xml", "w") as f:
# f.write(_xml_string)
self._block_joints = block_joints
self._block_size = size
self._reset_joints = reset_joints
def initialize_episode(self, physics, random_state):
# Randomly set feasible target position
if self._reset_joints is not None:
self._arm.set_joints(physics, self._reset_joints)
else:
self._arm.randomize_joints(physics, random_state)
physics.forward()
# check if arm is in collision with floor
while self.in_collision(physics):
self._arm.randomize_joints(physics, random_state)
physics.forward()
# Randomize block positions
for block_j in self._block_joints:
randomize_pose(
block_j,
physics,
random_state=random_state,
position_range=0.5,
z_offset=self._block_size * 2,
)
physics.forward()
def get_reward(self, physics):
# flange position
return 0
def randomize_pose(
free_joint: mjcf.RootElement,
physics: mjcf.Physics,
random_state: np.random.RandomState,
position_range: float = 0.5,
z_offset: float = 0.0,
) -> None:
"""Randomize the pose of an entity."""
entity_pos = random_state.uniform(-position_range, position_range, size=2)
# make x, y farther than 0.1 from 0, 0
while np.linalg.norm(entity_pos) < 0.3:
entity_pos = random_state.uniform(-position_range, position_range, size=2)
entity_pos = np.concatenate([entity_pos, [z_offset]])
entity_quat = random_limited_quaternion(random_state, limit=np.pi)
qpos = np.concatenate([entity_pos, entity_quat])
physics.bind(free_joint).qpos = qpos

View file

@ -0,0 +1,63 @@
"""A task where a walker must learn to stand."""
import numpy as np
from dm_control.suite.utils import randomizers
from dm_control.utils import rewards
from gello.dm_control_tasks.arms.manipulator import Manipulator
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
from gello.dm_control_tasks.manipulation.tasks import base
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
class Reach(base.ManipulationTask):
"""Reach task for a manipulator."""
def __init__(
self,
arm: Manipulator,
arena: FixedManipulationArena,
physics_timestep=base._PHYSICS_TIMESTEP,
control_timestep=base._CONTROL_TIMESTEP,
distance_tolerance: float = 0.5,
) -> None:
super().__init__(arm, arena, physics_timestep, control_timestep)
self._distance_tolerance = distance_tolerance
# Create target.
self._target = self.root_entity.mjcf_model.worldbody.add(
"site",
name="target",
type="sphere",
pos=(0, 0, 0),
size=(0.1,),
rgba=_TARGET_COLOR,
)
def initialize_episode(self, physics, random_state):
# Randomly set feasible target position
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
physics.forward()
flange_position = physics.bind(self._arm.flange).xpos[:3]
print(flange_position)
# set target position to flange position
physics.bind(self._target).pos = flange_position
# Randomize initial position of the arm.
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
physics.forward()
def get_reward(self, physics):
# flange position
flange_pos = physics.bind(self._arm.flange).pos[:3]
distance = np.linalg.norm(physics.bind(self._target).pos[:3] - flange_pos)
return -rewards.tolerance(
distance,
bounds=(0, self._distance_tolerance),
margin=self._distance_tolerance,
value_at_margin=0,
sigmoid="linear",
)

View file

@ -0,0 +1,23 @@
from pathlib import Path
from typing import List
from dm_control import mjcf
# Path to the root of the project.
_PROJECT_ROOT: Path = Path(__file__).parent.parent.parent
# Path to the Menagerie submodule.
MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie"
def safe_find_all(
root: mjcf.RootElement,
feature_name: str,
immediate_children_only: bool = False,
exclude_attachments: bool = False,
) -> List[mjcf.Element]:
"""Find all given elements or throw an error if none are found."""
features = root.find_all(feature_name, immediate_children_only, exclude_attachments)
if not features:
raise ValueError(f"No {feature_name} found in the MJCF model.")
return features

View file

274
gello/dynamixel/driver.py Normal file
View file

@ -0,0 +1,274 @@
import time
from threading import Event, Lock, Thread
from typing import Protocol, Sequence
import numpy as np
from dynamixel_sdk.group_sync_read import GroupSyncRead
from dynamixel_sdk.group_sync_write import GroupSyncWrite
from dynamixel_sdk.packet_handler import PacketHandler
from dynamixel_sdk.port_handler import PortHandler
from dynamixel_sdk.robotis_def import (
COMM_SUCCESS,
DXL_HIBYTE,
DXL_HIWORD,
DXL_LOBYTE,
DXL_LOWORD,
)
# Constants
ADDR_TORQUE_ENABLE = 64
ADDR_GOAL_POSITION = 116
LEN_GOAL_POSITION = 4
ADDR_PRESENT_POSITION = 132
ADDR_PRESENT_POSITION = 140
LEN_PRESENT_POSITION = 4
TORQUE_ENABLE = 1
TORQUE_DISABLE = 0
class DynamixelDriverProtocol(Protocol):
def set_joints(self, joint_angles: Sequence[float]):
"""Set the joint angles for the Dynamixel servos.
Args:
joint_angles (Sequence[float]): A list of joint angles.
"""
...
def torque_enabled(self) -> bool:
"""Check if torque is enabled for the Dynamixel servos.
Returns:
bool: True if torque is enabled, False if it is disabled.
"""
...
def set_torque_mode(self, enable: bool):
"""Set the torque mode for the Dynamixel servos.
Args:
enable (bool): True to enable torque, False to disable.
"""
...
def get_joints(self) -> np.ndarray:
"""Get the current joint angles in radians.
Returns:
np.ndarray: An array of joint angles.
"""
...
def close(self):
"""Close the driver."""
class FakeDynamixelDriver(DynamixelDriverProtocol):
def __init__(self, ids: Sequence[int]):
self._ids = ids
self._joint_angles = np.zeros(len(ids), dtype=int)
self._torque_enabled = False
def set_joints(self, joint_angles: Sequence[float]):
if len(joint_angles) != len(self._ids):
raise ValueError(
"The length of joint_angles must match the number of servos"
)
if not self._torque_enabled:
raise RuntimeError("Torque must be enabled to set joint angles")
self._joint_angles = np.array(joint_angles)
def torque_enabled(self) -> bool:
return self._torque_enabled
def set_torque_mode(self, enable: bool):
self._torque_enabled = enable
def get_joints(self) -> np.ndarray:
return self._joint_angles.copy()
def close(self):
pass
class DynamixelDriver(DynamixelDriverProtocol):
def __init__(
self, ids: Sequence[int], port: str = "/dev/ttyUSB0", baudrate: int = 57600
):
"""Initialize the DynamixelDriver class.
Args:
ids (Sequence[int]): A list of IDs for the Dynamixel servos.
port (str): The USB port to connect to the arm.
baudrate (int): The baudrate for communication.
"""
self._ids = ids
self._joint_angles = None
self._lock = Lock()
# Initialize the port handler, packet handler, and group sync read/write
self._portHandler = PortHandler(port)
self._packetHandler = PacketHandler(2.0)
self._groupSyncRead = GroupSyncRead(
self._portHandler,
self._packetHandler,
ADDR_PRESENT_POSITION,
LEN_PRESENT_POSITION,
)
self._groupSyncWrite = GroupSyncWrite(
self._portHandler,
self._packetHandler,
ADDR_GOAL_POSITION,
LEN_GOAL_POSITION,
)
# Open the port and set the baudrate
if not self._portHandler.openPort():
raise RuntimeError("Failed to open the port")
if not self._portHandler.setBaudRate(baudrate):
raise RuntimeError(f"Failed to change the baudrate, {baudrate}")
# Add parameters for each Dynamixel servo to the group sync read
for dxl_id in self._ids:
if not self._groupSyncRead.addParam(dxl_id):
raise RuntimeError(
f"Failed to add parameter for Dynamixel with ID {dxl_id}"
)
# Disable torque for each Dynamixel servo
self._torque_enabled = False
try:
self.set_torque_mode(self._torque_enabled)
except Exception as e:
print(f"port: {port}, {e}")
self._stop_thread = Event()
self._start_reading_thread()
def set_joints(self, joint_angles: Sequence[float]):
if len(joint_angles) != len(self._ids):
raise ValueError(
"The length of joint_angles must match the number of servos"
)
if not self._torque_enabled:
raise RuntimeError("Torque must be enabled to set joint angles")
for dxl_id, angle in zip(self._ids, joint_angles):
# Convert the angle to the appropriate value for the servo
position_value = int(angle * 2048 / np.pi)
# Allocate goal position value into byte array
param_goal_position = [
DXL_LOBYTE(DXL_LOWORD(position_value)),
DXL_HIBYTE(DXL_LOWORD(position_value)),
DXL_LOBYTE(DXL_HIWORD(position_value)),
DXL_HIBYTE(DXL_HIWORD(position_value)),
]
# Add goal position value to the Syncwrite parameter storage
dxl_addparam_result = self._groupSyncWrite.addParam(
dxl_id, param_goal_position
)
if not dxl_addparam_result:
raise RuntimeError(
f"Failed to set joint angle for Dynamixel with ID {dxl_id}"
)
# Syncwrite goal position
dxl_comm_result = self._groupSyncWrite.txPacket()
if dxl_comm_result != COMM_SUCCESS:
raise RuntimeError("Failed to syncwrite goal position")
# Clear syncwrite parameter storage
self._groupSyncWrite.clearParam()
def torque_enabled(self) -> bool:
return self._torque_enabled
def set_torque_mode(self, enable: bool):
torque_value = TORQUE_ENABLE if enable else TORQUE_DISABLE
with self._lock:
for dxl_id in self._ids:
dxl_comm_result, dxl_error = self._packetHandler.write1ByteTxRx(
self._portHandler, dxl_id, ADDR_TORQUE_ENABLE, torque_value
)
if dxl_comm_result != COMM_SUCCESS or dxl_error != 0:
print(dxl_comm_result)
print(dxl_error)
raise RuntimeError(
f"Failed to set torque mode for Dynamixel with ID {dxl_id}"
)
self._torque_enabled = enable
def _start_reading_thread(self):
self._reading_thread = Thread(target=self._read_joint_angles)
self._reading_thread.daemon = True
self._reading_thread.start()
def _read_joint_angles(self):
# Continuously read joint angles and update the joint_angles array
while not self._stop_thread.is_set():
time.sleep(0.001)
with self._lock:
_joint_angles = np.zeros(len(self._ids), dtype=int)
dxl_comm_result = self._groupSyncRead.txRxPacket()
if dxl_comm_result != COMM_SUCCESS:
print(f"warning, comm failed: {dxl_comm_result}")
continue
for i, dxl_id in enumerate(self._ids):
if self._groupSyncRead.isAvailable(
dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION
):
angle = self._groupSyncRead.getData(
dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION
)
angle = np.int32(np.uint32(angle))
_joint_angles[i] = angle
else:
raise RuntimeError(
f"Failed to get joint angles for Dynamixel with ID {dxl_id}"
)
self._joint_angles = _joint_angles
# self._groupSyncRead.clearParam() # TODO what does this do? should i add it
def get_joints(self) -> np.ndarray:
# Return a copy of the joint_angles array to avoid race conditions
while self._joint_angles is None:
time.sleep(0.1)
with self._lock:
_j = self._joint_angles.copy()
return _j / 2048.0 * np.pi
def close(self):
self._stop_thread.set()
self._reading_thread.join()
self._portHandler.closePort()
def main():
# Set the port, baudrate, and servo IDs
ids = [1]
# Create a DynamixelDriver instance
try:
driver = DynamixelDriver(ids)
except FileNotFoundError:
driver = DynamixelDriver(ids, port="/dev/cu.usbserial-FT7WBMUB")
driver.set_torque_mode(True)
driver.set_torque_mode(False)
# Print the joint angles
try:
while True:
joint_angles = driver.get_joints()
print(f"Joint angles for IDs {ids}: {joint_angles}")
# print(f"Joint angles for IDs {ids[1]}: {joint_angles[1]}")
except KeyboardInterrupt:
driver.close()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,35 @@
import numpy as np
import pytest
from gello.dynamixel.driver import FakeDynamixelDriver
@pytest.fixture
def fake_driver():
return FakeDynamixelDriver(ids=[1, 2])
def test_set_joints(fake_driver):
fake_driver.set_torque_mode(True)
fake_driver.set_joints([np.pi / 2, np.pi / 2])
assert np.allclose(fake_driver.get_joints(), [np.pi / 2, np.pi / 2])
def test_set_joints_wrong_length(fake_driver):
with pytest.raises(ValueError):
fake_driver.set_joints([np.pi / 2])
def test_set_joints_torque_disabled(fake_driver):
with pytest.raises(RuntimeError):
fake_driver.set_joints([np.pi / 2, np.pi / 2])
def test_torque_enabled(fake_driver):
assert not fake_driver.torque_enabled()
fake_driver.set_torque_mode(True)
assert fake_driver.torque_enabled()
def test_get_joints(fake_driver):
assert np.allclose(fake_driver.get_joints(), [0, 0])

88
gello/env.py Normal file
View file

@ -0,0 +1,88 @@
import time
from typing import Any, Dict, Optional
import numpy as np
from gello.cameras.camera import CameraDriver
from gello.robots.robot import Robot
class Rate:
def __init__(self, rate: float):
self.last = time.time()
self.rate = rate
def sleep(self) -> None:
while self.last + 1.0 / self.rate > time.time():
time.sleep(0.0001)
self.last = time.time()
class RobotEnv:
def __init__(
self,
robot: Robot,
control_rate_hz: float = 100.0,
camera_dict: Optional[Dict[str, CameraDriver]] = None,
) -> None:
self._robot = robot
self._rate = Rate(control_rate_hz)
self._camera_dict = {} if camera_dict is None else camera_dict
def robot(self) -> Robot:
"""Get the robot object.
Returns:
robot: the robot object.
"""
return self._robot
def __len__(self):
return 0
def step(self, joints: np.ndarray) -> Dict[str, Any]:
"""Step the environment forward.
Args:
joints: joint angles command to step the environment with.
Returns:
obs: observation from the environment.
"""
assert len(joints) == (
self._robot.num_dofs()
), f"input:{len(joints)}, robot:{self._robot.num_dofs()}"
assert self._robot.num_dofs() == len(joints)
self._robot.command_joint_state(joints)
self._rate.sleep()
return self.get_obs()
def get_obs(self) -> Dict[str, Any]:
"""Get observation from the environment.
Returns:
obs: observation from the environment.
"""
observations = {}
for name, camera in self._camera_dict.items():
image, depth = camera.read()
observations[f"{name}_rgb"] = image
observations[f"{name}_depth"] = depth
robot_obs = self._robot.get_observations()
assert "joint_positions" in robot_obs
assert "joint_velocities" in robot_obs
assert "ee_pos_quat" in robot_obs
observations["joint_positions"] = robot_obs["joint_positions"]
observations["joint_velocities"] = robot_obs["joint_velocities"]
observations["ee_pos_quat"] = robot_obs["ee_pos_quat"]
observations["gripper_position"] = robot_obs["gripper_position"]
return observations
def main() -> None:
pass
if __name__ == "__main__":
main()

134
gello/robots/dynamixel.py Normal file
View file

@ -0,0 +1,134 @@
from typing import Dict, Optional, Sequence, Tuple
import numpy as np
from gello.robots.robot import Robot
class DynamixelRobot(Robot):
"""A class representing a UR robot."""
def __init__(
self,
joint_ids: Sequence[int],
joint_offsets: Optional[Sequence[float]] = None,
joint_signs: Optional[Sequence[int]] = None,
real: bool = False,
port: str = "/dev/ttyUSB0",
baudrate: int = 57600,
gripper_config: Optional[Tuple[int, float, float]] = None,
start_joints: Optional[np.ndarray] = None,
):
from gello.dynamixel.driver import (
DynamixelDriver,
DynamixelDriverProtocol,
FakeDynamixelDriver,
)
print(f"attempting to connect to port: {port}")
self.gripper_open_close: Optional[Tuple[float, float]]
if gripper_config is not None:
assert joint_offsets is not None
assert joint_signs is not None
# joint_ids.append(gripper_config[0])
# joint_offsets.append(0.0)
# joint_signs.append(1)
joint_ids = tuple(joint_ids) + (gripper_config[0],)
joint_offsets = tuple(joint_offsets) + (0.0,)
joint_signs = tuple(joint_signs) + (1,)
self.gripper_open_close = (
gripper_config[1] * np.pi / 180,
gripper_config[2] * np.pi / 180,
)
else:
self.gripper_open_close = None
self._joint_ids = joint_ids
self._driver: DynamixelDriverProtocol
if joint_offsets is None:
self._joint_offsets = np.zeros(len(joint_ids))
else:
self._joint_offsets = np.array(joint_offsets)
if joint_signs is None:
self._joint_signs = np.ones(len(joint_ids))
else:
self._joint_signs = np.array(joint_signs)
assert len(self._joint_ids) == len(self._joint_offsets), (
f"joint_ids: {len(self._joint_ids)}, "
f"joint_offsets: {len(self._joint_offsets)}"
)
assert len(self._joint_ids) == len(self._joint_signs), (
f"joint_ids: {len(self._joint_ids)}, "
f"joint_signs: {len(self._joint_signs)}"
)
assert np.all(
np.abs(self._joint_signs) == 1
), f"joint_signs: {self._joint_signs}"
if real:
self._driver = DynamixelDriver(joint_ids, port=port, baudrate=baudrate)
self._driver.set_torque_mode(False)
else:
self._driver = FakeDynamixelDriver(joint_ids)
self._torque_on = False
self._last_pos = None
self._alpha = 0.99
if start_joints is not None:
# loop through all joints and add +- 2pi to the joint offsets to get the closest to start joints
new_joint_offsets = []
current_joints = self.get_joint_state()
assert current_joints.shape == start_joints.shape
if gripper_config is not None:
current_joints = current_joints[:-1]
start_joints = start_joints[:-1]
for c_joint, s_joint, joint_offset in zip(
current_joints, start_joints, self._joint_offsets
):
new_joint_offsets.append(
np.pi * 2 * np.round((s_joint - c_joint) / (2 * np.pi))
+ joint_offset
)
if gripper_config is not None:
new_joint_offsets.append(self._joint_offsets[-1])
self._joint_offsets = np.array(new_joint_offsets)
def num_dofs(self) -> int:
return len(self._joint_ids)
def get_joint_state(self) -> np.ndarray:
pos = (self._driver.get_joints() - self._joint_offsets) * self._joint_signs
assert len(pos) == self.num_dofs()
if self.gripper_open_close is not None:
# map pos to [0, 1]
g_pos = (pos[-1] - self.gripper_open_close[0]) / (
self.gripper_open_close[1] - self.gripper_open_close[0]
)
g_pos = min(max(0, g_pos), 1)
pos[-1] = g_pos
if self._last_pos is None:
self._last_pos = pos
else:
# exponential smoothing
pos = self._last_pos * (1 - self._alpha) + pos * self._alpha
self._last_pos = pos
return pos
def command_joint_state(self, joint_state: np.ndarray) -> None:
self._driver.set_joints((joint_state + self._joint_offsets).tolist())
def set_torque_mode(self, mode: bool):
if mode == self._torque_on:
return
self._driver.set_torque_mode(mode)
self._torque_on = mode
def get_observations(self) -> Dict[str, np.ndarray]:
return {"joint_state": self.get_joint_state()}

88
gello/robots/panda.py Normal file
View file

@ -0,0 +1,88 @@
import time
from typing import Dict
import numpy as np
from gello.robots.robot import Robot
MAX_OPEN = 0.09
class PandaRobot(Robot):
"""A class representing a UR robot."""
def __init__(self, robot_ip: str = "100.97.47.74"):
from polymetis import GripperInterface, RobotInterface
self.robot = RobotInterface(
ip_address=robot_ip,
)
self.gripper = GripperInterface(
ip_address="localhost",
)
self.robot.go_home()
self.robot.start_joint_impedance()
self.gripper.goto(width=MAX_OPEN, speed=255, force=255)
time.sleep(1)
def num_dofs(self) -> int:
"""Get the number of joints of the robot.
Returns:
int: The number of joints of the robot.
"""
return 8
def get_joint_state(self) -> np.ndarray:
"""Get the current state of the leader robot.
Returns:
T: The current state of the leader robot.
"""
robot_joints = self.robot.get_joint_positions()
gripper_pos = self.gripper.get_state()
pos = np.append(robot_joints, gripper_pos.width / MAX_OPEN)
return pos
def command_joint_state(self, joint_state: np.ndarray) -> None:
"""Command the leader robot to a given state.
Args:
joint_state (np.ndarray): The state to command the leader robot to.
"""
import torch
self.robot.update_desired_joint_positions(torch.tensor(joint_state[:-1]))
self.gripper.goto(width=(MAX_OPEN * (1 - joint_state[-1])), speed=1, force=1)
def get_observations(self) -> Dict[str, np.ndarray]:
joints = self.get_joint_state()
pos_quat = np.zeros(7)
gripper_pos = np.array([joints[-1]])
return {
"joint_positions": joints,
"joint_velocities": joints,
"ee_pos_quat": pos_quat,
"gripper_position": gripper_pos,
}
def main():
robot = PandaRobot()
current_joints = robot.get_joint_state()
# move a small delta 0.1 rad
move_joints = current_joints + 0.05
# make last joint (gripper) closed
move_joints[-1] = 0.5
time.sleep(1)
m = 0.09
robot.gripper.goto(1 * m, speed=255, force=255)
time.sleep(1)
robot.gripper.goto(1.05 * m, speed=255, force=255)
time.sleep(1)
robot.gripper.goto(1.1 * m, speed=255, force=255)
time.sleep(1)
if __name__ == "__main__":
main()

128
gello/robots/robot.py Normal file
View file

@ -0,0 +1,128 @@
from abc import abstractmethod
from typing import Dict, Protocol
import numpy as np
class Robot(Protocol):
"""Robot protocol.
A protocol for a robot that can be controlled.
"""
@abstractmethod
def num_dofs(self) -> int:
"""Get the number of joints of the robot.
Returns:
int: The number of joints of the robot.
"""
raise NotImplementedError
@abstractmethod
def get_joint_state(self) -> np.ndarray:
"""Get the current state of the leader robot.
Returns:
T: The current state of the leader robot.
"""
raise NotImplementedError
@abstractmethod
def command_joint_state(self, joint_state: np.ndarray) -> None:
"""Command the leader robot to a given state.
Args:
joint_state (np.ndarray): The state to command the leader robot to.
"""
raise NotImplementedError
@abstractmethod
def get_observations(self) -> Dict[str, np.ndarray]:
"""Get the current observations of the robot.
This is to extract all the information that is available from the robot,
such as joint positions, joint velocities, etc. This may also include
information from additional sensors, such as cameras, force sensors, etc.
Returns:
Dict[str, np.ndarray]: A dictionary of observations.
"""
raise NotImplementedError
class PrintRobot(Robot):
"""A robot that prints the commanded joint state."""
def __init__(self, num_dofs: int, dont_print: bool = False):
self._num_dofs = num_dofs
self._joint_state = np.zeros(num_dofs)
self._dont_print = dont_print
def num_dofs(self) -> int:
return self._num_dofs
def get_joint_state(self) -> np.ndarray:
return self._joint_state
def command_joint_state(self, joint_state: np.ndarray) -> None:
assert len(joint_state) == (self._num_dofs), (
f"Expected joint state of length {self._num_dofs}, "
f"got {len(joint_state)}."
)
self._joint_state = joint_state
if not self._dont_print:
print(self._joint_state)
def get_observations(self) -> Dict[str, np.ndarray]:
joint_state = self.get_joint_state()
pos_quat = np.zeros(7)
return {
"joint_positions": joint_state,
"joint_velocities": joint_state,
"ee_pos_quat": pos_quat,
"gripper_position": np.array(0),
}
class BimanualRobot(Robot):
def __init__(self, robot_l: Robot, robot_r: Robot):
self._robot_l = robot_l
self._robot_r = robot_r
def num_dofs(self) -> int:
return self._robot_l.num_dofs() + self._robot_r.num_dofs()
def get_joint_state(self) -> np.ndarray:
return np.concatenate(
(self._robot_l.get_joint_state(), self._robot_r.get_joint_state())
)
def command_joint_state(self, joint_state: np.ndarray) -> None:
self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()])
self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :])
def get_observations(self) -> Dict[str, np.ndarray]:
l_obs = self._robot_l.get_observations()
r_obs = self._robot_r.get_observations()
assert l_obs.keys() == r_obs.keys()
return_obs = {}
for k in l_obs.keys():
try:
return_obs[k] = np.concatenate((l_obs[k], r_obs[k]))
except Exception as e:
print(e)
print(k)
print(l_obs[k])
print(r_obs[k])
raise RuntimeError()
return return_obs
def main():
pass
if __name__ == "__main__":
main()

View file

@ -0,0 +1,358 @@
"""Module to control Robotiq's grippers - tested with HAND-E.
Taken from https://github.com/githubuser0xFFFF/py_robotiq_gripper/blob/master/src/robotiq_gripper.py
"""
import socket
import threading
import time
from enum import Enum
from typing import OrderedDict, Tuple, Union
class RobotiqGripper:
"""Communicates with the gripper directly, via socket with string commands, leveraging string names for variables."""
# WRITE VARIABLES (CAN ALSO READ)
ACT = (
"ACT" # act : activate (1 while activated, can be reset to clear fault status)
)
GTO = (
"GTO" # gto : go to (will perform go to with the actions set in pos, for, spe)
)
ATR = "ATR" # atr : auto-release (emergency slow move)
ADR = (
"ADR" # adr : auto-release direction (open(1) or close(0) during auto-release)
)
FOR = "FOR" # for : force (0-255)
SPE = "SPE" # spe : speed (0-255)
POS = "POS" # pos : position (0-255), 0 = open
# READ VARIABLES
STA = "STA" # status (0 = is reset, 1 = activating, 3 = active)
PRE = "PRE" # position request (echo of last commanded position)
OBJ = "OBJ" # object detection (0 = moving, 1 = outer grip, 2 = inner grip, 3 = no object at rest)
FLT = "FLT" # fault (0=ok, see manual for errors if not zero)
ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work
class GripperStatus(Enum):
"""Gripper status reported by the gripper. The integer values have to match what the gripper sends."""
RESET = 0
ACTIVATING = 1
# UNUSED = 2 # This value is currently not used by the gripper firmware
ACTIVE = 3
class ObjectStatus(Enum):
"""Object status reported by the gripper. The integer values have to match what the gripper sends."""
MOVING = 0
STOPPED_OUTER_OBJECT = 1
STOPPED_INNER_OBJECT = 2
AT_DEST = 3
def __init__(self):
"""Constructor."""
self.socket = None
self.command_lock = threading.Lock()
self._min_position = 0
self._max_position = 255
self._min_speed = 0
self._max_speed = 255
self._min_force = 0
self._max_force = 255
def connect(self, hostname: str, port: int, socket_timeout: float = 10.0) -> None:
"""Connects to a gripper at the given address.
:param hostname: Hostname or ip.
:param port: Port.
:param socket_timeout: Timeout for blocking socket operations.
"""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
assert self.socket is not None
self.socket.connect((hostname, port))
self.socket.settimeout(socket_timeout)
def disconnect(self) -> None:
"""Closes the connection with the gripper."""
assert self.socket is not None
self.socket.close()
def _set_vars(self, var_dict: OrderedDict[str, Union[int, float]]):
"""Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response.
:param var_dict: Dictionary of variables to set (variable_name, value).
:return: True on successful reception of ack, false if no ack was received, indicating the set may not
have been effective.
"""
assert self.socket is not None
# construct unique command
cmd = "SET"
for variable, value in var_dict.items():
cmd += f" {variable} {str(value)}"
cmd += "\n" # new line is required for the command to finish
# atomic commands send/rcv
with self.command_lock:
self.socket.sendall(cmd.encode(self.ENCODING))
data = self.socket.recv(1024)
return self._is_ack(data)
def _set_var(self, variable: str, value: Union[int, float]):
"""Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response.
:param variable: Variable to set.
:param value: Value to set for the variable.
:return: True on successful reception of ack, false if no ack was received, indicating the set may not
have been effective.
"""
return self._set_vars(OrderedDict([(variable, value)]))
def _get_var(self, variable: str):
"""Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the response is received or the socket times out.
:param variable: Name of the variable to retrieve.
:return: Value of the variable as integer.
"""
assert self.socket is not None
# atomic commands send/rcv
with self.command_lock:
cmd = f"GET {variable}\n"
self.socket.sendall(cmd.encode(self.ENCODING))
data = self.socket.recv(1024)
# expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value
# note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here
var_name, value_str = data.decode(self.ENCODING).split()
if var_name != variable:
raise ValueError(
f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'"
)
value = int(value_str)
return value
@staticmethod
def _is_ack(data: str):
return data == b"ack"
def _reset(self):
"""Reset the gripper.
The following code is executed in the corresponding script function
def rq_reset(gripper_socket="1"):
rq_set_var("ACT", 0, gripper_socket)
rq_set_var("ATR", 0, gripper_socket)
while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
rq_set_var("ACT", 0, gripper_socket)
rq_set_var("ATR", 0, gripper_socket)
sync()
end
sleep(0.5)
end
"""
self._set_var(self.ACT, 0)
self._set_var(self.ATR, 0)
while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
self._set_var(self.ACT, 0)
self._set_var(self.ATR, 0)
time.sleep(0.5)
def activate(self, auto_calibrate: bool = True):
"""Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags.
:param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion.
The following code is executed in the corresponding script function
def rq_activate(gripper_socket="1"):
if (not rq_is_gripper_activated(gripper_socket)):
rq_reset(gripper_socket)
while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
rq_reset(gripper_socket)
sync()
end
rq_set_var("ACT",1, gripper_socket)
end
end
def rq_activate_and_wait(gripper_socket="1"):
if (not rq_is_gripper_activated(gripper_socket)):
rq_activate(gripper_socket)
sleep(1.0)
while(not rq_get_var("ACT", 1, gripper_socket) == 1 or not rq_get_var("STA", 1, gripper_socket) == 3):
sleep(0.1)
end
sleep(0.5)
end
end
"""
if not self.is_active():
self._reset()
while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
time.sleep(0.01)
self._set_var(self.ACT, 1)
time.sleep(1.0)
while not self._get_var(self.ACT) == 1 or not self._get_var(self.STA) == 3:
time.sleep(0.01)
# auto-calibrate position range if desired
if auto_calibrate:
self.auto_calibrate()
def is_active(self):
"""Returns whether the gripper is active."""
status = self._get_var(self.STA)
return (
RobotiqGripper.GripperStatus(status) == RobotiqGripper.GripperStatus.ACTIVE
)
def get_min_position(self) -> int:
"""Returns the minimum position the gripper can reach (open position)."""
return self._min_position
def get_max_position(self) -> int:
"""Returns the maximum position the gripper can reach (closed position)."""
return self._max_position
def get_open_position(self) -> int:
"""Returns what is considered the open position for gripper (minimum position value)."""
return self.get_min_position()
def get_closed_position(self) -> int:
"""Returns what is considered the closed position for gripper (maximum position value)."""
return self.get_max_position()
def is_open(self):
"""Returns whether the current position is considered as being fully open."""
return self.get_current_position() <= self.get_open_position()
def is_closed(self):
"""Returns whether the current position is considered as being fully closed."""
return self.get_current_position() >= self.get_closed_position()
def get_current_position(self) -> int:
"""Returns the current position as returned by the physical hardware."""
return self._get_var(self.POS)
def auto_calibrate(self, log: bool = True) -> None:
"""Attempts to calibrate the open and closed positions, by slowly closing and opening the gripper.
:param log: Whether to print the results to log.
"""
# first try to open in case we are holding an object
(position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
raise RuntimeError(f"Calibration failed opening to start: {str(status)}")
# try to close as far as possible, and record the number
(position, status) = self.move_and_wait_for_pos(
self.get_closed_position(), 64, 1
)
if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
raise RuntimeError(
f"Calibration failed because of an object: {str(status)}"
)
assert position <= self._max_position
self._max_position = position
# try to open as far as possible, and record the number
(position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
raise RuntimeError(
f"Calibration failed because of an object: {str(status)}"
)
assert position >= self._min_position
self._min_position = position
if log:
print(
f"Gripper auto-calibrated to [{self.get_min_position()}, {self.get_max_position()}]"
)
def move(self, position: int, speed: int, force: int) -> Tuple[bool, int]:
"""Sends commands to start moving towards the given position, with the specified speed and force.
:param position: Position to move to [min_position, max_position]
:param speed: Speed to move at [min_speed, max_speed]
:param force: Force to use [min_force, max_force]
:return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with
the actual position that was requested, after being adjusted to the min/max calibrated range.
"""
def clip_val(min_val, val, max_val):
return max(min_val, min(val, max_val))
clip_pos = clip_val(self._min_position, position, self._max_position)
clip_spe = clip_val(self._min_speed, speed, self._max_speed)
clip_for = clip_val(self._min_force, force, self._max_force)
# moves to the given position with the given speed and force
var_dict = OrderedDict(
[
(self.POS, clip_pos),
(self.SPE, clip_spe),
(self.FOR, clip_for),
(self.GTO, 1),
]
)
succ = self._set_vars(var_dict)
time.sleep(0.008) # need to wait (dont know why)
return succ, clip_pos
def move_and_wait_for_pos(
self, position: int, speed: int, force: int
) -> Tuple[int, ObjectStatus]: # noqa
"""Sends commands to start moving towards the given position, with the specified speed and force, and then waits for the move to complete.
:param position: Position to move to [min_position, max_position]
:param speed: Speed to move at [min_speed, max_speed]
:param force: Force to use [min_force, max_force]
:return: A tuple with an integer representing the last position returned by the gripper after it notified
that the move had completed, a status indicating how the move ended (see ObjectStatus enum for details). Note
that it is possible that the position was not reached, if an object was detected during motion.
"""
set_ok, cmd_pos = self.move(position, speed, force)
if not set_ok:
raise RuntimeError("Failed to set variables for move.")
# wait until the gripper acknowledges that it will try to go to the requested position
while self._get_var(self.PRE) != cmd_pos:
time.sleep(0.001)
# wait until not moving
cur_obj = self._get_var(self.OBJ)
while (
RobotiqGripper.ObjectStatus(cur_obj) == RobotiqGripper.ObjectStatus.MOVING
):
cur_obj = self._get_var(self.OBJ)
# report the actual position and the object status
final_pos = self._get_var(self.POS)
final_obj = cur_obj
return final_pos, RobotiqGripper.ObjectStatus(final_obj)
def main():
# test open and closing the gripper
gripper = RobotiqGripper()
gripper.connect(hostname="192.168.1.10", port=63352)
# gripper.activate()
print(gripper.get_current_position())
gripper.move(20, 255, 1)
time.sleep(0.2)
print(gripper.get_current_position())
gripper.move(65, 255, 1)
time.sleep(0.2)
print(gripper.get_current_position())
gripper.move(20, 255, 1)
gripper.disconnect()
if __name__ == "__main__":
main()

256
gello/robots/sim_robot.py Normal file
View file

@ -0,0 +1,256 @@
import pickle
import threading
import time
from typing import Any, Dict, Optional
import mujoco
import mujoco.viewer
import numpy as np
import zmq
from dm_control import mjcf
from gello.robots.robot import Robot
assert mujoco.viewer is mujoco.viewer
def attach_hand_to_arm(
arm_mjcf: mjcf.RootElement,
hand_mjcf: mjcf.RootElement,
) -> None:
"""Attaches a hand to an arm.
The arm must have a site named "attachment_site".
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
Args:
arm_mjcf: The mjcf.RootElement of the arm.
hand_mjcf: The mjcf.RootElement of the hand.
Raises:
ValueError: If the arm does not have a site named "attachment_site".
"""
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
attachment_site = arm_mjcf.find("site", "attachment_site")
if attachment_site is None:
raise ValueError("No attachment site found in the arm model.")
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
arm_key = arm_mjcf.find("key", "home")
if arm_key is not None:
hand_key = hand_mjcf.find("key", "home")
if hand_key is None:
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
else:
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
attachment_site.attach(hand_mjcf)
def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None):
# assert robot_xml_path.endswith(".xml")
arena = mjcf.RootElement()
arm_simulate = mjcf.from_path(robot_xml_path)
# arm_copy = mjcf.from_path(xml_path)
if gripper_xml_path is not None:
# attach gripper to the robot at "attachment_site"
gripper_simulate = mjcf.from_path(gripper_xml_path)
attach_hand_to_arm(arm_simulate, gripper_simulate)
arena.worldbody.attach(arm_simulate)
# arena.worldbody.attach(arm_copy)
return arena
class ZMQServerThread(threading.Thread):
def __init__(self, server):
super().__init__()
self._server = server
def run(self):
self._server.serve()
def terminate(self):
self._server.stop()
class ZMQRobotServer:
"""A class representing a ZMQ server for a robot."""
def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556):
self._robot = robot
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
addr = f"tcp://{host}:{port}"
self._socket.bind(addr)
self._stop_event = threading.Event()
def serve(self) -> None:
"""Serve the robot state and commands over ZMQ."""
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
while not self._stop_event.is_set():
try:
message = self._socket.recv()
request = pickle.loads(message)
# Call the appropriate method based on the request
method = request.get("method")
args = request.get("args", {})
result: Any
if method == "num_dofs":
result = self._robot.num_dofs()
elif method == "get_joint_state":
result = self._robot.get_joint_state()
elif method == "command_joint_state":
result = self._robot.command_joint_state(**args)
elif method == "get_observations":
result = self._robot.get_observations()
else:
result = {"error": "Invalid method"}
print(result)
raise NotImplementedError(
f"Invalid method: {method}, {args, result}"
)
self._socket.send(pickle.dumps(result))
except zmq.error.Again:
print("Timeout in ZMQLeaderServer serve")
# Timeout occurred, check if the stop event is set
def stop(self) -> None:
self._stop_event.set()
self._socket.close()
self._context.term()
class MujocoRobotServer:
def __init__(
self,
xml_path: str,
gripper_xml_path: Optional[str] = None,
host: str = "127.0.0.1",
port: int = 5556,
print_joints: bool = False,
):
self._has_gripper = gripper_xml_path is not None
arena = build_scene(xml_path, gripper_xml_path)
assets: Dict[str, str] = {}
for asset in arena.asset.all_children():
if asset.tag == "mesh":
f = asset.file
assets[f.get_vfs_filename()] = asset.file.contents
xml_string = arena.to_xml_string()
# save xml_string to file
with open("arena.xml", "w") as f:
f.write(xml_string)
self._model = mujoco.MjModel.from_xml_string(xml_string, assets)
self._data = mujoco.MjData(self._model)
self._num_joints = self._model.nu
self._joint_state = np.zeros(self._num_joints)
self._joint_cmd = self._joint_state
self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port)
self._zmq_server_thread = ZMQServerThread(self._zmq_server)
self._print_joints = print_joints
def num_dofs(self) -> int:
return self._num_joints
def get_joint_state(self) -> np.ndarray:
return self._joint_state
def command_joint_state(self, joint_state: np.ndarray) -> None:
assert len(joint_state) == self._num_joints, (
f"Expected joint state of length {self._num_joints}, "
f"got {len(joint_state)}."
)
if self._has_gripper:
_joint_state = joint_state.copy()
_joint_state[-1] = _joint_state[-1] * 255
self._joint_cmd = _joint_state
else:
self._joint_cmd = joint_state.copy()
def freedrive_enabled(self) -> bool:
return True
def set_freedrive_mode(self, enable: bool):
pass
def get_observations(self) -> Dict[str, np.ndarray]:
joint_positions = self._data.qpos.copy()[: self._num_joints]
joint_velocities = self._data.qvel.copy()[: self._num_joints]
ee_site = "attachment_site"
try:
ee_pos = self._data.site_xpos.copy()[
mujoco.mj_name2id(self._model, 6, ee_site)
]
ee_mat = self._data.site_xmat.copy()[
mujoco.mj_name2id(self._model, 6, ee_site)
]
ee_quat = np.zeros(4)
mujoco.mju_mat2Quat(ee_quat, ee_mat)
except Exception:
ee_pos = np.zeros(3)
ee_quat = np.zeros(4)
ee_quat[0] = 1
gripper_pos = self._data.qpos.copy()[self._num_joints - 1]
return {
"joint_positions": joint_positions,
"joint_velocities": joint_velocities,
"ee_pos_quat": np.concatenate([ee_pos, ee_quat]),
"gripper_position": gripper_pos,
}
def serve(self) -> None:
# start the zmq server
self._zmq_server_thread.start()
with mujoco.viewer.launch_passive(self._model, self._data) as viewer:
while viewer.is_running():
step_start = time.time()
# mj_step can be replaced with code that also evaluates
# a policy and applies a control signal before stepping the physics.
self._data.ctrl[:] = self._joint_cmd
# self._data.qpos[:] = self._joint_cmd
mujoco.mj_step(self._model, self._data)
self._joint_state = self._data.qpos.copy()[: self._num_joints]
if self._print_joints:
print(self._joint_state)
# Example modification of a viewer option: toggle contact points every two seconds.
with viewer.lock():
# TODO remove?
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(
self._data.time % 2
)
# Pick up changes to the physics state, apply perturbations, update options from GUI.
viewer.sync()
# Rudimentary time keeping, will drift relative to wall clock.
time_until_next_step = self._model.opt.timestep - (
time.time() - step_start
)
if time_until_next_step > 0:
time.sleep(time_until_next_step)
def stop(self) -> None:
self._zmq_server_thread.join()
def __del__(self) -> None:
self.stop()

133
gello/robots/ur.py Normal file
View file

@ -0,0 +1,133 @@
from typing import Dict
import numpy as np
from gello.robots.robot import Robot
class URRobot(Robot):
"""A class representing a UR robot."""
def __init__(self, robot_ip: str = "192.168.1.10", no_gripper: bool = False):
import rtde_control
import rtde_receive
[print("in ur robot") for _ in range(4)]
try:
self.robot = rtde_control.RTDEControlInterface(robot_ip)
except Exception as e:
print(e)
print(robot_ip)
self.r_inter = rtde_receive.RTDEReceiveInterface(robot_ip)
if not no_gripper:
from gello.robots.robotiq_gripper import RobotiqGripper
self.gripper = RobotiqGripper()
self.gripper.connect(hostname=robot_ip, port=63352)
print("gripper connected")
# gripper.activate()
[print("connect") for _ in range(4)]
self._free_drive = False
self.robot.endFreedriveMode()
self._use_gripper = not no_gripper
def num_dofs(self) -> int:
"""Get the number of joints of the robot.
Returns:
int: The number of joints of the robot.
"""
if self._use_gripper:
return 7
return 6
def _get_gripper_pos(self) -> float:
import time
time.sleep(0.01)
gripper_pos = self.gripper.get_current_position()
assert 0 <= gripper_pos <= 255, "Gripper position must be between 0 and 255"
return gripper_pos / 255
def get_joint_state(self) -> np.ndarray:
"""Get the current state of the leader robot.
Returns:
T: The current state of the leader robot.
"""
robot_joints = self.r_inter.getActualQ()
if self._use_gripper:
gripper_pos = self._get_gripper_pos()
pos = np.append(robot_joints, gripper_pos)
else:
pos = robot_joints
return pos
def command_joint_state(self, joint_state: np.ndarray) -> None:
"""Command the leader robot to a given state.
Args:
joint_state (np.ndarray): The state to command the leader robot to.
"""
velocity = 0.5
acceleration = 0.5
dt = 1.0 / 500 # 2ms
lookahead_time = 0.2
gain = 100
robot_joints = joint_state[:6]
t_start = self.robot.initPeriod()
self.robot.servoJ(
robot_joints, velocity, acceleration, dt, lookahead_time, gain
)
if self._use_gripper:
gripper_pos = joint_state[-1] * 255
self.gripper.move(gripper_pos, 255, 10)
self.robot.waitPeriod(t_start)
def freedrive_enabled(self) -> bool:
"""Check if the robot is in freedrive mode.
Returns:
bool: True if the robot is in freedrive mode, False otherwise.
"""
return self._free_drive
def set_freedrive_mode(self, enable: bool) -> None:
"""Set the freedrive mode of the robot.
Args:
enable (bool): True to enable freedrive mode, False to disable it.
"""
if enable and not self._free_drive:
self._free_drive = True
self.robot.freedriveMode()
elif not enable and self._free_drive:
self._free_drive = False
self.robot.endFreedriveMode()
def get_observations(self) -> Dict[str, np.ndarray]:
joints = self.get_joint_state()
pos_quat = np.zeros(7)
gripper_pos = np.array([joints[-1]])
return {
"joint_positions": joints,
"joint_velocities": joints,
"ee_pos_quat": pos_quat,
"gripper_position": gripper_pos,
}
def main():
robot_ip = "192.168.1.11"
ur = URRobot(robot_ip, no_gripper=True)
print(ur)
ur.set_freedrive_mode(True)
print(ur.get_observations())
if __name__ == "__main__":
main()

358
gello/robots/xarm_robot.py Normal file
View file

@ -0,0 +1,358 @@
import dataclasses
import threading
import time
from typing import Dict, Optional
import numpy as np
from pyquaternion import Quaternion
from gello.robots.robot import Robot
def _aa_from_quat(quat: np.ndarray) -> np.ndarray:
"""Convert a quaternion to an axis-angle representation.
Args:
quat (np.ndarray): The quaternion to convert.
Returns:
np.ndarray: The axis-angle representation of the quaternion.
"""
assert quat.shape == (4,), "Input quaternion must be a 4D vector."
norm = np.linalg.norm(quat)
assert norm != 0, "Input quaternion must not be a zero vector."
quat = quat / norm # Normalize the quaternion
Q = Quaternion(w=quat[3], x=quat[0], y=quat[1], z=quat[2])
angle = Q.angle
axis = Q.axis
aa = axis * angle
return aa
def _quat_from_aa(aa: np.ndarray) -> np.ndarray:
"""Convert an axis-angle representation to a quaternion.
Args:
aa (np.ndarray): The axis-angle representation to convert.
Returns:
np.ndarray: The quaternion representation of the axis-angle.
"""
assert aa.shape == (3,), "Input axis-angle must be a 3D vector."
norm = np.linalg.norm(aa)
assert norm != 0, "Input axis-angle must not be a zero vector."
axis = aa / norm # Normalize the axis-angle
Q = Quaternion(axis=axis, angle=norm)
quat = np.array([Q.x, Q.y, Q.z, Q.w])
return quat
@dataclasses.dataclass(frozen=True)
class RobotState:
x: float
y: float
z: float
gripper: float
j1: float
j2: float
j3: float
j4: float
j5: float
j6: float
j7: float
aa: np.ndarray
@staticmethod
def from_robot(
cartesian: np.ndarray,
joints: np.ndarray,
gripper: float,
aa: np.ndarray,
) -> "RobotState":
return RobotState(
cartesian[0],
cartesian[1],
cartesian[2],
gripper,
joints[0],
joints[1],
joints[2],
joints[3],
joints[4],
joints[5],
joints[6],
aa,
)
def cartesian_pos(self) -> np.ndarray:
return np.array([self.x, self.y, self.z])
def quat(self) -> np.ndarray:
return _quat_from_aa(self.aa)
def joints(self) -> np.ndarray:
return np.array([self.j1, self.j2, self.j3, self.j4, self.j5, self.j6, self.j7])
def gripper_pos(self) -> float:
return self.gripper
class Rate:
def __init__(self, *, duration):
self.duration = duration
self.last = time.time()
def sleep(self, duration=None) -> None:
duration = self.duration if duration is None else duration
assert duration >= 0
now = time.time()
passed = now - self.last
remaining = duration - passed
assert passed >= 0
if remaining > 0.0001:
time.sleep(remaining)
self.last = time.time()
class XArmRobot(Robot):
GRIPPER_OPEN = 800
GRIPPER_CLOSE = 0
# MAX_DELTA = 0.2
DEFAULT_MAX_DELTA = 0.05
def num_dofs(self) -> int:
return 8
def get_joint_state(self) -> np.ndarray:
state = self.get_state()
gripper = state.gripper_pos()
all_dofs = np.concatenate([state.joints(), np.array([gripper])])
return all_dofs
def command_joint_state(self, joint_state: np.ndarray) -> None:
if len(joint_state) == 7:
self.set_command(joint_state, None)
elif len(joint_state) == 8:
self.set_command(joint_state[:7], joint_state[7])
else:
raise ValueError(
f"Invalid joint state: {joint_state}, len={len(joint_state)}"
)
def stop(self):
self.running = False
if self.robot is not None:
self.robot.disconnect()
if self.command_thread is not None:
self.command_thread.join()
def __init__(
self,
ip: str = "192.168.1.226",
real: bool = True,
control_frequency: float = 50.0,
max_delta: float = DEFAULT_MAX_DELTA,
):
print(ip)
self.real = real
self.max_delta = max_delta
if real:
from xarm.wrapper import XArmAPI
self.robot = XArmAPI(ip, is_radian=True)
else:
self.robot = None
self._control_frequency = control_frequency
self._clear_error_states()
self._set_gripper_position(self.GRIPPER_OPEN)
self.last_state_lock = threading.Lock()
self.target_command_lock = threading.Lock()
self.last_state = self._update_last_state()
self.target_command = {
"joints": self.last_state.joints(),
"gripper": 0,
}
self.running = True
self.command_thread = None
if real:
self.command_thread = threading.Thread(target=self._robot_thread)
self.command_thread.start()
def get_state(self) -> RobotState:
with self.last_state_lock:
return self.last_state
def set_command(self, joints: np.ndarray, gripper: Optional[float] = None) -> None:
with self.target_command_lock:
self.target_command = {
"joints": joints,
"gripper": gripper,
}
def _clear_error_states(self):
if self.robot is None:
return
self.robot.clean_error()
self.robot.clean_warn()
self.robot.motion_enable(True)
time.sleep(1)
self.robot.set_mode(1)
time.sleep(1)
self.robot.set_collision_sensitivity(0)
time.sleep(1)
self.robot.set_state(state=0)
time.sleep(1)
self.robot.set_gripper_enable(True)
time.sleep(1)
self.robot.set_gripper_mode(0)
time.sleep(1)
self.robot.set_gripper_speed(3000)
time.sleep(1)
def _get_gripper_pos(self) -> float:
if self.robot is None:
return 0.0
code, gripper_pos = self.robot.get_gripper_position()
while code != 0 or gripper_pos is None:
print(f"Error code {code} in get_gripper_position(). {gripper_pos}")
time.sleep(0.001)
code, gripper_pos = self.robot.get_gripper_position()
if code == 22:
self._clear_error_states()
normalized_gripper_pos = (gripper_pos - self.GRIPPER_OPEN) / (
self.GRIPPER_CLOSE - self.GRIPPER_OPEN
)
return normalized_gripper_pos
def _set_gripper_position(self, pos: int) -> None:
if self.robot is None:
return
self.robot.set_gripper_position(pos, wait=False)
# while self.robot.get_is_moving():
# time.sleep(0.01)
def _robot_thread(self):
rate = Rate(
duration=1 / self._control_frequency
) # command and update rate for robot
step_times = []
count = 0
while self.running:
s_t = time.time()
# update last state
self.last_state = self._update_last_state()
with self.target_command_lock:
joint_delta = np.array(
self.target_command["joints"] - self.last_state.joints()
)
gripper_command = self.target_command["gripper"]
norm = np.linalg.norm(joint_delta)
# threshold delta to be at most 0.01 in norm space
if norm > self.max_delta:
delta = joint_delta / norm * self.max_delta
else:
delta = joint_delta
# command position
self._set_position(
self.last_state.joints() + delta,
)
if gripper_command is not None:
set_point = gripper_command
self._set_gripper_position(
self.GRIPPER_OPEN
+ set_point * (self.GRIPPER_CLOSE - self.GRIPPER_OPEN)
)
self.last_state = self._update_last_state()
rate.sleep()
step_times.append(time.time() - s_t)
count += 1
if count % 1000 == 0:
# Mean, Std, Min, Max, only show 3 decimal places and string pad with 10 spaces
frequency = 1 / np.mean(step_times)
# print(f"Step time - mean: {np.mean(step_times):10.3f}, std: {np.std(step_times):10.3f}, min: {np.min(step_times):10.3f}, max: {np.max(step_times):10.3f}")
print(
f"Low Level Frequency - mean: {frequency:10.3f}, std: {np.std(frequency):10.3f}, min: {np.min(frequency):10.3f}, max: {np.max(frequency):10.3f}"
)
step_times = []
def _update_last_state(self) -> RobotState:
with self.last_state_lock:
if self.robot is None:
return RobotState(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, np.zeros(3))
gripper_pos = self._get_gripper_pos()
code, servo_angle = self.robot.get_servo_angle(is_radian=True)
while code != 0:
print(f"Error code {code} in get_servo_angle().")
self._clear_error_states()
code, servo_angle = self.robot.get_servo_angle(is_radian=True)
code, cart_pos = self.robot.get_position_aa(is_radian=True)
while code != 0:
print(f"Error code {code} in get_position().")
self._clear_error_states()
code, cart_pos = self.robot.get_position_aa(is_radian=True)
cart_pos = np.array(cart_pos)
aa = cart_pos[3:]
cart_pos[:3] /= 1000
return RobotState.from_robot(
cart_pos,
servo_angle,
gripper_pos,
aa,
)
def _set_position(
self,
joints: np.ndarray,
) -> None:
if self.robot is None:
return
# threhold xyz to be in min max
ret = self.robot.set_servo_angle_j(joints, wait=False, is_radian=True)
if ret in [1, 9]:
self._clear_error_states()
def get_observations(self) -> Dict[str, np.ndarray]:
state = self.get_state()
pos_quat = np.concatenate([state.cartesian_pos(), state.quat()])
joints = self.get_joint_state()
return {
"joint_positions": joints, # rotational joint + gripper state
"joint_velocities": joints,
"ee_pos_quat": pos_quat,
"gripper_position": np.array(state.gripper_pos()),
}
def main():
ip = "192.168.1.226"
robot = XArmRobot(ip)
import time
time.sleep(1)
print(robot.get_state())
time.sleep(1)
print(robot.get_state())
print("end")
robot.stop()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,69 @@
import pickle
import threading
from typing import Optional, Tuple
import numpy as np
import zmq
from gello.cameras.camera import CameraDriver
DEFAULT_CAMERA_PORT = 5000
class ZMQClientCamera(CameraDriver):
"""A class representing a ZMQ client for a leader robot."""
def __init__(self, port: int = DEFAULT_CAMERA_PORT, host: str = "127.0.0.1"):
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{port}")
def read(
self,
img_size: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Get the current state of the leader robot.
Returns:
T: The current state of the leader robot.
"""
# pack the image_size and send it to the server
send_message = pickle.dumps(img_size)
self._socket.send(send_message)
state_dict = pickle.loads(self._socket.recv())
return state_dict
class ZMQServerCamera:
def __init__(
self,
camera: CameraDriver,
port: int = DEFAULT_CAMERA_PORT,
host: str = "127.0.0.1",
):
self._camera = camera
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
addr = f"tcp://{host}:{port}"
debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}"
print(debug_message)
self._timout_message = f"Timeout in Camera Server, Camera: {camera}"
self._socket.bind(addr)
self._stop_event = threading.Event()
def serve(self) -> None:
"""Serve the leader robot state over ZMQ."""
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
while not self._stop_event.is_set():
try:
message = self._socket.recv()
img_size = pickle.loads(message)
camera_read = self._camera.read(img_size)
self._socket.send(pickle.dumps(camera_read))
except zmq.Again:
print(self._timout_message)
# Timeout occurred, check if the stop event is set
def stop(self) -> None:
"""Signal the server to stop serving."""
self._stop_event.set()

View file

@ -0,0 +1,125 @@
import pickle
import threading
from typing import Any, Dict
import numpy as np
import zmq
from gello.robots.robot import Robot
DEFAULT_ROBOT_PORT = 6000
class ZMQServerRobot:
def __init__(
self,
robot: Robot,
port: int = DEFAULT_ROBOT_PORT,
host: str = "127.0.0.1",
):
self._robot = robot
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REP)
addr = f"tcp://{host}:{port}"
debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}"
print(debug_message)
self._timout_message = f"Timeout in Robot Server, Robot: {robot}"
self._socket.bind(addr)
self._stop_event = threading.Event()
def serve(self) -> None:
"""Serve the leader robot state over ZMQ."""
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
while not self._stop_event.is_set():
try:
# Wait for next request from client
message = self._socket.recv()
request = pickle.loads(message)
# Call the appropriate method based on the request
method = request.get("method")
args = request.get("args", {})
result: Any
if method == "num_dofs":
result = self._robot.num_dofs()
elif method == "get_joint_state":
result = self._robot.get_joint_state()
elif method == "command_joint_state":
result = self._robot.command_joint_state(**args)
elif method == "get_observations":
result = self._robot.get_observations()
else:
result = {"error": "Invalid method"}
print(result)
raise NotImplementedError(
f"Invalid method: {method}, {args, result}"
)
self._socket.send(pickle.dumps(result))
except zmq.Again:
print(self._timout_message)
# Timeout occurred, check if the stop event is set
def stop(self) -> None:
"""Signal the server to stop serving."""
self._stop_event.set()
class ZMQClientRobot(Robot):
"""A class representing a ZMQ client for a leader robot."""
def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"):
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.connect(f"tcp://{host}:{port}")
def num_dofs(self) -> int:
"""Get the number of joints in the robot.
Returns:
int: The number of joints in the robot.
"""
request = {"method": "num_dofs"}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result
def get_joint_state(self) -> np.ndarray:
"""Get the current state of the leader robot.
Returns:
T: The current state of the leader robot.
"""
request = {"method": "get_joint_state"}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result
def command_joint_state(self, joint_state: np.ndarray) -> None:
"""Command the leader robot to the given state.
Args:
joint_state (T): The state to command the leader robot to.
"""
request = {
"method": "command_joint_state",
"args": {"joint_state": joint_state},
}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result
def get_observations(self) -> Dict[str, np.ndarray]:
"""Get the current observations of the leader robot.
Returns:
Dict[str, np.ndarray]: The current observations of the leader robot.
"""
request = {"method": "get_observations"}
send_message = pickle.dumps(request)
self._socket.send(send_message)
result = pickle.loads(self._socket.recv())
return result

Binary file not shown.

After

Width:  |  Height:  |  Size: 289 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 359 KiB

BIN
imgs/title.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 269 KiB

1
kill_nodes.sh Executable file
View file

@ -0,0 +1 @@
pkill -9 -f launch_nodes.py

8
mypy.ini Normal file
View file

@ -0,0 +1,8 @@
[mypy]
ignore_missing_imports = True
namespace_packages = True
no_implicit_optional = True
exclude = third_party/
[mypy-dynamixel_sdk.*]
ignore_missing_imports = True

12
pyrightconfig.json Normal file
View file

@ -0,0 +1,12 @@
{
"include": [
"gello"
],
"exclude": [
"**/node_modules",
"src/typestubs"
],
"reportMissingImports": false,
"reportMissingTypeStubs": false,
"reportGeneralTypeIssues": false
}

17
requirements.txt Normal file
View file

@ -0,0 +1,17 @@
# keep in alphabetical order
dm_control
Pillow
pygame
pyspacemouse
PyQt6
pyquaternion
pyrealsense2
pure-python-adb
quaternion
tyro
ur-rtde
zmq
xarm
xarm-python-sdk
numpy-quaternion
termcolor

13
requirements_dev.txt Normal file
View file

@ -0,0 +1,13 @@
# keep in alphabetical order
# for development
black
flake8
flake8-docstrings
isort
ipdb
jupyterlab
mypy
neovim
pyright
pytest
python-lsp-server[all]

View file

@ -0,0 +1,59 @@
from dataclasses import dataclass
import numpy as np
import tyro
from dm_control import composer, viewer
from gello.agents.gello_agent import DynamixelRobotConfig
from gello.dm_control_tasks.arms.ur5e import UR5e
from gello.dm_control_tasks.manipulation.arenas.floors import Floor
from gello.dm_control_tasks.manipulation.tasks.block_play import BlockPlay
@dataclass
class Args:
use_gello: bool = False
config = DynamixelRobotConfig(
joint_ids=(1, 2, 3, 4, 5, 6),
joint_offsets=(
-np.pi / 2,
1 * np.pi / 2 + np.pi,
np.pi / 2 + 0 * np.pi,
0 * np.pi + np.pi / 2,
np.pi - 2 * np.pi / 2,
-1 * np.pi / 2 + 2 * np.pi,
),
joint_signs=(1, 1, -1, 1, 1, 1),
gripper_config=(7, 20, -22),
)
def main(args: Args) -> None:
reset_joints_left = np.deg2rad([90, -90, -90, -90, 90, 0, 0])
robot = UR5e()
task = BlockPlay(robot, Floor(), reset_joints=reset_joints_left[:-1])
# task = BlockPlay(robot, Floor())
env = composer.Environment(task=task)
action_space = env.action_spec()
if args.use_gello:
gello = config.make_robot(
port="/dev/cu.usbserial-FT7WBEIA", start_joints=reset_joints_left
)
def policy(timestep) -> np.ndarray:
if args.use_gello:
joint_command = gello.get_joint_state()
joint_command = np.array(joint_command).copy()
joint_command[-1] = joint_command[-1] * 255
return joint_command
return np.random.uniform(action_space.minimum, action_space.maximum)
viewer.launch(env, policy=policy)
if __name__ == "__main__":
main(tyro.cli(Args))

View file

@ -0,0 +1,98 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple
import numpy as np
import tyro
from gello.dynamixel.driver import DynamixelDriver
MENAGERIE_ROOT: Path = Path(__file__).parent / "third_party" / "mujoco_menagerie"
@dataclass
class Args:
port: str = "/dev/ttyUSB0"
"""The port that GELLO is connected to."""
start_joints: Tuple[float, ...] = (0, 0, 0, 0, 0, 0)
"""The joint angles that the GELLO is placed in at (in radians)."""
joint_signs: Tuple[float, ...] = (1, 1, -1, 1, 1, 1)
"""The joint angles that the GELLO is placed in at (in radians)."""
gripper: bool = True
"""Whether or not the gripper is attached."""
def __post_init__(self):
assert len(self.joint_signs) == len(self.start_joints)
for idx, j in enumerate(self.joint_signs):
assert (
j == -1 or j == 1
), f"Joint idx: {idx} should be -1 or 1, but got {j}."
@property
def num_robot_joints(self) -> int:
return len(self.start_joints)
@property
def num_joints(self) -> int:
extra_joints = 1 if self.gripper else 0
return self.num_robot_joints + extra_joints
def get_config(args: Args) -> None:
joint_ids = list(range(1, args.num_joints + 1))
driver = DynamixelDriver(joint_ids, port=args.port, baudrate=57600)
# assume that the joint state shouold be args.start_joints
# find the offset, which is a multiple of np.pi/2 that minimizes the error between the current joint state and args.start_joints
# this is done by brute force, we seach in a range of +/- 8pi
def get_error(offset: float, index: int, joint_state: np.ndarray) -> float:
joint_sign_i = args.joint_signs[index]
joint_i = joint_sign_i * (joint_state[index] - offset)
start_i = args.start_joints[index]
return np.abs(joint_i - start_i)
for _ in range(10):
driver.get_joints() # warmup
for _ in range(1):
best_offsets = []
curr_joints = driver.get_joints()
for i in range(args.num_robot_joints):
best_offset = 0
best_error = 1e6
for offset in np.linspace(
-8 * np.pi, 8 * np.pi, 8 * 4 + 1
): # intervals of pi/2
error = get_error(offset, i, curr_joints)
if error < best_error:
best_error = error
best_offset = offset
best_offsets.append(best_offset)
print()
print("best offsets : ", [f"{x:.3f}" for x in best_offsets])
print(
"best offsets function of pi: ["
+ ", ".join([f"{int(np.round(x/(np.pi/2)))}*np.pi/2" for x in best_offsets])
+ " ]",
)
if args.gripper:
print(
"gripper open (degrees) ",
np.rad2deg(driver.get_joints()[-1]) - 0.2,
)
print(
"gripper close (degrees) ",
np.rad2deg(driver.get_joints()[-1]) - 42,
)
def main(args: Args) -> None:
get_config(args)
if __name__ == "__main__":
main(tyro.cli(Args))

39
scripts/launch.py Normal file
View file

@ -0,0 +1,39 @@
import os
import subprocess
current_file_path = os.path.abspath(__file__)
def run_docker_container():
user = os.getenv("USER")
container_name = f"gello_{user}"
gello_path = os.path.abspath(os.path.join(current_file_path, "../../"))
volume_mapping = f"{gello_path}:/gello"
cmd = [
"docker",
"run",
"--runtime=nvidia",
"--rm",
"--name",
container_name,
"--privileged",
"--volume",
volume_mapping,
"--volume",
"/home/gello:/homefolder",
"--net=host",
"--volume",
"/dev/serial/by-id/:/dev/serial/by-id/",
"-it",
"gello:latest",
"bash",
"-c",
"pip install -e third_party/DynamixelSDK/python && exec bash",
]
subprocess.run(cmd)
if __name__ == "__main__":
run_docker_container()

View file

@ -0,0 +1,41 @@
import time
from pathlib import Path
import mujoco
import mujoco.viewer
def main():
_PROJECT_ROOT: Path = Path(__file__).parent.parent
_MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie"
xml = _MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
# xml = _MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
m = mujoco.MjModel.from_xml_path(xml.as_posix())
d = mujoco.MjData(m)
with mujoco.viewer.launch_passive(m, d) as viewer:
# Close the viewer automatically after 30 wall-seconds.
while viewer.is_running():
step_start = time.time()
# mj_step can be replaced with code that also evaluates
# a policy and applies a control signal before stepping the physics.
mujoco.mj_step(m, d)
# Example modification of a viewer option: toggle contact points every two seconds.
with viewer.lock():
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(d.time % 2)
# Pick up changes to the physics state, apply perturbations, update options from GUI.
viewer.sync()
# Rudimentary time keeping, will drift relative to wall clock.
time_until_next_step = m.opt.timestep - (time.time() - step_start)
if time_until_next_step > 0:
time.sleep(time_until_next_step)
if __name__ == "__main__":
main()

25
setup.py Normal file
View file

@ -0,0 +1,25 @@
import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="gello",
version="0.0.1",
author="Philipp Wu",
author_email="philippwu@berkeley.edu",
description="software for GELLO",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/wuphilipp/gello_software",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
],
python_requires=">=3.8",
license="MIT",
install_requires=[
"numpy",
],
)

1
third_party/DynamixelSDK vendored Submodule

@ -0,0 +1 @@
Subproject commit 3450c7078917b262d9b36042c15444047aae226e

1
third_party/mujoco_menagerie vendored Submodule

@ -0,0 +1 @@
Subproject commit c118f3d5d5ec9fb27f832ac78b3c4971234f0f4f