initial commit, add gello software code and instructions
This commit is contained in:
parent
e7d842ad35
commit
18cc23a38e
70 changed files with 5875 additions and 4 deletions
3
.flake8
Normal file
3
.flake8
Normal file
|
@ -0,0 +1,3 @@
|
|||
[flake8]
|
||||
ignore = D100, D101, D102, D103, D104, D105, D107, E203, E501, W503, SIM201, SIM113, B027, C408, B008
|
||||
docstring-convention = google
|
44
.github/workflows/pythonapp.yml
vendored
Normal file
44
.github/workflows/pythonapp.yml
vendored
Normal file
|
@ -0,0 +1,44 @@
|
|||
name: Python application
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
schedule:
|
||||
- cron: "0 0 * * 0"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
# python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements_dev.txt
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
git submodule init
|
||||
git submodule update
|
||||
pip install third_party/DynamixelSDK/python
|
||||
|
||||
- name: Black
|
||||
run: |
|
||||
black --check gello
|
||||
- name: flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 gello --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest gello
|
29
.gitignore
vendored
Normal file
29
.gitignore
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
# keep each top level section alphabetical
|
||||
# keep each item within the sections alphabetical
|
||||
|
||||
# large files
|
||||
*.png
|
||||
*.gif
|
||||
|
||||
# mac
|
||||
*.DS_Store
|
||||
|
||||
# misc
|
||||
*logs*
|
||||
*runs*
|
||||
*tb_logs*
|
||||
*wandb*
|
||||
*wandb_logs*
|
||||
outputs/*
|
||||
|
||||
# python
|
||||
*__pycache__
|
||||
*.egg-info
|
||||
*.hypothesis
|
||||
*.ipynb_checkpoints
|
||||
*.mypy_cache
|
||||
*.pyc
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
MUJOCO_LOG.TXT
|
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
[submodule "third_party/mujoco_menagerie"]
|
||||
path = third_party/mujoco_menagerie
|
||||
url = https://github.com/deepmind/mujoco_menagerie.git
|
||||
[submodule "third_party/DynamixelSDK"]
|
||||
path = third_party/DynamixelSDK
|
||||
url = https://github.com/ROBOTIS-GIT/DynamixelSDK.git
|
6
.isort.cfg
Normal file
6
.isort.cfg
Normal file
|
@ -0,0 +1,6 @@
|
|||
[settings]
|
||||
use_parentheses=True
|
||||
include_trailing_comma=True
|
||||
multi_line_output=3
|
||||
ensure_newline_before_comments=True
|
||||
line_length=88
|
38
.pre-commit-config.yaml
Normal file
38
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,38 @@
|
|||
repos:
|
||||
|
||||
# remove unused python imports
|
||||
- repo: https://github.com/myint/autoflake.git
|
||||
rev: v2.1.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args: ["--in-place", "--remove-all-unused-imports", "--ignore-init-module-imports"]
|
||||
|
||||
# sort imports
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
# code format according to black
|
||||
- repo: https://github.com/ambv/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
# check for python styling with flake8
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies: [
|
||||
'flake8-docstrings',
|
||||
'flake8-bugbear',
|
||||
'flake8-comprehensions',
|
||||
'flake8-simplify',
|
||||
]
|
||||
|
||||
# cleanup notebooks
|
||||
- repo: https://github.com/kynan/nbstripout
|
||||
rev: 0.6.1
|
||||
hooks:
|
||||
- id: nbstripout
|
22
Dockerfile
Normal file
22
Dockerfile
Normal file
|
@ -0,0 +1,22 @@
|
|||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
|
||||
WORKDIR /gello
|
||||
|
||||
# Set environment variables first (less likely to change)
|
||||
ENV PYTHONPATH=/gello:/gello/third_party/oculus_reader/
|
||||
|
||||
# Group apt updates and installs together
|
||||
RUN apt update && apt install -y \
|
||||
libhidapi-dev \
|
||||
python3-pip \
|
||||
android-tools-adb \
|
||||
libegl1-mesa-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Python alias setup
|
||||
RUN echo "alias python=python3" >> ~/.bashrc
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt /gello
|
||||
RUN pip install -r requirements.txt
|
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2023 Philipp Wu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
180
README.md
180
README.md
|
@ -1,10 +1,174 @@
|
|||
# GELLO Software
|
||||
This is the central repo that holds the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/
|
||||
# GELLO
|
||||
This is the central repo that holds the all the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/
|
||||
See the GELLO hardware repo for the STL files and hardware instructions for building your own GELLO https://github.com/wuphilipp/gello_mechanical
|
||||
```
|
||||
git clone https://github.com/wuphilipp/gello_software.git
|
||||
cd gello
|
||||
```
|
||||
|
||||
# Installation
|
||||
<p align="center">
|
||||
<img src="imgs/title.png" />
|
||||
</p>
|
||||
|
||||
Coming Soon
|
||||
|
||||
## Use your own enviroment
|
||||
```
|
||||
git submodule init
|
||||
git submodule update
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
pip install -e third_party/DynamixelSDK/python
|
||||
```
|
||||
|
||||
## Use with Docker
|
||||
First install ```docker``` following this [link](https://docs.docker.com/engine/install/ubuntu/) on your host machine.
|
||||
Then you can clone the repo and build the corresponding docker environment
|
||||
|
||||
Build the docker image and tag it as gello:latest. If you are going to name it differently, you need to change the launch.py image name
|
||||
```
|
||||
docker build . -t gello:latest
|
||||
```
|
||||
|
||||
We have provided an entry point into the docker container
|
||||
```
|
||||
python experiments/launch.py
|
||||
```
|
||||
|
||||
# GELLO configuration setup (PLEASE READ)
|
||||
Now that you have downloaded the code, there is some additional preparation work to properly configure the Dynamixels and GELLO.
|
||||
These instructions will guide you on how to update the motor ids of the Dynamixels and then how to extract the joint offsets to configure your GELLO.
|
||||
|
||||
## Update motor IDs
|
||||
Install the [dynamixel_wizard](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2/).
|
||||
By default, each motor has the ID 1. In order for multiple dynamixels to be controlled by the same U2D2 controller board, each dynamixel must have a unique ID.
|
||||
This process must be done one motor at a time. Connect each motor, starting from the base motor, and assign them in increasing order until you reach the gripper.
|
||||
|
||||
Steps:
|
||||
* Connect a single motor to the controller and connect the controller to the computer.
|
||||
* Open the dynamixel wizard
|
||||
* Click scan (found at the top left corner), this should detect the dynamixel. Connect to the motor
|
||||
* Look for the ID address and change the ID to the appropriate number.
|
||||
* Repeat for each motor
|
||||
|
||||
## Create the GELLO configuration and determining joint ID's
|
||||
After the motor ID's are set, we can now connect to the GELLO controller device. However each motor has its own joint offset, which will result in a joint offset between GELLO and your actual robot arm.
|
||||
Dynamixels have a symmetric 4 hole pattern which means there the joint offset is a multiple of pi/2.
|
||||
The `GelloAgent` class accepts a `DynamixelRobotConfig` (found in `gello/agents/gello_agent.py`). The Dynamixel config specifies the parameters you need to find to operate your GELLO. Look at the documentation for more details.
|
||||
|
||||
We have created a simple script to automatically detect the joint offset:
|
||||
* set GELLO into a known configuration, where you know what the corresponding joint angles should be. For example, we set out GELLO in this configuration, where we know the desired ground truth joints. (0, -90, 90, -90, -90, 0)
|
||||
<p align="center">
|
||||
<img src="imgs/gello_matching_joints.jpg" width="45%"/>
|
||||
<img src="imgs/robot_known_configuration.jpg" width="45%"/>
|
||||
</p>
|
||||
|
||||
* run
|
||||
```
|
||||
python scripts/gello_get_offset.py \
|
||||
--start-joints 0 -1.57 1.57 -1.57 -1.57 0 \ # in radians
|
||||
--joint-signs 1 1 -1 1 1 1 \
|
||||
--port /dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6
|
||||
# replace values with your own
|
||||
```
|
||||
* Use the known starting joints for `start-joints`.
|
||||
* Use the `joint-signs` for your own robot (see below).
|
||||
* Use your serial port for `port`. You can find the port id of your U2D2 Dynamixel device by running `ls /dev/serial/by-id` and looking for the path that starts with `usb-FTDI_USB__-__Serial_Converter` (on Ubuntu). On Mac, look in /dev/ and the device that starts with `cu.usbserial`
|
||||
|
||||
`joint-signs` for each robot type:
|
||||
* UR: `1 1 -1 1 1 1`
|
||||
* Panda: `1 -1 1 1 1 -1 1`
|
||||
* xArm: `1 1 1 1 1 1 1`
|
||||
|
||||
The script prints out a list of joint offsets. Go to `gello/agents/gello_agent.py` and add a DynamixelRobotConfig to the PORT_CONFIG_MAP. You are now ready to run your GELLO!
|
||||
|
||||
# Using GELLO to control a robot!
|
||||
|
||||
The code provided here is simple and only relies on python packages. The code does NOT use ROS, but a ROS wrapper can easily be adapted from this code.
|
||||
For multiprocessing, we leverage [ZMQ](https://zeromq.org/)
|
||||
|
||||
## Testing in sim
|
||||
First test your GELLO with a simulated robot to make sure that the joint angles match as expected.
|
||||
In one terminal run
|
||||
```
|
||||
python experiments/launch_nodes.py --robot <sim_ur, sim_panda, or sim_xarm>
|
||||
```
|
||||
This launched the robot node. A simulated robot using the mujoco viewer should appear.
|
||||
|
||||
Then, launch your GELLO (the controller node).
|
||||
```
|
||||
python experiments/run_env.py --agent=gello
|
||||
```
|
||||
You should be able to use GELLO to control the simulated robot!
|
||||
|
||||
## Running on a real robot.
|
||||
Once you have verified that your GELLO is properly configured, you can test it on a real robot!
|
||||
|
||||
Before you run with the real robot, you will have to install a robot specific python package.
|
||||
The supported robots are in `gello/robots`.
|
||||
* UR: [ur_rtde](https://sdurobotics.gitlab.io/ur_rtde/installation/installation.html)
|
||||
* panda: [polymetis](https://facebookresearch.github.io/fairo/polymetis/installation.html). If you use a different framework to control the panda, the code is easy to adpot. See/Modify `gello/robots/panda.py`
|
||||
* xArm: [xArm python SDK](https://github.com/xArm-Developer/xArm-Python-SDK)
|
||||
|
||||
```
|
||||
# Launch all of the node
|
||||
python experiments/launch_nodes.py --robot=<your robot>
|
||||
# run the enviroment loop
|
||||
python experiments/run_env.py --agent=gello
|
||||
```
|
||||
|
||||
Ideally you can start your GELLO near a known configuration each time. If this is possible, you can set the `--start-joint` flag with GELLO's known starting configuration. This also enables the robot to reset before you begin teleoperation.
|
||||
|
||||
## Collect data
|
||||
We have provided a simple example for collecting data with gello.
|
||||
To save trajectories with the keyboard, add the following flag `--use-save-interface`
|
||||
|
||||
Data can then be processed using the demo_to_gdict script.
|
||||
```
|
||||
python gello/data_utils/demo_to_gdict.py --source-dir=<source dir location>
|
||||
```
|
||||
|
||||
## Running a bimanual system with GELLO
|
||||
GELLO also be used in bimanual configurations.
|
||||
For an example, see the `bimanual_ur` robot in `launch_nodes.py` and `--bimanual` flag in the `run_env.py` script.
|
||||
|
||||
## Notes
|
||||
Due to the use of multiprocessing, sometimes python process are not killed properly. We have provided the kill_nodes script which will kill the
|
||||
python processes.
|
||||
```
|
||||
./kill_nodes.sh
|
||||
```
|
||||
|
||||
### Using a new robot!
|
||||
If you want to use a new robot you need a GELLO that is compatible. If the kiniamtics are close enough, you may directly use an existing GELLO. Otherwise you will have to design your own.
|
||||
To add a new robot, simply implement the `Robot` protocol found in `gello/robots/robot`. See `gello/robots/panda.py`, `gello/robots/ur.py`, `gello/robots/xarm_robot.py` for examples.
|
||||
|
||||
### Contributing
|
||||
Please make a PR if you would like to contribute! The goal of this project is to enable more accessible and higher quality teleoperation devices and we would love your input!
|
||||
|
||||
You can optionally install some dev packages.
|
||||
```
|
||||
pip install -r requirements_dev.txt
|
||||
```
|
||||
|
||||
The code is organized as follows:
|
||||
* `scripts`: contains some helpful python `scripts`
|
||||
* `experiments`: contains entrypoints into the gello code
|
||||
* `gello`: contains all of the `gello` python package code
|
||||
* `agents`: teleoperation agents
|
||||
* `cameras`: code to interface with camera hardware
|
||||
* `data_utils`: data processing utils. used for imitation learning
|
||||
* `dm_control_tasks`: dm_control utils to build a simple dm_control enviroment. used for demos
|
||||
* `dynamixel`: code to interface with the dynamixel hardware
|
||||
* `robots`: robot specific interfaces
|
||||
* `zmq_core`: zmq utilities for enabling a multi node system
|
||||
|
||||
|
||||
This code base uses `isort` and `black` for code formatting.
|
||||
pre-commits hooks are great. This will automatically do some checking/formatting. To use the pre-commit hooks, run the following:
|
||||
```
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
# Citation
|
||||
|
||||
|
@ -15,3 +179,11 @@ Coming Soon
|
|||
year={2023},
|
||||
}
|
||||
```
|
||||
|
||||
# License & Acknowledgements
|
||||
This source code is licensed under the MIT license found in the LICENSE file. in the root directory of this source tree.
|
||||
|
||||
This project builds on top of or utilizes the following third party dependencies.
|
||||
* [google-deepmind/mujoco_menagerie](https://github.com/google-deepmind/mujoco_menagerie): Prebuilt robot models for mujoco
|
||||
* [brentyi/tyro](https://github.com/brentyi/tyro): Argument parsing and configuration
|
||||
* [ZMQ](https://zeromq.org/): Enables easy create of node like processes in python.
|
||||
|
|
20
config_hostmachine.sh
Normal file
20
config_hostmachine.sh
Normal file
|
@ -0,0 +1,20 @@
|
|||
#!/bin/bash
|
||||
|
||||
# install nvidia driver 520
|
||||
sudo apt-get install nvidia-520 -y
|
||||
|
||||
# install docker
|
||||
sudo apt-get update
|
||||
sudo apt-get install ca-certificates curl gnupg
|
||||
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.gpg
|
||||
|
||||
echo \
|
||||
"deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
|
||||
"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \
|
||||
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||
|
||||
sudo apt-get update
|
||||
sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
37
experiments/launch_camera_clients.py
Normal file
37
experiments/launch_camera_clients.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.zmq_core.camera_node import ZMQClientCamera
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
ports: Tuple[int, ...] = (5000, 5001)
|
||||
hostname: str = "127.0.0.1"
|
||||
# hostname: str = "128.32.175.167"
|
||||
|
||||
|
||||
def main(args):
|
||||
cameras = []
|
||||
import cv2
|
||||
|
||||
images_display_names = []
|
||||
for port in args.ports:
|
||||
cameras.append(ZMQClientCamera(port=port, host=args.hostname))
|
||||
images_display_names.append(f"image_{port}")
|
||||
cv2.namedWindow(images_display_names[-1], cv2.WINDOW_NORMAL)
|
||||
|
||||
while True:
|
||||
for display_name, camera in zip(images_display_names, cameras):
|
||||
image, depth = camera.read()
|
||||
stacked_depth = np.dstack([depth, depth, depth]).astype(np.uint8)
|
||||
image_depth = cv2.hconcat([image[:, :, ::-1], stacked_depth])
|
||||
cv2.imshow(display_name, image_depth)
|
||||
cv2.waitKey(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
40
experiments/launch_camera_nodes.py
Normal file
40
experiments/launch_camera_nodes.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
|
||||
import tyro
|
||||
|
||||
from gello.cameras.realsense_camera import RealSenseCamera, get_device_ids
|
||||
from gello.zmq_core.camera_node import ZMQServerCamera
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
# hostname: str = "127.0.0.1"
|
||||
hostname: str = "128.32.175.167"
|
||||
|
||||
|
||||
def launch_server(port: int, camera_id: int, args: Args):
|
||||
camera = RealSenseCamera(camera_id)
|
||||
server = ZMQServerCamera(camera, port=port, host=args.hostname)
|
||||
print(f"Starting camera server on port {port}")
|
||||
server.serve()
|
||||
|
||||
|
||||
def main(args):
|
||||
ids = get_device_ids()
|
||||
camera_port = 5000
|
||||
camera_servers = []
|
||||
for camera_id in ids:
|
||||
# start a python process for each camera
|
||||
print(f"Launching camera {camera_id} on port {camera_port}")
|
||||
camera_servers.append(
|
||||
Process(target=launch_server, args=(camera_port, camera_id, args))
|
||||
)
|
||||
camera_port += 1
|
||||
|
||||
for server in camera_servers:
|
||||
server.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
94
experiments/launch_nodes.py
Normal file
94
experiments/launch_nodes.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import tyro
|
||||
|
||||
from gello.robots.robot import BimanualRobot, PrintRobot
|
||||
from gello.zmq_core.robot_node import ZMQServerRobot
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
robot: str = "xarm"
|
||||
robot_port: int = 6001
|
||||
hostname: str = "127.0.0.1"
|
||||
robot_ip: str = "192.168.1.10"
|
||||
|
||||
|
||||
def launch_robot_server(args: Args):
|
||||
port = args.robot_port
|
||||
if args.robot == "sim_ur":
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
elif args.robot == "sim_panda":
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
|
||||
gripper_xml = None
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
elif args.robot == "sim_xarm":
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "ufactory_xarm7" / "xarm7.xml"
|
||||
gripper_xml = None
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
|
||||
else:
|
||||
if args.robot == "xarm":
|
||||
from gello.robots.xarm_robot import XArmRobot
|
||||
|
||||
robot = XArmRobot(ip=args.robot_ip)
|
||||
elif args.robot == "ur":
|
||||
from gello.robots.ur import URRobot
|
||||
|
||||
robot = URRobot(robot_ip=args.robot_ip)
|
||||
elif args.robot == "panda":
|
||||
from gello.robots.panda import PandaRobot
|
||||
|
||||
robot = PandaRobot(robot_ip=args.robot_ip)
|
||||
elif args.robot == "bimanual_ur":
|
||||
from gello.robots.ur import URRobot
|
||||
|
||||
# IP for the bimanual robot setup is hardcoded
|
||||
_robot_l = URRobot(robot_ip="192.168.2.10")
|
||||
_robot_r = URRobot(robot_ip="192.168.1.10")
|
||||
robot = BimanualRobot(_robot_l, _robot_r)
|
||||
elif args.robot == "none" or args.robot == "print":
|
||||
robot = PrintRobot(8)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
|
||||
)
|
||||
server = ZMQServerRobot(robot, port=port, host=args.hostname)
|
||||
print(f"Starting robot server on port {port}")
|
||||
server.serve()
|
||||
|
||||
|
||||
def main(args):
|
||||
launch_robot_server(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
192
experiments/quick_run.py
Normal file
192
experiments/quick_run.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
import atexit
|
||||
import glob
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.agents.agent import DummyAgent
|
||||
from gello.agents.gello_agent import GelloAgent
|
||||
from gello.agents.spacemouse_agent import SpacemouseAgent
|
||||
from gello.env import RobotEnv
|
||||
from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
hz: int = 100
|
||||
|
||||
agent: str = "gello"
|
||||
robot: str = "ur5"
|
||||
gello_port: Optional[str] = None
|
||||
mock: bool = False
|
||||
verbose: bool = False
|
||||
|
||||
hostname: str = "127.0.0.1"
|
||||
robot_port: int = 6001
|
||||
|
||||
|
||||
def launch_robot_server(port: int, args: Args):
|
||||
if args.robot == "sim_ur":
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
elif args.robot == "sim_panda":
|
||||
from gello.robots.sim_robot import MujocoRobotServer
|
||||
|
||||
MENAGERIE_ROOT: Path = (
|
||||
Path(__file__).parent.parent / "third_party" / "mujoco_menagerie"
|
||||
)
|
||||
xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
|
||||
gripper_xml = None
|
||||
server = MujocoRobotServer(
|
||||
xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname
|
||||
)
|
||||
server.serve()
|
||||
|
||||
else:
|
||||
if args.robot == "xarm":
|
||||
from gello.robots.xarm_robot import XArmRobot
|
||||
|
||||
robot = XArmRobot()
|
||||
elif args.robot == "ur5":
|
||||
from gello.robots.ur import URRobot
|
||||
|
||||
robot = URRobot(robot_ip=args.robot_ip)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none"
|
||||
)
|
||||
server = ZMQServerRobot(robot, port=port, host=args.hostname)
|
||||
print(f"Starting robot server on port {port}")
|
||||
server.serve()
|
||||
|
||||
|
||||
def start_robot_process(args: Args):
|
||||
process = Process(target=launch_robot_server, args=(args.robot_port, args))
|
||||
|
||||
# Function to kill the child process
|
||||
def kill_child_process(process):
|
||||
print("Killing child process...")
|
||||
process.terminate()
|
||||
|
||||
# Register the kill_child_process function to be called at exit
|
||||
atexit.register(kill_child_process, process)
|
||||
process.start()
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
start_robot_process(args)
|
||||
|
||||
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
|
||||
env = RobotEnv(robot_client, control_rate_hz=args.hz)
|
||||
|
||||
if args.agent == "gello":
|
||||
gello_port = args.gello_port
|
||||
if gello_port is None:
|
||||
usb_ports = glob.glob("/dev/serial/by-id/*")
|
||||
print(f"Found {len(usb_ports)} ports")
|
||||
if len(usb_ports) > 0:
|
||||
gello_port = usb_ports[0]
|
||||
print(f"using port {gello_port}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"No gello port found, please specify one or plug in gello"
|
||||
)
|
||||
agent = GelloAgent(port=gello_port)
|
||||
|
||||
reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0])
|
||||
curr_joints = env.get_obs()["joint_positions"]
|
||||
if reset_joints.shape == curr_joints.shape:
|
||||
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
||||
steps = min(int(max_delta / 0.01), 100)
|
||||
|
||||
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
||||
env.step(jnt)
|
||||
time.sleep(0.001)
|
||||
|
||||
elif args.agent == "quest":
|
||||
from gello.agents.quest_agent import SingleArmQuestAgent
|
||||
|
||||
agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l")
|
||||
elif args.agent == "spacemouse":
|
||||
agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose)
|
||||
elif args.agent == "dummy" or args.agent == "none":
|
||||
agent = DummyAgent(num_dofs=robot_client.num_dofs())
|
||||
else:
|
||||
raise ValueError("Invalid agent name")
|
||||
|
||||
# going to start position
|
||||
print("Going to start position")
|
||||
start_pos = agent.act(env.get_obs())
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
|
||||
abs_deltas = np.abs(start_pos - joints)
|
||||
id_max_joint_delta = np.argmax(abs_deltas)
|
||||
|
||||
max_joint_delta = 0.8
|
||||
if abs_deltas[id_max_joint_delta] > max_joint_delta:
|
||||
id_mask = abs_deltas > max_joint_delta
|
||||
print()
|
||||
ids = np.arange(len(id_mask))[id_mask]
|
||||
for i, delta, joint, current_j in zip(
|
||||
ids,
|
||||
abs_deltas[id_mask],
|
||||
start_pos[id_mask],
|
||||
joints[id_mask],
|
||||
):
|
||||
print(
|
||||
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
|
||||
assert len(start_pos) == len(
|
||||
joints
|
||||
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
|
||||
|
||||
max_delta = 0.05
|
||||
for _ in range(25):
|
||||
obs = env.get_obs()
|
||||
command_joints = agent.act(obs)
|
||||
current_joints = obs["joint_positions"]
|
||||
delta = command_joints - current_joints
|
||||
max_joint_delta = np.abs(delta).max()
|
||||
if max_joint_delta > max_delta:
|
||||
delta = delta / max_joint_delta * max_delta
|
||||
env.step(current_joints + delta)
|
||||
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
action = agent.act(obs)
|
||||
if (action - joints > 0.5).any():
|
||||
print("Action is too big")
|
||||
|
||||
# print which joints are too big
|
||||
joint_index = np.where(action - joints > 0.8)
|
||||
for j in joint_index:
|
||||
print(
|
||||
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
while True:
|
||||
action = agent.act(obs)
|
||||
obs = env.step(action)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
246
experiments/run_env.py
Normal file
246
experiments/run_env.py
Normal file
|
@ -0,0 +1,246 @@
|
|||
import datetime
|
||||
import glob
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.agents.agent import BimanualAgent, DummyAgent
|
||||
from gello.agents.gello_agent import GelloAgent
|
||||
from gello.data_utils.format_obs import save_frame
|
||||
from gello.env import RobotEnv
|
||||
from gello.robots.robot import PrintRobot
|
||||
from gello.zmq_core.robot_node import ZMQClientRobot
|
||||
|
||||
|
||||
def print_color(*args, color=None, attrs=(), **kwargs):
|
||||
import termcolor
|
||||
|
||||
if len(args) > 0:
|
||||
args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
agent: str = "none"
|
||||
robot_port: int = 6001
|
||||
wrist_camera_port: int = 5000
|
||||
base_camera_port: int = 5001
|
||||
hostname: str = "127.0.0.1"
|
||||
robot_type: str = None # only needed for quest agent or spacemouse agent
|
||||
hz: int = 100
|
||||
start_joints: Optional[Tuple[float, ...]] = None
|
||||
|
||||
gello_port: Optional[str] = None
|
||||
mock: bool = False
|
||||
use_save_interface: bool = False
|
||||
data_dir: str = "~/bc_data"
|
||||
bimanual: bool = False
|
||||
verbose: bool = False
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.mock:
|
||||
robot_client = PrintRobot(8, dont_print=True)
|
||||
camera_clients = {}
|
||||
else:
|
||||
camera_clients = {
|
||||
# you can optionally add camera nodes here for imitation learning purposes
|
||||
# "wrist": ZMQClientCamera(port=args.wrist_camera_port, host=args.hostname),
|
||||
# "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
|
||||
}
|
||||
robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
|
||||
env = RobotEnv(robot_client, control_rate_hz=args.hz, camera_dict=camera_clients)
|
||||
|
||||
if args.bimanual:
|
||||
if args.agent == "gello":
|
||||
# dynamixel control box port map (to distinguish left and right gello)
|
||||
right = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0"
|
||||
left = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0"
|
||||
left_agent = GelloAgent(port=left)
|
||||
right_agent = GelloAgent(port=right)
|
||||
agent = BimanualAgent(left_agent, right_agent)
|
||||
elif args.agent == "quest":
|
||||
from gello.agents.quest_agent import SingleArmQuestAgent
|
||||
|
||||
left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
|
||||
right_agent = SingleArmQuestAgent(
|
||||
robot_type=args.robot_type, which_hand="r"
|
||||
)
|
||||
agent = BimanualAgent(left_agent, right_agent)
|
||||
# raise NotImplementedError
|
||||
elif args.agent == "spacemouse":
|
||||
from gello.agents.spacemouse_agent import SpacemouseAgent
|
||||
|
||||
left_path = "/dev/hidraw0"
|
||||
right_path = "/dev/hidraw1"
|
||||
left_agent = SpacemouseAgent(
|
||||
robot_type=args.robot_type, device_path=left_path, verbose=args.verbose
|
||||
)
|
||||
right_agent = SpacemouseAgent(
|
||||
robot_type=args.robot_type,
|
||||
device_path=right_path,
|
||||
verbose=args.verbose,
|
||||
invert_button=True,
|
||||
)
|
||||
agent = BimanualAgent(left_agent, right_agent)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent name for bimanual: {args.agent}")
|
||||
|
||||
# System setup specific. This reset configuration works well on our setup. If you are mounting the robot
|
||||
# differently, you need a separate reset joint configuration.
|
||||
reset_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0, 0])
|
||||
reset_joints_right = np.deg2rad([0, -90, 90, -90, -90, 0, 0])
|
||||
reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
|
||||
curr_joints = env.get_obs()["joint_positions"]
|
||||
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
||||
steps = min(int(max_delta / 0.01), 100)
|
||||
|
||||
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
||||
env.step(jnt)
|
||||
else:
|
||||
if args.agent == "gello":
|
||||
gello_port = args.gello_port
|
||||
if gello_port is None:
|
||||
usb_ports = glob.glob("/dev/serial/by-id/*")
|
||||
print(f"Found {len(usb_ports)} ports")
|
||||
if len(usb_ports) > 0:
|
||||
gello_port = usb_ports[0]
|
||||
print(f"using port {gello_port}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"No gello port found, please specify one or plug in gello"
|
||||
)
|
||||
if args.start_joints is None:
|
||||
reset_joints = np.deg2rad(
|
||||
[0, -90, 90, -90, -90, 0, 0]
|
||||
) # Change this to your own reset joints
|
||||
else:
|
||||
reset_joints = args.start_joints
|
||||
agent = GelloAgent(port=gello_port, start_joints=args.start_joints)
|
||||
curr_joints = env.get_obs()["joint_positions"]
|
||||
if reset_joints.shape == curr_joints.shape:
|
||||
max_delta = (np.abs(curr_joints - reset_joints)).max()
|
||||
steps = min(int(max_delta / 0.01), 100)
|
||||
|
||||
for jnt in np.linspace(curr_joints, reset_joints, steps):
|
||||
env.step(jnt)
|
||||
time.sleep(0.001)
|
||||
elif args.agent == "quest":
|
||||
from gello.agents.quest_agent import SingleArmQuestAgent
|
||||
|
||||
agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
|
||||
elif args.agent == "spacemouse":
|
||||
from gello.agents.spacemouse_agent import SpacemouseAgent
|
||||
|
||||
agent = SpacemouseAgent(robot_type=args.robot_type, verbose=args.verbose)
|
||||
elif args.agent == "dummy" or args.agent == "none":
|
||||
agent = DummyAgent(num_dofs=robot_client.num_dofs())
|
||||
elif args.agent == "policy":
|
||||
raise NotImplementedError("add your imitation policy here if there is one")
|
||||
else:
|
||||
raise ValueError("Invalid agent name")
|
||||
|
||||
# going to start position
|
||||
print("Going to start position")
|
||||
start_pos = agent.act(env.get_obs())
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
|
||||
abs_deltas = np.abs(start_pos - joints)
|
||||
id_max_joint_delta = np.argmax(abs_deltas)
|
||||
|
||||
max_joint_delta = 0.8
|
||||
if abs_deltas[id_max_joint_delta] > max_joint_delta:
|
||||
id_mask = abs_deltas > max_joint_delta
|
||||
print()
|
||||
ids = np.arange(len(id_mask))[id_mask]
|
||||
for i, delta, joint, current_j in zip(
|
||||
ids,
|
||||
abs_deltas[id_mask],
|
||||
start_pos[id_mask],
|
||||
joints[id_mask],
|
||||
):
|
||||
print(
|
||||
f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}"
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
|
||||
assert len(start_pos) == len(
|
||||
joints
|
||||
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
|
||||
|
||||
max_delta = 0.05
|
||||
for _ in range(25):
|
||||
obs = env.get_obs()
|
||||
command_joints = agent.act(obs)
|
||||
current_joints = obs["joint_positions"]
|
||||
delta = command_joints - current_joints
|
||||
max_joint_delta = np.abs(delta).max()
|
||||
if max_joint_delta > max_delta:
|
||||
delta = delta / max_joint_delta * max_delta
|
||||
env.step(current_joints + delta)
|
||||
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
action = agent.act(obs)
|
||||
if (action - joints > 0.5).any():
|
||||
print("Action is too big")
|
||||
|
||||
# print which joints are too big
|
||||
joint_index = np.where(action - joints > 0.8)
|
||||
for j in joint_index:
|
||||
print(
|
||||
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
if args.use_save_interface:
|
||||
from gello.data_utils.keyboard_interface import KBReset
|
||||
|
||||
kb_interface = KBReset()
|
||||
|
||||
print_color("\nStart 🚀🚀🚀", color="green", attrs=("bold",))
|
||||
|
||||
save_path = None
|
||||
start_time = time.time()
|
||||
while True:
|
||||
num = time.time() - start_time
|
||||
message = f"\rTime passed: {round(num, 2)} "
|
||||
print_color(
|
||||
message,
|
||||
color="white",
|
||||
attrs=("bold",),
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
action = agent.act(obs)
|
||||
dt = datetime.datetime.now()
|
||||
if args.use_save_interface:
|
||||
state = kb_interface.update()
|
||||
if state == "start":
|
||||
dt_time = datetime.datetime.now()
|
||||
save_path = (
|
||||
Path(args.data_dir).expanduser()
|
||||
/ args.agent
|
||||
/ dt_time.strftime("%m%d_%H%M%S")
|
||||
)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Saving to {save_path}")
|
||||
elif state == "save":
|
||||
assert save_path is not None, "something went wrong"
|
||||
save_frame(save_path, dt, obs, action)
|
||||
elif state == "normal":
|
||||
save_path = None
|
||||
else:
|
||||
raise ValueError(f"Invalid state {state}")
|
||||
obs = env.step(action)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
0
gello/__init__.py
Normal file
0
gello/__init__.py
Normal file
43
gello/agents/agent.py
Normal file
43
gello/agents/agent.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import Any, Dict, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Agent(Protocol):
|
||||
def act(self, obs: Dict[str, Any]) -> np.ndarray:
|
||||
"""Returns an action given an observation.
|
||||
|
||||
Args:
|
||||
obs: observation from the environment.
|
||||
|
||||
Returns:
|
||||
action: action to take on the environment.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyAgent(Agent):
|
||||
def __init__(self, num_dofs: int):
|
||||
self.num_dofs = num_dofs
|
||||
|
||||
def act(self, obs: Dict[str, Any]) -> np.ndarray:
|
||||
return np.zeros(self.num_dofs)
|
||||
|
||||
|
||||
class BimanualAgent(Agent):
|
||||
def __init__(self, agent_left: Agent, agent_right: Agent):
|
||||
self.agent_left = agent_left
|
||||
self.agent_right = agent_right
|
||||
|
||||
def act(self, obs: Dict[str, Any]) -> np.ndarray:
|
||||
left_obs = {}
|
||||
right_obs = {}
|
||||
for key, val in obs.items():
|
||||
L = val.shape[0]
|
||||
half_dim = L // 2
|
||||
assert L == half_dim * 2, f"{key} must be even, something is wrong"
|
||||
left_obs[key] = val[:half_dim]
|
||||
right_obs[key] = val[half_dim:]
|
||||
return np.concatenate(
|
||||
[self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
|
||||
)
|
139
gello/agents/gello_agent.py
Normal file
139
gello/agents/gello_agent.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gello.agents.agent import Agent
|
||||
from gello.robots.dynamixel import DynamixelRobot
|
||||
|
||||
|
||||
@dataclass
|
||||
class DynamixelRobotConfig:
|
||||
joint_ids: Sequence[int]
|
||||
"""The joint ids of GELLO (not including the gripper). Usually (1, 2, 3 ...)."""
|
||||
|
||||
joint_offsets: Sequence[float]
|
||||
"""The joint offsets of GELLO. There needs to be a joint offset for each joint_id and should be a multiple of pi/2."""
|
||||
|
||||
joint_signs: Sequence[int]
|
||||
"""The joint signs of GELLO. There needs to be a joint sign for each joint_id and should be either 1 or -1.
|
||||
|
||||
This will be different for each arm design. Refernce the examples below for the correct signs for your robot.
|
||||
"""
|
||||
|
||||
gripper_config: Tuple[int, int, int]
|
||||
"""The gripper config of GELLO. This is a tuple of (gripper_joint_id, degrees in open_position, degrees in closed_position)."""
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.joint_ids) == len(self.joint_offsets)
|
||||
assert len(self.joint_ids) == len(self.joint_signs)
|
||||
|
||||
def make_robot(
|
||||
self, port: str = "/dev/ttyUSB0", start_joints: Optional[np.ndarray] = None
|
||||
) -> DynamixelRobot:
|
||||
return DynamixelRobot(
|
||||
joint_ids=self.joint_ids,
|
||||
joint_offsets=list(self.joint_offsets),
|
||||
real=True,
|
||||
joint_signs=list(self.joint_signs),
|
||||
port=port,
|
||||
gripper_config=self.gripper_config,
|
||||
start_joints=start_joints,
|
||||
)
|
||||
|
||||
|
||||
PORT_CONFIG_MAP: Dict[str, DynamixelRobotConfig] = {
|
||||
# xArm
|
||||
# "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig(
|
||||
# joint_ids=(1, 2, 3, 4, 5, 6, 7),
|
||||
# joint_offsets=(
|
||||
# 2 * np.pi / 2,
|
||||
# 2 * np.pi / 2,
|
||||
# 2 * np.pi / 2,
|
||||
# 2 * np.pi / 2,
|
||||
# -1 * np.pi / 2 + 2 * np.pi,
|
||||
# 1 * np.pi / 2,
|
||||
# 1 * np.pi / 2,
|
||||
# ),
|
||||
# joint_signs=(1, 1, 1, 1, 1, 1, 1),
|
||||
# gripper_config=(8, 279, 279 - 50),
|
||||
# ),
|
||||
# panda
|
||||
# "/dev/cu.usbserial-FT3M9NVB": DynamixelRobotConfig(
|
||||
"/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig(
|
||||
joint_ids=(1, 2, 3, 4, 5, 6, 7),
|
||||
joint_offsets=(
|
||||
3 * np.pi / 2,
|
||||
2 * np.pi / 2,
|
||||
1 * np.pi / 2,
|
||||
4 * np.pi / 2,
|
||||
-2 * np.pi / 2 + 2 * np.pi,
|
||||
3 * np.pi / 2,
|
||||
4 * np.pi / 2,
|
||||
),
|
||||
joint_signs=(1, -1, 1, 1, 1, -1, 1),
|
||||
gripper_config=(8, 195, 152),
|
||||
),
|
||||
# Left UR
|
||||
"/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0": DynamixelRobotConfig(
|
||||
joint_ids=(1, 2, 3, 4, 5, 6),
|
||||
joint_offsets=(
|
||||
0,
|
||||
1 * np.pi / 2 + np.pi,
|
||||
np.pi / 2 + 0 * np.pi,
|
||||
0 * np.pi + np.pi / 2,
|
||||
np.pi - 2 * np.pi / 2,
|
||||
-1 * np.pi / 2 + 2 * np.pi,
|
||||
),
|
||||
joint_signs=(1, 1, -1, 1, 1, 1),
|
||||
gripper_config=(7, 20, -22),
|
||||
),
|
||||
# Right UR
|
||||
"/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0": DynamixelRobotConfig(
|
||||
joint_ids=(1, 2, 3, 4, 5, 6),
|
||||
joint_offsets=(
|
||||
np.pi + 0 * np.pi,
|
||||
2 * np.pi + np.pi / 2,
|
||||
2 * np.pi + np.pi / 2,
|
||||
2 * np.pi + np.pi / 2,
|
||||
1 * np.pi,
|
||||
3 * np.pi / 2,
|
||||
),
|
||||
joint_signs=(1, 1, -1, 1, 1, 1),
|
||||
gripper_config=(7, 286, 248),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class GelloAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
dynamixel_config: Optional[DynamixelRobotConfig] = None,
|
||||
start_joints: Optional[np.ndarray] = None,
|
||||
):
|
||||
if dynamixel_config is not None:
|
||||
self._robot = dynamixel_config.make_robot(
|
||||
port=port, start_joints=start_joints
|
||||
)
|
||||
else:
|
||||
assert os.path.exists(port), port
|
||||
assert port in PORT_CONFIG_MAP, f"Port {port} not in config map"
|
||||
|
||||
config = PORT_CONFIG_MAP[port]
|
||||
self._robot = config.make_robot(port=port, start_joints=start_joints)
|
||||
|
||||
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
return self._robot.get_joint_state()
|
||||
dyna_joints = self._robot.get_joint_state()
|
||||
# current_q = dyna_joints[:-1] # last one dim is the gripper
|
||||
current_gripper = dyna_joints[-1] # last one dim is the gripper
|
||||
|
||||
print(current_gripper)
|
||||
if current_gripper < 0.2:
|
||||
self._robot.set_torque_mode(False)
|
||||
return obs["joint_positions"]
|
||||
else:
|
||||
self._robot.set_torque_mode(False)
|
||||
return dyna_joints
|
178
gello/agents/quest_agent.py
Normal file
178
gello/agents/quest_agent.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import quaternion
|
||||
from dm_control import mjcf
|
||||
from dm_control.utils.inverse_kinematics import qpos_from_site_pose
|
||||
from oculus_reader.reader import OculusReader
|
||||
|
||||
from gello.agents.agent import Agent
|
||||
from gello.agents.spacemouse_agent import apply_transfer, mj2ur, ur2mj
|
||||
from gello.dm_control_tasks.arms.ur5e import UR5e
|
||||
|
||||
# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on
|
||||
# our setup, for details please checkout the project page.
|
||||
quest2ur = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
||||
ur2quest = np.linalg.inv(quest2ur)
|
||||
|
||||
translation_scaling_factor = 2.0
|
||||
|
||||
|
||||
class SingleArmQuestAgent(Agent):
|
||||
def __init__(self, robot_type: str, which_hand: str, verbose: bool = False) -> None:
|
||||
"""Interact with the robot using the quest controller.
|
||||
|
||||
leftTrig: press to start control (also record the current position as the home position)
|
||||
leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper
|
||||
"""
|
||||
self.which_hand = which_hand
|
||||
assert self.which_hand in ["l", "r"]
|
||||
|
||||
self.oculus_reader = OculusReader()
|
||||
if robot_type == "ur5":
|
||||
_robot = UR5e()
|
||||
else:
|
||||
raise ValueError(f"Unknown robot type: {robot_type}")
|
||||
self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model)
|
||||
self.control_active = False
|
||||
self.reference_quest_pose = None
|
||||
self.reference_ee_rot_ur = None
|
||||
self.reference_ee_pos_ur = None
|
||||
|
||||
self.robot_type = robot_type
|
||||
self._verbose = verbose
|
||||
|
||||
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
if self.robot_type == "ur5":
|
||||
num_dof = 6
|
||||
current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
|
||||
current_gripper_angle = obs["joint_positions"][-1]
|
||||
# run the fk
|
||||
self.physics.data.qpos[:num_dof] = current_qpos
|
||||
self.physics.step()
|
||||
|
||||
ee_rot_mj = np.array(
|
||||
self.physics.named.data.site_xmat["attachment_site"]
|
||||
).reshape(3, 3)
|
||||
ee_pos_mj = np.array(self.physics.named.data.site_xpos["attachment_site"])
|
||||
if self.which_hand == "l":
|
||||
pose_key = "l"
|
||||
trigger_key = "leftTrig"
|
||||
# joystick_key = "leftJS"
|
||||
# left yx
|
||||
gripper_open_key = "Y"
|
||||
gripper_close_key = "X"
|
||||
elif self.which_hand == "r":
|
||||
pose_key = "r"
|
||||
trigger_key = "rightTrig"
|
||||
# joystick_key = "rightJS"
|
||||
# right ba for the key
|
||||
gripper_open_key = "B"
|
||||
gripper_close_key = "A"
|
||||
else:
|
||||
raise ValueError(f"Unknown hand: {self.which_hand}")
|
||||
# check the trigger button state
|
||||
pose_data, button_data = self.oculus_reader.get_transformations_and_buttons()
|
||||
if len(pose_data) == 0 or len(button_data) == 0:
|
||||
print("no data, quest not yet ready")
|
||||
return np.concatenate([current_qpos, [current_gripper_angle]])
|
||||
|
||||
new_gripper_angle = current_gripper_angle
|
||||
if button_data[gripper_open_key]:
|
||||
new_gripper_angle = 1
|
||||
if button_data[gripper_close_key]:
|
||||
new_gripper_angle = 0
|
||||
arm_not_move_return = np.concatenate([current_qpos, [new_gripper_angle]])
|
||||
if len(pose_data) == 0:
|
||||
print("no data, quest not yet ready")
|
||||
return arm_not_move_return
|
||||
|
||||
trigger_state = button_data[trigger_key][0]
|
||||
if trigger_state > 0.5:
|
||||
if self.control_active is True:
|
||||
if self._verbose:
|
||||
print("controlling the arm")
|
||||
current_pose = pose_data[pose_key]
|
||||
delta_rot = current_pose[:3, :3] @ np.linalg.inv(
|
||||
self.reference_quest_pose[:3, :3]
|
||||
)
|
||||
delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3]
|
||||
delta_pos_ur = (
|
||||
apply_transfer(quest2ur, delta_pos) * translation_scaling_factor
|
||||
)
|
||||
# ? is this the case?
|
||||
delta_rot_ur = quest2ur[:3, :3] @ delta_rot @ ur2quest[:3, :3]
|
||||
if self._verbose:
|
||||
print(
|
||||
f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}"
|
||||
)
|
||||
next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur
|
||||
next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur
|
||||
|
||||
target_quat = quaternion.as_float_array(
|
||||
quaternion.from_rotation_matrix(ur2mj[:3, :3] @ next_ee_rot_ur)
|
||||
)
|
||||
ik_result = qpos_from_site_pose(
|
||||
self.physics,
|
||||
"attachment_site",
|
||||
target_pos=apply_transfer(ur2mj, next_ee_pos_ur),
|
||||
target_quat=target_quat,
|
||||
tol=1e-14,
|
||||
max_steps=400,
|
||||
)
|
||||
self.physics.reset()
|
||||
if ik_result.success:
|
||||
new_qpos = ik_result.qpos[:num_dof]
|
||||
else:
|
||||
print("ik failed, using the original qpos")
|
||||
return arm_not_move_return
|
||||
command = np.concatenate([new_qpos, [new_gripper_angle]])
|
||||
return command
|
||||
|
||||
else: # last state is not in active
|
||||
self.control_active = True
|
||||
if self._verbose:
|
||||
print("control activated!")
|
||||
self.reference_quest_pose = pose_data[pose_key]
|
||||
|
||||
self.reference_ee_rot_ur = mj2ur[:3, :3] @ ee_rot_mj
|
||||
self.reference_ee_pos_ur = apply_transfer(mj2ur, ee_pos_mj)
|
||||
return arm_not_move_return
|
||||
else:
|
||||
if self._verbose:
|
||||
print("deactive control")
|
||||
self.control_active = False
|
||||
self.reference_quest_pose = None
|
||||
return arm_not_move_return
|
||||
|
||||
|
||||
class DualArmQuestAgent(Agent):
|
||||
def __init__(self, robot_type: str) -> None:
|
||||
self.left_arm = SingleArmQuestAgent(robot_type, "l")
|
||||
self.right_arm = SingleArmQuestAgent(robot_type, "r")
|
||||
|
||||
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
oculus_reader = OculusReader()
|
||||
while True:
|
||||
"""
|
||||
example output:
|
||||
({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ],
|
||||
[-0.107737 , 0.0958919, 0.989544 , -0.833478 ],
|
||||
[ 0.549685 , 0.835106 , -0.0210789, -0.892425 ],
|
||||
[ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ],
|
||||
[ 0.070887, 0.516083, -0.8536 , -0.238691],
|
||||
[-0.941994, -0.246809, -0.227447, -0.370447],
|
||||
[ 0. , 0. , 0. , 1. ]])},
|
||||
{'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)})
|
||||
|
||||
"""
|
||||
pose_data, button_data = oculus_reader.get_transformations_and_buttons()
|
||||
if len(pose_data) == 0:
|
||||
print("no data")
|
||||
continue
|
||||
else:
|
||||
print(pose_data["l"])
|
224
gello/agents/spacemouse_agent.py
Normal file
224
gello/agents/spacemouse_agent.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from dm_control import mjcf
|
||||
from dm_control.utils.inverse_kinematics import qpos_from_site_pose
|
||||
|
||||
from gello.agents.agent import Agent
|
||||
from gello.dm_control_tasks.arms.ur5e import UR5e
|
||||
|
||||
# mujoco has a slightly different coordinate system than UR control box
|
||||
mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
|
||||
ur2mj = np.linalg.inv(mj2ur)
|
||||
|
||||
# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on
|
||||
# our setup, for details please checkout the project page.
|
||||
spacemouse2ur = np.array(
|
||||
[
|
||||
[-1, 0, 0, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
)
|
||||
ur2spacemouse = np.linalg.inv(spacemouse2ur)
|
||||
|
||||
|
||||
def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray:
|
||||
# xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix
|
||||
if len(xyz) == 3:
|
||||
xyz = np.append(xyz, 1)
|
||||
return np.matmul(mat, xyz)[:3]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpacemouseConfig:
|
||||
angle_scale: float = 0.24
|
||||
translation_scale: float = 0.06
|
||||
# only control the xyz, rotation direction, not the gripper
|
||||
invert_control: np.ndarray = np.ones(6)
|
||||
rotation_mode: str = "euler"
|
||||
|
||||
|
||||
class SpacemouseAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
robot_type: str,
|
||||
config: SpacemouseConfig = SpacemouseConfig(),
|
||||
device_path: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
invert_button: bool = False,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.last_state_lock = threading.Lock()
|
||||
self._invert_button = invert_button
|
||||
# example state:SpaceNavigator(t=3.581528532, x=0.0, y=0.0, z=0.0, roll=0.0, pitch=0.0, yaw=0.0, buttons=[0, 0])
|
||||
# all continuous inputs range from 0-1, buttons are 0 or 1
|
||||
self.spacemouse_latest_state = None
|
||||
self._device_path = device_path
|
||||
spacemouse_thread = threading.Thread(target=self._read_from_spacemouse)
|
||||
spacemouse_thread.start()
|
||||
self._verbose = verbose
|
||||
if self._verbose:
|
||||
print(f"robot_type: {robot_type}")
|
||||
if robot_type == "ur5":
|
||||
_robot = UR5e()
|
||||
else:
|
||||
raise ValueError(f"Unknown robot type: {robot_type}")
|
||||
self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model)
|
||||
|
||||
def _read_from_spacemouse(self):
|
||||
import pyspacemouse
|
||||
|
||||
if self._device_path is None:
|
||||
mouse = pyspacemouse.open()
|
||||
else:
|
||||
mouse = pyspacemouse.open(path=self._device_path)
|
||||
if mouse:
|
||||
while 1:
|
||||
state = mouse.read()
|
||||
with self.last_state_lock:
|
||||
self.spacemouse_latest_state = state
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
raise ValueError("Failed to open spacemouse")
|
||||
|
||||
def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
|
||||
import quaternion
|
||||
|
||||
# obs: the folllow robot's current state
|
||||
# in rad, 6/7 dof depends on the robot type
|
||||
num_dof = 6
|
||||
if self._verbose:
|
||||
print("act invoked")
|
||||
current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
|
||||
current_gripper_angle = obs["joint_positions"][-1]
|
||||
self.physics.data.qpos[:num_dof] = current_qpos
|
||||
|
||||
self.physics.step()
|
||||
ee_rot = np.array(self.physics.named.data.site_xmat["attachment_site"]).reshape(
|
||||
3, 3
|
||||
)
|
||||
ee_pos = np.array(self.physics.named.data.site_xpos["attachment_site"])
|
||||
|
||||
ee_rot = mj2ur[:3, :3] @ ee_rot
|
||||
ee_pos = apply_transfer(mj2ur, ee_pos)
|
||||
# ^ mujoco coordinate to UR
|
||||
|
||||
with self.last_state_lock:
|
||||
spacemouse_read = self.spacemouse_latest_state
|
||||
if self._verbose:
|
||||
print(f"spacemouse_read: {spacemouse_read}")
|
||||
|
||||
assert spacemouse_read is not None
|
||||
spacemouse_xyz_rot_np = np.array(
|
||||
[
|
||||
spacemouse_read.x,
|
||||
spacemouse_read.y,
|
||||
spacemouse_read.z,
|
||||
spacemouse_read.roll,
|
||||
spacemouse_read.pitch,
|
||||
spacemouse_read.yaw,
|
||||
]
|
||||
)
|
||||
spacemouse_button = (
|
||||
spacemouse_read.buttons
|
||||
) # size 2 list binary indicating left/right button pressing
|
||||
spacemouse_xyz_rot_np = spacemouse_xyz_rot_np * self.config.invert_control
|
||||
if np.max(np.abs(spacemouse_xyz_rot_np)) > 0.9:
|
||||
spacemouse_xyz_rot_np[np.abs(spacemouse_xyz_rot_np) < 0.6] = 0
|
||||
tx, ty, tz, r, p, y = spacemouse_xyz_rot_np
|
||||
# convert roll pick yaw to rotation matrix (rpy)
|
||||
|
||||
trans_transform = np.eye(4)
|
||||
# delta xyz from the spacemouse reading
|
||||
trans_transform[:3, 3] = apply_transfer(
|
||||
spacemouse2ur, np.array([tx, ty, tz]) * self.config.translation_scale
|
||||
)
|
||||
|
||||
# break rot_transform into each axis
|
||||
rot_transform_x = np.eye(4)
|
||||
rot_transform_x[:3, :3] = quaternion.as_rotation_matrix(
|
||||
quaternion.from_rotation_vector(
|
||||
np.array([-p, 0, 0]) * self.config.angle_scale
|
||||
)
|
||||
)
|
||||
|
||||
rot_transform_y = np.eye(4)
|
||||
rot_transform_y[:3, :3] = quaternion.as_rotation_matrix(
|
||||
quaternion.from_rotation_vector(
|
||||
np.array([0, r, 0]) * self.config.angle_scale
|
||||
)
|
||||
)
|
||||
|
||||
rot_transform_z = np.eye(4)
|
||||
rot_transform_z[:3, :3] = quaternion.as_rotation_matrix(
|
||||
quaternion.from_rotation_vector(
|
||||
np.array([0, 0, -y]) * self.config.angle_scale
|
||||
)
|
||||
)
|
||||
|
||||
# in ur space
|
||||
rot_transform = (
|
||||
spacemouse2ur
|
||||
@ rot_transform_z
|
||||
@ rot_transform_y
|
||||
@ rot_transform_x
|
||||
@ ur2spacemouse
|
||||
)
|
||||
|
||||
if self._verbose:
|
||||
print(f"rot_transform: {rot_transform}")
|
||||
print(f"new spacemounse cmd in ur space = {trans_transform[:3, 3]}")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
new_ee_pos = trans_transform[:3, 3] + ee_pos
|
||||
if self.config.rotation_mode == "rpy":
|
||||
new_ee_rot = ee_rot @ rot_transform[:3, :3]
|
||||
elif self.config.rotation_mode == "euler":
|
||||
new_ee_rot = rot_transform[:3, :3] @ ee_rot
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown rotation mode: {self.config.rotation_mode}"
|
||||
)
|
||||
|
||||
target_quat = quaternion.as_float_array(
|
||||
quaternion.from_rotation_matrix(ur2mj[:3, :3] @ new_ee_rot)
|
||||
)
|
||||
ik_result = qpos_from_site_pose(
|
||||
self.physics,
|
||||
"attachment_site",
|
||||
target_pos=apply_transfer(ur2mj, new_ee_pos),
|
||||
target_quat=target_quat,
|
||||
tol=1e-14,
|
||||
max_steps=400,
|
||||
)
|
||||
self.physics.reset()
|
||||
if ik_result.success:
|
||||
new_qpos = ik_result.qpos[:num_dof]
|
||||
else:
|
||||
print("ik failed, using the original qpos")
|
||||
return np.concatenate([current_qpos, [current_gripper_angle]])
|
||||
new_gripper_angle = current_gripper_angle
|
||||
if self._invert_button:
|
||||
if spacemouse_button[1]:
|
||||
new_gripper_angle = 1
|
||||
if spacemouse_button[0]:
|
||||
new_gripper_angle = 0
|
||||
else:
|
||||
if spacemouse_button[1]:
|
||||
new_gripper_angle = 0
|
||||
if spacemouse_button[0]:
|
||||
new_gripper_angle = 1
|
||||
command = np.concatenate([new_qpos, [new_gripper_angle]])
|
||||
return command
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pyspacemouse
|
||||
|
||||
success = pyspacemouse.open("/dev/hidraw4")
|
||||
success = pyspacemouse.open("/dev/hidraw5")
|
81
gello/cameras/camera.py
Normal file
81
gello/cameras/camera.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CameraDriver(Protocol):
|
||||
"""Camera protocol.
|
||||
|
||||
A protocol for a camera driver. This is used to abstract the camera from the rest of the code.
|
||||
"""
|
||||
|
||||
def read(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera.
|
||||
|
||||
Args:
|
||||
img_size: The size of the image to return. If None, the original size is returned.
|
||||
farthest: The farthest distance to map to 255.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The color image.
|
||||
np.ndarray: The depth image.
|
||||
"""
|
||||
|
||||
|
||||
class DummyCamera(CameraDriver):
|
||||
"""A dummy camera for testing."""
|
||||
|
||||
def read(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera.
|
||||
|
||||
Args:
|
||||
img_size: The size of the image to return. If None, the original size is returned.
|
||||
farthest: The farthest distance to map to 255.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The color image.
|
||||
np.ndarray: The depth image.
|
||||
"""
|
||||
if img_size is None:
|
||||
return (
|
||||
np.random.randint(255, size=(480, 640, 3), dtype=np.uint8),
|
||||
np.random.randint(255, size=(480, 640, 1), dtype=np.uint16),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
np.random.randint(
|
||||
255, size=(img_size[0], img_size[1], 3), dtype=np.uint8
|
||||
),
|
||||
np.random.randint(
|
||||
255, size=(img_size[0], img_size[1], 1), dtype=np.uint16
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SavedCamera(CameraDriver):
|
||||
def __init__(self, path: str = "example"):
|
||||
self.path = str(Path(__file__).parent / path)
|
||||
from PIL import Image
|
||||
|
||||
self._color_img = Image.open(f"{self.path}/image.png")
|
||||
self._depth_img = Image.open(f"{self.path}/depth.png")
|
||||
|
||||
def read(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if img_size is not None:
|
||||
color_img = self._color_img.resize(img_size)
|
||||
depth_img = self._depth_img.resize(img_size)
|
||||
else:
|
||||
color_img = self._color_img
|
||||
depth_img = self._depth_img
|
||||
|
||||
return np.array(color_img), np.array(depth_img)[:, :, 0:1]
|
123
gello/cameras/realsense_camera.py
Normal file
123
gello/cameras/realsense_camera.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
import os
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gello.cameras.camera import CameraDriver
|
||||
|
||||
|
||||
def get_device_ids() -> List[str]:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
ctx = rs.context()
|
||||
devices = ctx.query_devices()
|
||||
device_ids = []
|
||||
for dev in devices:
|
||||
dev.hardware_reset()
|
||||
device_ids.append(dev.get_info(rs.camera_info.serial_number))
|
||||
time.sleep(2)
|
||||
return device_ids
|
||||
|
||||
|
||||
class RealSenseCamera(CameraDriver):
|
||||
def __repr__(self) -> str:
|
||||
return f"RealSenseCamera(device_id={self._device_id})"
|
||||
|
||||
def __init__(self, device_id: Optional[str] = None, flip: bool = False):
|
||||
import pyrealsense2 as rs
|
||||
|
||||
self._device_id = device_id
|
||||
|
||||
if device_id is None:
|
||||
ctx = rs.context()
|
||||
devices = ctx.query_devices()
|
||||
for dev in devices:
|
||||
dev.hardware_reset()
|
||||
time.sleep(2)
|
||||
self._pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
else:
|
||||
self._pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_device(device_id)
|
||||
|
||||
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
|
||||
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
|
||||
self._pipeline.start(config)
|
||||
self._flip = flip
|
||||
|
||||
def read(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None, # farthest: float = 0.12
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera.
|
||||
|
||||
Args:
|
||||
img_size: The size of the image to return. If None, the original size is returned.
|
||||
farthest: The farthest distance to map to 255.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The color image, shape=(H, W, 3)
|
||||
np.ndarray: The depth image, shape=(H, W, 1)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
color_frame = frames.get_color_frame()
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
depth_frame = frames.get_depth_frame()
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
# depth_image = cv2.convertScaleAbs(depth_image, alpha=0.03)
|
||||
if img_size is None:
|
||||
image = color_image[:, :, ::-1]
|
||||
depth = depth_image
|
||||
else:
|
||||
image = cv2.resize(color_image, img_size)[:, :, ::-1]
|
||||
depth = cv2.resize(depth_image, img_size)
|
||||
|
||||
# rotate 180 degree's because everything is upside down in order to center the camera
|
||||
if self._flip:
|
||||
image = cv2.rotate(image, cv2.ROTATE_180)
|
||||
depth = cv2.rotate(depth, cv2.ROTATE_180)[:, :, None]
|
||||
else:
|
||||
depth = depth[:, :, None]
|
||||
|
||||
return image, depth
|
||||
|
||||
|
||||
def _debug_read(camera, save_datastream=False):
|
||||
import cv2
|
||||
|
||||
cv2.namedWindow("image")
|
||||
cv2.namedWindow("depth")
|
||||
counter = 0
|
||||
if not os.path.exists("images"):
|
||||
os.makedirs("images")
|
||||
if save_datastream and not os.path.exists("stream"):
|
||||
os.makedirs("stream")
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
image, depth = camera.read()
|
||||
depth = np.concatenate([depth, depth, depth], axis=-1)
|
||||
key = cv2.waitKey(1)
|
||||
cv2.imshow("image", image[:, :, ::-1])
|
||||
cv2.imshow("depth", depth)
|
||||
if key == ord("s"):
|
||||
cv2.imwrite(f"images/image_{counter}.png", image[:, :, ::-1])
|
||||
cv2.imwrite(f"images/depth_{counter}.png", depth)
|
||||
if save_datastream:
|
||||
cv2.imwrite(f"stream/image_{counter}.png", image[:, :, ::-1])
|
||||
cv2.imwrite(f"stream/depth_{counter}.png", depth)
|
||||
counter += 1
|
||||
if key == 27:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device_ids = get_device_ids()
|
||||
print(f"Found {len(device_ids)} devices")
|
||||
print(device_ids)
|
||||
rs = RealSenseCamera(flip=True, device_id=device_ids[0])
|
||||
im, depth = rs.read()
|
||||
_debug_read(rs, save_datastream=True)
|
231
gello/data_utils/conversion_utils.py
Normal file
231
gello/data_utils/conversion_utils.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
from typing import Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import transforms3d._gohlketransforms as ttf
|
||||
|
||||
|
||||
def to_torch(array, device="cpu"):
|
||||
if isinstance(array, torch.Tensor):
|
||||
return array.to(device)
|
||||
if isinstance(array, np.ndarray):
|
||||
return torch.from_numpy(array).to(device)
|
||||
else:
|
||||
return torch.tensor(array).to(device)
|
||||
|
||||
|
||||
def to_numpy(array):
|
||||
if isinstance(array, torch.Tensor):
|
||||
return array.cpu().numpy()
|
||||
return array
|
||||
|
||||
|
||||
def center_crop(rgb_frame, depth_frame):
|
||||
H, W = rgb_frame.shape[-2:]
|
||||
sq_size = min(H, W)
|
||||
|
||||
# crop the center square
|
||||
if H > W:
|
||||
rgb_frame = rgb_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size]
|
||||
depth_frame = depth_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size]
|
||||
elif W < H:
|
||||
rgb_frame = rgb_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)]
|
||||
depth_frame = depth_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)]
|
||||
|
||||
return rgb_frame, depth_frame
|
||||
|
||||
|
||||
def resize(rgb, depth, size=224):
|
||||
rgb = rgb.transpose([1, 2, 0])
|
||||
rgb = cv2.resize(rgb, (size, size), interpolation=cv2.INTER_LINEAR)
|
||||
rgb = rgb.transpose([2, 0, 1])
|
||||
|
||||
depth = cv2.resize(depth[0], (size, size), interpolation=cv2.INTER_LINEAR)
|
||||
depth = depth.reshape([1, size, size])
|
||||
return rgb, depth
|
||||
|
||||
|
||||
def filter_depth(depth, max_depth=2.0, min_depth=0.0):
|
||||
depth[np.isnan(depth)] = 0.0
|
||||
depth[np.isinf(depth)] = 0.0
|
||||
depth = np.clip(depth, min_depth, max_depth)
|
||||
return depth
|
||||
|
||||
|
||||
def preproc_obs(
|
||||
demo: Dict[str, np.ndarray], joint_only: bool = True
|
||||
) -> Dict[str, np.ndarray]:
|
||||
# c h w
|
||||
rgb_wrist = demo.get(f"wrist_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore
|
||||
depth_wrist = demo.get(f"wrist_depth").transpose([2, 0, 1]) * 1.0 # type: ignore
|
||||
rgb_base = demo.get("base_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore
|
||||
depth_base = demo.get("base_depth").transpose([2, 0, 1]) * 1.0 # type: ignore
|
||||
|
||||
# Center crop and fitler depth
|
||||
rgb_wrist, depth_wrist = resize(*center_crop(rgb_wrist, depth_wrist))
|
||||
rgb_base, depth_base = resize(*center_crop(rgb_base, depth_base))
|
||||
|
||||
depth_wrist = filter_depth(depth_wrist)
|
||||
depth_base = filter_depth(depth_base)
|
||||
|
||||
rgb = np.stack([rgb_wrist, rgb_base], axis=0)
|
||||
depth = np.stack([depth_wrist, depth_base], axis=0)
|
||||
|
||||
# Dummy
|
||||
dummy_cam = np.eye(4)
|
||||
K = np.eye(3)
|
||||
|
||||
# state
|
||||
qpos, qvel, ee_pos_quat, gripper_pos = (
|
||||
demo.get("joint_positions"), # type: ignore
|
||||
demo.get("joint_velocities"), # type: ignore
|
||||
demo.get("ee_pos_quat"), # type: ignore
|
||||
demo.get("gripper_position"), # type: ignore
|
||||
)
|
||||
|
||||
if joint_only:
|
||||
state: np.ndarray = qpos
|
||||
else:
|
||||
state: np.ndarray = np.concatenate([qpos, qvel, ee_pos_quat, gripper_pos[None]])
|
||||
|
||||
return {
|
||||
"rgb": rgb,
|
||||
"depth": depth,
|
||||
"camera_poses": dummy_cam,
|
||||
"K_matrices": K,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
|
||||
class Pose(object):
|
||||
def __init__(self, x, y, z, qw, qx, qy, qz):
|
||||
self.p = np.array([x, y, z])
|
||||
|
||||
# we internally use tf.transformations, which uses [x, y, z, w] for quaternions.
|
||||
self.q = np.array([qx, qy, qz, qw])
|
||||
|
||||
# make sure that the quaternion has positive scalar part
|
||||
if self.q[3] < 0:
|
||||
self.q *= -1
|
||||
|
||||
self.q = self.q / np.linalg.norm(self.q)
|
||||
|
||||
def __mul__(self, other):
|
||||
assert isinstance(other, Pose)
|
||||
p = self.p + ttf.quaternion_matrix(self.q)[:3, :3].dot(other.p)
|
||||
q = ttf.quaternion_multiply(self.q, other.q)
|
||||
return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
|
||||
|
||||
def __rmul__(self, other):
|
||||
assert isinstance(other, Pose)
|
||||
return other * self
|
||||
|
||||
def __str__(self):
|
||||
return "p: {}, q: {}".format(self.p, self.q)
|
||||
|
||||
def inv(self):
|
||||
R = ttf.quaternion_matrix(self.q)[:3, :3]
|
||||
p = -R.T.dot(self.p)
|
||||
q = ttf.quaternion_inverse(self.q)
|
||||
return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
|
||||
|
||||
def to_quaternion(self):
|
||||
"""
|
||||
this satisfies Pose(*to_quaternion(p)) == p
|
||||
"""
|
||||
q_reverted = np.array([self.q[3], self.q[0], self.q[1], self.q[2]])
|
||||
return np.concatenate([self.p, q_reverted])
|
||||
|
||||
def to_axis_angle(self):
|
||||
"""
|
||||
returns the axis-angle representation of the rotation.
|
||||
"""
|
||||
angle = 2 * np.arccos(self.q[3])
|
||||
angle = angle / np.pi
|
||||
if angle > 1:
|
||||
angle -= 2
|
||||
|
||||
axis = self.q[:3] / np.linalg.norm(self.q[:3])
|
||||
|
||||
# keep the axes positive
|
||||
if axis[0] < 0:
|
||||
axis *= -1
|
||||
angle *= -1
|
||||
|
||||
return np.concatenate([self.p, axis, [angle]])
|
||||
|
||||
def to_euler(self):
|
||||
q = np.array(ttf.euler_from_quaternion(self.q))
|
||||
if q[0] > np.pi:
|
||||
q[0] -= 2 * np.pi
|
||||
if q[1] > np.pi:
|
||||
q[1] -= 2 * np.pi
|
||||
if q[2] > np.pi:
|
||||
q[2] -= 2 * np.pi
|
||||
|
||||
q = q / np.pi
|
||||
|
||||
return np.concatenate([self.p, q, [0.0]])
|
||||
|
||||
def to_44_matrix(self):
|
||||
out = np.eye(4)
|
||||
out[:3, :3] = ttf.quaternion_matrix(self.q)[:3, :3]
|
||||
out[:3, 3] = self.p
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def from_axis_angle(x, y, z, ax, ay, az, phi):
|
||||
"""
|
||||
returns a Pose object from the axis-angle representation of the rotation.
|
||||
"""
|
||||
|
||||
phi = phi * np.pi
|
||||
p = np.array([x, y, z])
|
||||
qw = np.cos(phi / 2.0)
|
||||
qx = ax * np.sin(phi / 2.0)
|
||||
qy = ay * np.sin(phi / 2.0)
|
||||
qz = az * np.sin(phi / 2.0)
|
||||
|
||||
return Pose(p[0], p[1], p[2], qw, qx, qy, qz)
|
||||
|
||||
@staticmethod
|
||||
def from_euler(x, y, z, roll, pitch, yaw, _):
|
||||
"""
|
||||
returns a Pose object from the euler representation of the rotation.
|
||||
"""
|
||||
p = np.array([x, y, z])
|
||||
roll, pitch, yaw = roll * np.pi, pitch * np.pi, yaw * np.pi
|
||||
q = ttf.quaternion_from_euler(roll, pitch, yaw)
|
||||
return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2])
|
||||
|
||||
@staticmethod
|
||||
def from_quaternion(x, y, z, qw, qx, qy, qz):
|
||||
"""
|
||||
returns a Pose object from the quaternion representation of the rotation.
|
||||
"""
|
||||
p = np.array([x, y, z])
|
||||
return Pose(p[0], p[1], p[2], qw, qx, qy, qz)
|
||||
|
||||
|
||||
def compute_inverse_action(p, p_new, ee_control=False):
|
||||
assert isinstance(p, Pose) and isinstance(p_new, Pose)
|
||||
|
||||
if ee_control:
|
||||
dpose = p.inv() * p_new
|
||||
else:
|
||||
dpose = p_new * p.inv()
|
||||
|
||||
return dpose
|
||||
|
||||
|
||||
def compute_forward_action(p, dpose, ee_control=False):
|
||||
assert isinstance(p, Pose) and isinstance(dpose, Pose)
|
||||
dpose = Pose.from_quaternion(*dpose.to_quaternion())
|
||||
|
||||
if ee_control:
|
||||
p_new = p * dpose
|
||||
else:
|
||||
p_new = dpose * p
|
||||
|
||||
return p_new
|
345
gello/data_utils/demo_to_gdict.py
Normal file
345
gello/data_utils/demo_to_gdict.py
Normal file
|
@ -0,0 +1,345 @@
|
|||
import glob
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
from natsort import natsorted
|
||||
from tqdm import tqdm
|
||||
|
||||
from gello.data_utils.plot_utils import plot_in_grid
|
||||
|
||||
np.set_printoptions(precision=3, suppress=True)
|
||||
|
||||
import mediapy as mp
|
||||
from gdict.data import DictArray, GDict
|
||||
from simple_bc.utils.visualization_utils import make_grid_video_from_numpy
|
||||
|
||||
from gello.data_utils.conversion_utils import preproc_obs
|
||||
|
||||
# def get_act_bounds(source_dir: str) -> np.ndarray:
|
||||
# pkls = natsorted(
|
||||
# glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
|
||||
# )
|
||||
# if len(pkls) <= 30:
|
||||
# print(f"Skipping {source_dir} because it has less than 30 frames.")
|
||||
# return None
|
||||
# pkls = pkls[:-5]
|
||||
#
|
||||
# scale_factor = None
|
||||
# for pkl in pkls:
|
||||
# try:
|
||||
# with open(pkl, "rb") as f:
|
||||
# demo = pickle.load(f)
|
||||
# except Exception as e:
|
||||
# print(f"Skipping {pkl} because it is corrupted.")
|
||||
# print(f"Error: {e}")
|
||||
# raise Exception("Corrupted pkl")
|
||||
#
|
||||
# requested_control = demo.pop("control")
|
||||
# curr_scale_factor = np.abs(requested_control)
|
||||
# if scale_factor is None:
|
||||
# scale_factor = curr_scale_factor
|
||||
# else:
|
||||
# scale_factor = np.maximum(scale_factor, curr_scale_factor)
|
||||
# assert scale_factor is not None
|
||||
# return scale_factor
|
||||
|
||||
|
||||
def get_act_min_max(source_dir: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
pkls = natsorted(
|
||||
glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
|
||||
)
|
||||
if len(pkls) <= 30:
|
||||
print(f"Skipping {source_dir} because it has less than 30 frames.")
|
||||
raise RuntimeError("Too few frames")
|
||||
pkls = pkls[:-5]
|
||||
|
||||
scale_min = None
|
||||
scale_max = None
|
||||
for pkl in pkls:
|
||||
try:
|
||||
with open(pkl, "rb") as f:
|
||||
demo = pickle.load(f)
|
||||
except Exception as e:
|
||||
print(f"Skipping {pkl} because it is corrupted.")
|
||||
print(f"Error: {e}")
|
||||
raise Exception("Corrupted pkl")
|
||||
|
||||
requested_control = demo.pop("control")
|
||||
curr_scale_factor = requested_control
|
||||
if scale_min is None:
|
||||
assert scale_max is None
|
||||
scale_min = curr_scale_factor
|
||||
scale_max = curr_scale_factor
|
||||
else:
|
||||
assert scale_max is not None
|
||||
scale_min = np.minimum(scale_min, curr_scale_factor)
|
||||
scale_max = np.maximum(scale_min, curr_scale_factor)
|
||||
|
||||
assert scale_min is not None
|
||||
assert scale_max is not None
|
||||
return scale_min, scale_max
|
||||
|
||||
|
||||
def convert_single_demo(
|
||||
source_dir,
|
||||
i,
|
||||
traj_output_dir,
|
||||
rgb_output_dir,
|
||||
depth_output_dir,
|
||||
state_output_dir,
|
||||
action_output_dir,
|
||||
scale_factor,
|
||||
bias_factor,
|
||||
):
|
||||
"""
|
||||
1. converts the demo into a gdict
|
||||
2. visualizes the RGB of the demo
|
||||
3. visualizes the state + action space of the demo
|
||||
4. returns these to be collated by the caller.
|
||||
"""
|
||||
|
||||
pkls = natsorted(
|
||||
glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True
|
||||
)
|
||||
demo_stack = []
|
||||
|
||||
if len(pkls) <= 30:
|
||||
return 0
|
||||
|
||||
# go through the demo in reverse order.
|
||||
# remove the first few frames because they are not useful.
|
||||
pkls = pkls[:-5]
|
||||
|
||||
for pkl in pkls:
|
||||
curr_ts = {}
|
||||
try:
|
||||
with open(pkl, "rb") as f:
|
||||
demo = pickle.load(f)
|
||||
except:
|
||||
print(f"Skipping {pkl} because it is corrupted.")
|
||||
return 0
|
||||
|
||||
obs = preproc_obs(demo)
|
||||
action = demo.pop("control")
|
||||
action = (action - bias_factor) / scale_factor # normalize between -1 and 1
|
||||
|
||||
curr_ts["obs"] = obs
|
||||
curr_ts["actions"] = action
|
||||
curr_ts["dones"] = np.zeros(1) # random fill
|
||||
curr_ts["episode_dones"] = np.zeros(1) # random fill
|
||||
|
||||
curr_ts_wrapped = dict()
|
||||
curr_ts_wrapped[f"traj_{i}"] = curr_ts
|
||||
demo_stack = [curr_ts_wrapped] + demo_stack
|
||||
|
||||
demo_dict = DictArray.stack(demo_stack)
|
||||
GDict.to_hdf5(demo_dict, os.path.join(traj_output_dir + "", f"traj_{i}.h5"))
|
||||
|
||||
## save the base videos
|
||||
# save the base rgb and depth videos
|
||||
all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 1].transpose([0, 2, 3, 1])
|
||||
all_rgbs = all_rgbs.astype(np.uint8)
|
||||
_, H, W, _ = all_rgbs.shape
|
||||
all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 1].reshape([-1, H, W])
|
||||
all_depths = all_depths / 5.0 # scale to 0-1
|
||||
|
||||
mp.write_video(
|
||||
os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_base.mp4"), all_rgbs, fps=30
|
||||
)
|
||||
mp.write_video(
|
||||
os.path.join(depth_output_dir + "", f"traj_{i}_depth_base.mp4"),
|
||||
all_depths,
|
||||
fps=30,
|
||||
)
|
||||
|
||||
## save the wrist videos
|
||||
# save the rgb and depth videos
|
||||
all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 0].transpose([0, 2, 3, 1])
|
||||
all_rgbs = all_rgbs.astype(np.uint8)
|
||||
_, H, W, _ = all_rgbs.shape
|
||||
all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 0].reshape([-1, H, W])
|
||||
all_depths = all_depths / 2.0 # scale to 0-1
|
||||
|
||||
mp.write_video(
|
||||
os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_wrist.mp4"), all_rgbs, fps=30
|
||||
)
|
||||
mp.write_video(
|
||||
os.path.join(depth_output_dir + "", f"traj_{i}_depth_wrist.mp4"),
|
||||
all_depths,
|
||||
fps=30,
|
||||
)
|
||||
##
|
||||
|
||||
all_depths = np.tile(all_depths[..., None], [1, 1, 1, 3])
|
||||
|
||||
# save the state and action plots
|
||||
all_actions = demo_dict[f"traj_{i}"]["actions"]
|
||||
all_states = demo_dict[f"traj_{i}"]["obs"]["state"]
|
||||
|
||||
curr_actions = all_actions.reshape([1, *all_actions.shape])
|
||||
curr_states = all_states.reshape([-1, *all_states.shape])
|
||||
|
||||
plot_in_grid(
|
||||
curr_actions, os.path.join(action_output_dir + "", f"traj_{i}_actions.png")
|
||||
)
|
||||
plot_in_grid(
|
||||
curr_states, os.path.join(state_output_dir + "", f"traj_{i}_states.png")
|
||||
)
|
||||
|
||||
return all_rgbs, all_depths, all_actions, all_states
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
source_dir: str
|
||||
vis: bool = True
|
||||
|
||||
|
||||
def main(args):
|
||||
subdirs = natsorted(glob.glob(os.path.join(args.source_dir, "*/"), recursive=True))
|
||||
|
||||
output_dir = args.source_dir
|
||||
if output_dir[-1] == "/":
|
||||
output_dir = output_dir[:-1]
|
||||
|
||||
output_dir = os.path.join(output_dir, "_conv")
|
||||
|
||||
if not os.path.isdir(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
|
||||
output_dir = os.path.join(output_dir, "multiview")
|
||||
|
||||
if not os.path.isdir(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
else:
|
||||
print(f"Output directory {output_dir} already exists, and will be deleted")
|
||||
shutil.rmtree(output_dir)
|
||||
os.mkdir(output_dir)
|
||||
|
||||
train_dir = os.path.join(output_dir, "train")
|
||||
val_dir = os.path.join(output_dir, "val")
|
||||
|
||||
if not os.path.isdir(train_dir):
|
||||
os.mkdir(train_dir)
|
||||
if not os.path.isdir(val_dir):
|
||||
os.mkdir(val_dir)
|
||||
|
||||
val_size = int(min(0.1 * len(subdirs), 10))
|
||||
val_indices = np.random.choice(len(subdirs), size=val_size, replace=False)
|
||||
val_indices = set(val_indices)
|
||||
|
||||
print("Computing scale factors")
|
||||
pbar = tqdm(range(len(subdirs)))
|
||||
min_scale_factor = None
|
||||
max_scale_factor = None
|
||||
for i in pbar:
|
||||
try:
|
||||
curr_min, curr_max = get_act_min_max(subdirs[i])
|
||||
if min_scale_factor is None:
|
||||
assert max_scale_factor is None
|
||||
min_scale_factor = curr_min
|
||||
max_scale_factor = curr_max
|
||||
else:
|
||||
assert max_scale_factor is not None
|
||||
min_scale_factor = np.minimum(min_scale_factor, curr_min)
|
||||
max_scale_factor = np.maximum(max_scale_factor, curr_max)
|
||||
pbar.set_description(f"t: {i}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Skipping {subdirs[i]}")
|
||||
continue
|
||||
bias_factor = (min_scale_factor + max_scale_factor) / 2.0
|
||||
scale_factor = (max_scale_factor - min_scale_factor) / 2.0
|
||||
scale_factor[scale_factor == 0] = 1.0
|
||||
print("*" * 80)
|
||||
print(f"scale factors: {scale_factor}")
|
||||
print(f"bias factor: {bias_factor}")
|
||||
# make it into a copy pasteable string where the numbers are separated by commas
|
||||
scale_factor_str = ", ".join([f"{x}" for x in scale_factor])
|
||||
print(f"scale_factor = np.array([{scale_factor_str}])")
|
||||
bias_factor_str = ", ".join([f"{x}" for x in bias_factor])
|
||||
print(f"bias_factor = np.array([{bias_factor_str}])")
|
||||
print("*" * 80)
|
||||
|
||||
tot = 0
|
||||
|
||||
all_rgbs = []
|
||||
all_depths = []
|
||||
all_actions = []
|
||||
all_states = []
|
||||
|
||||
vis_dir = os.path.join(output_dir, "vis")
|
||||
state_output_dir = os.path.join(vis_dir, "state")
|
||||
action_output_dir = os.path.join(vis_dir, "action")
|
||||
rgb_output_dir = os.path.join(vis_dir, "rgb")
|
||||
depth_output_dir = os.path.join(vis_dir, "depth")
|
||||
|
||||
if not os.path.isdir(vis_dir):
|
||||
os.mkdir(vis_dir)
|
||||
if not os.path.isdir(state_output_dir):
|
||||
os.mkdir(state_output_dir)
|
||||
if not os.path.isdir(action_output_dir):
|
||||
os.mkdir(action_output_dir)
|
||||
if not os.path.isdir(rgb_output_dir):
|
||||
os.mkdir(rgb_output_dir)
|
||||
if not os.path.isdir(depth_output_dir):
|
||||
os.mkdir(depth_output_dir)
|
||||
|
||||
pbar = tqdm(range(len(subdirs)))
|
||||
for i in pbar:
|
||||
out_dir = val_dir if i in val_indices else train_dir
|
||||
out_dir = os.path.join(out_dir, "none")
|
||||
|
||||
if not os.path.isdir(out_dir):
|
||||
os.mkdir(out_dir)
|
||||
|
||||
ret = convert_single_demo(
|
||||
subdirs[i],
|
||||
i,
|
||||
out_dir,
|
||||
rgb_output_dir,
|
||||
depth_output_dir,
|
||||
state_output_dir,
|
||||
action_output_dir,
|
||||
scale_factor=scale_factor,
|
||||
bias_factor=bias_factor,
|
||||
)
|
||||
|
||||
if ret != 0:
|
||||
all_rgbs.append(ret[0])
|
||||
all_depths.append(ret[1])
|
||||
all_actions.append(ret[2])
|
||||
all_states.append(ret[3])
|
||||
tot += 1
|
||||
|
||||
pbar.set_description(f"t: {i}")
|
||||
|
||||
print(
|
||||
f"Finished converting all demos to {output_dir}! (num demos: {tot} / {len(subdirs)})"
|
||||
)
|
||||
|
||||
if args.vis:
|
||||
if len(all_rgbs) > 0:
|
||||
print(f"Visualizing all demos...")
|
||||
|
||||
plot_in_grid(
|
||||
all_actions, os.path.join(action_output_dir, "_all_actions.png")
|
||||
)
|
||||
plot_in_grid(all_states, os.path.join(state_output_dir, "_all_states.png"))
|
||||
make_grid_video_from_numpy(
|
||||
all_rgbs, 10, os.path.join(rgb_output_dir, "_all_rgb.mp4"), fps=30
|
||||
)
|
||||
make_grid_video_from_numpy(
|
||||
all_depths, 10, os.path.join(depth_output_dir, "_all_depth.mp4"), fps=30
|
||||
)
|
||||
|
||||
exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
22
gello/data_utils/format_obs.py
Normal file
22
gello/data_utils/format_obs.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import datetime
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_frame(
|
||||
folder: Path,
|
||||
timestamp: datetime.datetime,
|
||||
obs: Dict[str, np.ndarray],
|
||||
action: np.ndarray,
|
||||
) -> None:
|
||||
obs["control"] = action # add action to obs
|
||||
|
||||
# make folder if it doesn't exist
|
||||
folder.mkdir(exist_ok=True, parents=True)
|
||||
recorded_file = folder / (timestamp.isoformat() + ".pkl")
|
||||
|
||||
with open(recorded_file, "wb") as f:
|
||||
pickle.dump(obs, f)
|
59
gello/data_utils/keyboard_interface.py
Normal file
59
gello/data_utils/keyboard_interface.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import pygame
|
||||
|
||||
NORMAL = (128, 128, 128)
|
||||
RED = (255, 0, 0)
|
||||
GREEN = (0, 255, 0)
|
||||
|
||||
KEY_START = pygame.K_s
|
||||
KEY_CONTINUE = pygame.K_c
|
||||
KEY_QUIT_RECORDING = pygame.K_q
|
||||
|
||||
|
||||
class KBReset:
|
||||
def __init__(self):
|
||||
pygame.init()
|
||||
self._screen = pygame.display.set_mode((800, 800))
|
||||
self._set_color(NORMAL)
|
||||
self._saved = False
|
||||
|
||||
def update(self) -> str:
|
||||
pressed_last = self._get_pressed()
|
||||
if KEY_QUIT_RECORDING in pressed_last:
|
||||
self._set_color(RED)
|
||||
self._saved = False
|
||||
return "normal"
|
||||
|
||||
if self._saved:
|
||||
return "save"
|
||||
|
||||
if KEY_START in pressed_last:
|
||||
self._set_color(GREEN)
|
||||
self._saved = True
|
||||
return "start"
|
||||
|
||||
self._set_color(NORMAL)
|
||||
return "normal"
|
||||
|
||||
def _get_pressed(self):
|
||||
pressed = []
|
||||
pygame.event.pump()
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.KEYDOWN:
|
||||
pressed.append(event.key)
|
||||
return pressed
|
||||
|
||||
def _set_color(self, color):
|
||||
self._screen.fill(color)
|
||||
pygame.display.flip()
|
||||
|
||||
|
||||
def main():
|
||||
kb = KBReset()
|
||||
while True:
|
||||
state = kb.update()
|
||||
if state == "start":
|
||||
print("start")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
103
gello/data_utils/plot_utils.py
Normal file
103
gello/data_utils/plot_utils.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_in_grid(vals: np.ndarray, save_path: str):
|
||||
"""Plot the trajectories in a grid.
|
||||
|
||||
Args:
|
||||
vals: B x T x N, where
|
||||
B is the number of trajectories,
|
||||
T is the number of timesteps,
|
||||
N is the dimensionality of the values.
|
||||
save_path: path to save the plot.
|
||||
"""
|
||||
B = len(vals)
|
||||
N = vals[0].shape[-1]
|
||||
# fig, axes = plt.subplots(2, 4, figsize=(20, 10))
|
||||
rows = (N // 4) + (N % 4 > 0)
|
||||
fig, axes = plt.subplots(rows, 4, figsize=(20, 10))
|
||||
for b in range(B):
|
||||
curr = vals[b]
|
||||
for i in range(N):
|
||||
T = curr.shape[0]
|
||||
# give them transparency
|
||||
axes[i // 4, i % 4].plot(np.arange(T), curr[:, i], alpha=0.5)
|
||||
|
||||
for i in range(N):
|
||||
axes[i // 4, i % 4].set_title(f"Dim {i}")
|
||||
axes[i // 4, i % 4].set_ylim([-1.0, 1.0])
|
||||
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
|
||||
fig = plt.figure(figsize=(20, 10))
|
||||
ax = fig.add_subplot(141, projection="3d")
|
||||
for b in range(B):
|
||||
curr = vals[b]
|
||||
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
|
||||
# scatter the start and end points
|
||||
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
|
||||
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
|
||||
ax.set_xlabel("x")
|
||||
ax.set_ylabel("y")
|
||||
ax.set_zlabel("z")
|
||||
ax.legend(["Trajectory", "Start", "End"])
|
||||
ax.set_xlim([-1, 1])
|
||||
ax.set_ylim([-1, 1])
|
||||
ax.set_zlim([-1, 1])
|
||||
|
||||
# get the 2D view of the XY plane, with X pointing downwards
|
||||
ax.view_init(270, 0)
|
||||
|
||||
ax = fig.add_subplot(142, projection="3d")
|
||||
for b in range(B):
|
||||
curr = vals[b]
|
||||
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
|
||||
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
|
||||
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
|
||||
ax.set_xlabel("x")
|
||||
ax.set_ylabel("y")
|
||||
ax.set_zlabel("z")
|
||||
ax.legend(["Trajectory", "Start", "End"])
|
||||
ax.set_xlim([-1, 1])
|
||||
ax.set_ylim([-1, 1])
|
||||
ax.set_zlim([-1, 1])
|
||||
|
||||
# get the 2D view of the XZ plane, with X pointing leftwards
|
||||
ax.view_init(0, 0)
|
||||
|
||||
ax = fig.add_subplot(143, projection="3d")
|
||||
for b in range(B):
|
||||
curr = vals[b]
|
||||
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
|
||||
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
|
||||
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
|
||||
ax.set_xlabel("x")
|
||||
ax.set_ylabel("y")
|
||||
ax.set_zlabel("z")
|
||||
ax.legend(["Trajectory", "Start", "End"])
|
||||
ax.set_xlim([-1, 1])
|
||||
ax.set_ylim([-1, 1])
|
||||
ax.set_zlim([-1, 1])
|
||||
|
||||
# get the 2D view of the YZ plane, with Y pointing leftwards
|
||||
ax.view_init(0, 90)
|
||||
|
||||
ax = fig.add_subplot(144, projection="3d")
|
||||
for b in range(B):
|
||||
curr = vals[b]
|
||||
ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75)
|
||||
ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r")
|
||||
ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g")
|
||||
ax.set_xlabel("x")
|
||||
ax.set_ylabel("y")
|
||||
ax.set_zlabel("z")
|
||||
ax.legend(["Trajectory", "Start", "End"])
|
||||
ax.set_xlim([-1, 1])
|
||||
ax.set_ylim([-1, 1])
|
||||
ax.set_zlim([-1, 1])
|
||||
|
||||
plt.savefig(save_path[:-4] + "_3d.png")
|
||||
|
||||
plt.close()
|
5
gello/dm_control_tasks/arenas/__init__.py
Normal file
5
gello/dm_control_tasks/arenas/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from gello.dm_control_tasks.arenas.base import Arena
|
||||
|
||||
__all__ = [
|
||||
"Arena",
|
||||
]
|
13
gello/dm_control_tasks/arenas/arena.xml
Normal file
13
gello/dm_control_tasks/arenas/arena.xml
Normal file
|
@ -0,0 +1,13 @@
|
|||
<mujoco model="arena">
|
||||
|
||||
<visual>
|
||||
<headlight diffuse="0.6 0.6 0.6" ambient="0.3 0.3 0.3" specular="0 0 0"/>
|
||||
<rgba haze="0.15 0.25 0.35 1"/>
|
||||
</visual>
|
||||
|
||||
<asset>
|
||||
<texture type="skybox" builtin="gradient" rgb1="0.3 0.5 0.7" rgb2="0 0 0" width="512"
|
||||
height="3072"/>
|
||||
</asset>
|
||||
|
||||
</mujoco>
|
26
gello/dm_control_tasks/arenas/base.py
Normal file
26
gello/dm_control_tasks/arenas/base.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""Arena composer class."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from dm_control import composer, mjcf
|
||||
|
||||
_ARENA_XML = Path(__file__).resolve().parent / "arena.xml"
|
||||
|
||||
|
||||
class Arena(composer.Entity):
|
||||
"""Base arena class."""
|
||||
|
||||
def _build(self, name: str = "arena") -> None:
|
||||
self._mjcf_root = mjcf.from_path(str(_ARENA_XML))
|
||||
if name is not None:
|
||||
self._mjcf_root.model = name
|
||||
|
||||
def add_free_entity(self, entity) -> mjcf.Element:
|
||||
"""Includes an entity as a free moving body, i.e., with a freejoint."""
|
||||
frame = self.attach(entity)
|
||||
frame.add("freejoint")
|
||||
return frame
|
||||
|
||||
@property
|
||||
def mjcf_model(self) -> mjcf.RootElement:
|
||||
return self._mjcf_root
|
7
gello/dm_control_tasks/arms/__init__.py
Normal file
7
gello/dm_control_tasks/arms/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from gello.dm_control_tasks.arms.franka import Franka
|
||||
from gello.dm_control_tasks.arms.ur5e import UR5e
|
||||
|
||||
__all__ = [
|
||||
"UR5e",
|
||||
"Franka",
|
||||
]
|
28
gello/dm_control_tasks/arms/franka.py
Normal file
28
gello/dm_control_tasks/arms/franka.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
"""Franka composer class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
|
||||
|
||||
XML = MENAGERIE_ROOT / "franka_emika_panda" / "panda_nohand.xml"
|
||||
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
|
||||
|
||||
class Franka(Manipulator):
|
||||
"""Franka Robot."""
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str = "franka",
|
||||
xml_path: Union[str, Path] = XML,
|
||||
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
|
||||
) -> None:
|
||||
super()._build(name="franka", xml_path=XML, gripper_xml_path=GRIPPER_XML)
|
||||
|
||||
@property
|
||||
def flange(self) -> mjcf.Element:
|
||||
return self._mjcf_root.find("site", "attachment_site")
|
229
gello/dm_control_tasks/arms/manipulator.py
Normal file
229
gello/dm_control_tasks/arms/manipulator.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
"""Manipulator composer class."""
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from dm_control import composer, mjcf
|
||||
from dm_control.composer.observation import observable
|
||||
from dm_control.mujoco.wrapper import mjbindings
|
||||
from dm_control.suite.utils.randomizers import random_limited_quaternion
|
||||
|
||||
from gello.dm_control_tasks import mjcf_utils
|
||||
|
||||
|
||||
def attach_hand_to_arm(
|
||||
arm_mjcf: mjcf.RootElement,
|
||||
hand_mjcf: mjcf.RootElement,
|
||||
# attach_site: str,
|
||||
) -> None:
|
||||
"""Attaches a hand to an arm.
|
||||
|
||||
The arm must have a site named "attachment_site".
|
||||
|
||||
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
|
||||
|
||||
Args:
|
||||
arm_mjcf: The mjcf.RootElement of the arm.
|
||||
hand_mjcf: The mjcf.RootElement of the hand.
|
||||
attach_site: The name of the site to attach the hand to.
|
||||
|
||||
Raises:
|
||||
ValueError: If the arm does not have a site named "attachment_site".
|
||||
"""
|
||||
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
|
||||
|
||||
# attachment_site = arm_mjcf.find("site", attach_site)
|
||||
attachment_site = arm_mjcf.find("site", "attachment_site")
|
||||
if attachment_site is None:
|
||||
raise ValueError("No attachment site found in the arm model.")
|
||||
|
||||
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
|
||||
arm_key = arm_mjcf.find("key", "home")
|
||||
if arm_key is not None:
|
||||
hand_key = hand_mjcf.find("key", "home")
|
||||
if hand_key is None:
|
||||
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
|
||||
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
|
||||
else:
|
||||
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
|
||||
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
|
||||
|
||||
attachment_site.attach(hand_mjcf)
|
||||
|
||||
|
||||
class Manipulator(composer.Entity, abc.ABC):
|
||||
"""A manipulator entity."""
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str,
|
||||
xml_path: Union[str, Path],
|
||||
gripper_xml_path: Optional[Union[str, Path]],
|
||||
) -> None:
|
||||
"""Builds the manipulator.
|
||||
|
||||
Subclasses can not override this method, but should call this method in their
|
||||
own _build() method.
|
||||
"""
|
||||
self._mjcf_root = mjcf.from_path(str(xml_path))
|
||||
self._arm_joints = tuple(mjcf_utils.safe_find_all(self._mjcf_root, "joint"))
|
||||
if gripper_xml_path:
|
||||
gripper_mjcf = mjcf.from_path(str(gripper_xml_path))
|
||||
attach_hand_to_arm(self._mjcf_root, gripper_mjcf)
|
||||
|
||||
self._mjcf_root.model = name
|
||||
self._add_mjcf_elements()
|
||||
|
||||
def set_joints(self, physics: mjcf.Physics, joints: np.ndarray) -> None:
|
||||
assert len(joints) == len(self._arm_joints)
|
||||
for joint, joint_value in zip(self._arm_joints, joints):
|
||||
joint_id = physics.bind(joint).element_id
|
||||
joint_name = physics.model.id2name(joint_id, "joint")
|
||||
physics.named.data.qpos[joint_name] = joint_value
|
||||
|
||||
def randomize_joints(
|
||||
self,
|
||||
physics: mjcf.Physics,
|
||||
random: Optional[np.random.RandomState] = None,
|
||||
) -> None:
|
||||
random = random or np.random # type: ignore
|
||||
assert random is not None
|
||||
hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE
|
||||
slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE
|
||||
ball = mjbindings.enums.mjtJoint.mjJNT_BALL
|
||||
free = mjbindings.enums.mjtJoint.mjJNT_FREE
|
||||
|
||||
qpos = physics.named.data.qpos
|
||||
|
||||
for joint in self._arm_joints:
|
||||
joint_id = physics.bind(joint).element_id
|
||||
# joint_id = physics.model.name2id(joint.name, "joint")
|
||||
joint_name = physics.model.id2name(joint_id, "joint")
|
||||
joint_type = physics.model.jnt_type[joint_id]
|
||||
is_limited = physics.model.jnt_limited[joint_id]
|
||||
range_min, range_max = physics.model.jnt_range[joint_id]
|
||||
|
||||
if is_limited:
|
||||
if joint_type in [hinge, slide]:
|
||||
qpos[joint_name] = random.uniform(range_min, range_max)
|
||||
|
||||
elif joint_type == ball:
|
||||
qpos[joint_name] = random_limited_quaternion(random, range_max)
|
||||
|
||||
else:
|
||||
if joint_type == hinge:
|
||||
qpos[joint_name] = random.uniform(-np.pi, np.pi)
|
||||
|
||||
elif joint_type == ball:
|
||||
quat = random.randn(4)
|
||||
quat /= np.linalg.norm(quat)
|
||||
qpos[joint_name] = quat
|
||||
|
||||
elif joint_type == free:
|
||||
# this should be random.randn, but changing it now could significantly
|
||||
# affect benchmark results.
|
||||
quat = random.rand(4)
|
||||
quat /= np.linalg.norm(quat)
|
||||
qpos[joint_name][3:] = quat
|
||||
|
||||
def _add_mjcf_elements(self) -> None:
|
||||
# Parse joints.
|
||||
joints = mjcf_utils.safe_find_all(self._mjcf_root, "joint")
|
||||
joints = [joint for joint in joints if joint.tag != "freejoint"]
|
||||
self._joints = tuple(joints)
|
||||
|
||||
# Parse actuators.
|
||||
actuators = mjcf_utils.safe_find_all(self._mjcf_root, "actuator")
|
||||
self._actuators = tuple(actuators)
|
||||
|
||||
# Parse qpos / ctrl keyframes.
|
||||
self._keyframes = {}
|
||||
keyframes = mjcf_utils.safe_find_all(self._mjcf_root, "key")
|
||||
if keyframes:
|
||||
for frame in keyframes:
|
||||
if frame.qpos is not None:
|
||||
qpos = np.array(frame.qpos)
|
||||
self._keyframes[frame.name] = qpos
|
||||
|
||||
# add a visualizeation the flange position that is green
|
||||
self.flange.parent.add(
|
||||
"geom",
|
||||
name="flange_geom",
|
||||
type="sphere",
|
||||
size="0.01",
|
||||
rgba="0 1 0 1",
|
||||
pos=self.flange.pos,
|
||||
contype="0",
|
||||
conaffinity="0",
|
||||
)
|
||||
|
||||
def _build_observables(self):
|
||||
return ArmObservables(self)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def flange(self) -> mjcf.Element:
|
||||
"""Returns the flange element.
|
||||
|
||||
The flange is the end effector of the manipulator where tools can be
|
||||
attached, such as a gripper.
|
||||
"""
|
||||
|
||||
@property
|
||||
def mjcf_model(self) -> mjcf.RootElement:
|
||||
return self._mjcf_root
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._mjcf_root.model
|
||||
|
||||
@property
|
||||
def joints(self) -> Tuple[mjcf.Element, ...]:
|
||||
return self._joints
|
||||
|
||||
@property
|
||||
def actuators(self) -> Tuple[mjcf.Element, ...]:
|
||||
return self._actuators
|
||||
|
||||
@property
|
||||
def keyframes(self) -> Dict[str, np.ndarray]:
|
||||
return self._keyframes
|
||||
|
||||
|
||||
class ArmObservables(composer.Observables):
|
||||
"""Base class for quadruped observables."""
|
||||
|
||||
@composer.observable
|
||||
def joints_pos(self):
|
||||
return observable.MJCFFeature("qpos", self._entity.joints)
|
||||
|
||||
@composer.observable
|
||||
def joints_vel(self):
|
||||
return observable.MJCFFeature("qvel", self._entity.joints)
|
||||
|
||||
@composer.observable
|
||||
def flange_position(self):
|
||||
return observable.MJCFFeature("xpos", self._entity.flange)
|
||||
|
||||
@composer.observable
|
||||
def flange_orientation(self):
|
||||
return observable.MJCFFeature("xmat", self._entity.flange)
|
||||
|
||||
# Semantic grouping of observables.
|
||||
def _collect_from_attachments(self, attribute_name: str):
|
||||
out: List[observable.MJCFFeature] = []
|
||||
for entity in self._entity.iter_entities(exclude_self=True):
|
||||
out.extend(getattr(entity.observables, attribute_name, []))
|
||||
return out
|
||||
|
||||
@property
|
||||
def proprioception(self):
|
||||
return [
|
||||
self.joints_pos,
|
||||
self.joints_vel,
|
||||
self.flange_position,
|
||||
# self.flange_orientation,
|
||||
# self.flange_velocity,
|
||||
self.flange_angular_velocity,
|
||||
] + self._collect_from_attachments("proprioception")
|
26
gello/dm_control_tasks/arms/ur5e.py
Normal file
26
gello/dm_control_tasks/arms/ur5e.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""UR5e composer class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT
|
||||
|
||||
|
||||
class UR5e(Manipulator):
|
||||
GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml"
|
||||
XML = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str = "UR5e",
|
||||
xml_path: Union[str, Path] = XML,
|
||||
gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML,
|
||||
) -> None:
|
||||
super()._build(name=name, xml_path=xml_path, gripper_xml_path=gripper_xml_path)
|
||||
|
||||
@property
|
||||
def flange(self) -> mjcf.Element:
|
||||
return self._mjcf_root.find("site", "attachment_site")
|
22
gello/dm_control_tasks/arms/ur5e_test.py
Normal file
22
gello/dm_control_tasks/arms/ur5e_test.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""Tests for ur5e.py."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms import ur5e
|
||||
|
||||
|
||||
class UR5eTest(absltest.TestCase):
|
||||
def test_compiles_and_steps(self) -> None:
|
||||
robot = ur5e.UR5e()
|
||||
physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model)
|
||||
physics.step()
|
||||
|
||||
def test_joints(self) -> None:
|
||||
robot = ur5e.UR5e()
|
||||
for joint in robot.joints:
|
||||
self.assertEqual(joint.tag, "joint")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
261
gello/dm_control_tasks/arms/utils.py
Normal file
261
gello/dm_control_tasks/arms/utils.py
Normal file
|
@ -0,0 +1,261 @@
|
|||
import collections
|
||||
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
from dm_control.mujoco.wrapper import mjbindings
|
||||
|
||||
mjlib = mjbindings.mjlib
|
||||
|
||||
|
||||
_INVALID_JOINT_NAMES_TYPE = (
|
||||
"`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}."
|
||||
)
|
||||
_REQUIRE_TARGET_POS_OR_QUAT = (
|
||||
"At least one of `target_pos` or `target_quat` must be specified."
|
||||
)
|
||||
|
||||
IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"])
|
||||
|
||||
|
||||
def qpos_from_site_pose(
|
||||
physics,
|
||||
site_name,
|
||||
target_pos=None,
|
||||
target_quat=None,
|
||||
joint_names=None,
|
||||
tol=1e-14,
|
||||
rot_weight=1.0,
|
||||
regularization_threshold=0.1,
|
||||
regularization_strength=3e-2,
|
||||
max_update_norm=2.0,
|
||||
progress_thresh=20.0,
|
||||
max_steps=100,
|
||||
inplace=False,
|
||||
):
|
||||
"""Find joint positions that satisfy a target site position and/or rotation.
|
||||
|
||||
Args:
|
||||
physics: A `mujoco.Physics` instance.
|
||||
site_name: A string specifying the name of the target site.
|
||||
target_pos: A (3,) numpy array specifying the desired Cartesian position of
|
||||
the site, or None if the position should be unconstrained (default).
|
||||
One or both of `target_pos` or `target_quat` must be specified.
|
||||
target_quat: A (4,) numpy array specifying the desired orientation of the
|
||||
site as a quaternion, or None if the orientation should be unconstrained
|
||||
(default). One or both of `target_pos` or `target_quat` must be specified.
|
||||
joint_names: (optional) A list, tuple or numpy array specifying the names of
|
||||
one or more joints that can be manipulated in order to achieve the target
|
||||
site pose. If None (default), all joints may be manipulated.
|
||||
tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm`
|
||||
in the stopping criterion).
|
||||
rot_weight: (optional) Determines the weight given to rotational error
|
||||
relative to translational error.
|
||||
regularization_threshold: (optional) L2 regularization will be used when
|
||||
inverting the Jacobian whilst `err_norm` is greater than this value.
|
||||
regularization_strength: (optional) Coefficient of the quadratic penalty
|
||||
on joint movements.
|
||||
max_update_norm: (optional) The maximum L2 norm of the update applied to
|
||||
the joint positions on each iteration. The update vector will be scaled
|
||||
such that its magnitude never exceeds this value.
|
||||
progress_thresh: (optional) If `err_norm` divided by the magnitude of the
|
||||
joint position update is greater than this value then the optimization
|
||||
will terminate prematurely. This is a useful heuristic to avoid getting
|
||||
stuck in local minima.
|
||||
max_steps: (optional) The maximum number of iterations to perform.
|
||||
inplace: (optional) If True, `physics.data` will be modified in place.
|
||||
Default value is False, i.e. a copy of `physics.data` will be made.
|
||||
|
||||
Returns:
|
||||
An `IKResult` namedtuple with the following fields:
|
||||
qpos: An (nq,) numpy array of joint positions.
|
||||
err_norm: A float, the weighted sum of L2 norms for the residual
|
||||
translational and rotational errors.
|
||||
steps: An int, the number of iterations that were performed.
|
||||
success: Boolean, True if we converged on a solution within `max_steps`,
|
||||
False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `target_pos` and `target_quat` are None, or if
|
||||
`joint_names` has an invalid type.
|
||||
"""
|
||||
dtype = physics.data.qpos.dtype
|
||||
|
||||
if target_pos is not None and target_quat is not None:
|
||||
jac = np.empty((6, physics.model.nv), dtype=dtype)
|
||||
err = np.empty(6, dtype=dtype)
|
||||
jac_pos, jac_rot = jac[:3], jac[3:]
|
||||
err_pos, err_rot = err[:3], err[3:]
|
||||
else:
|
||||
jac = np.empty((3, physics.model.nv), dtype=dtype)
|
||||
err = np.empty(3, dtype=dtype)
|
||||
if target_pos is not None:
|
||||
jac_pos, jac_rot = jac, None
|
||||
err_pos, err_rot = err, None
|
||||
elif target_quat is not None:
|
||||
jac_pos, jac_rot = None, jac
|
||||
err_pos, err_rot = None, err
|
||||
else:
|
||||
raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT)
|
||||
|
||||
update_nv = np.zeros(physics.model.nv, dtype=dtype)
|
||||
|
||||
if target_quat is not None:
|
||||
site_xquat = np.empty(4, dtype=dtype)
|
||||
neg_site_xquat = np.empty(4, dtype=dtype)
|
||||
err_rot_quat = np.empty(4, dtype=dtype)
|
||||
|
||||
if not inplace:
|
||||
physics = physics.copy(share_model=True)
|
||||
|
||||
# Ensure that the Cartesian position of the site is up to date.
|
||||
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
|
||||
|
||||
# Convert site name to index.
|
||||
site_id = physics.model.name2id(site_name, "site")
|
||||
|
||||
# These are views onto the underlying MuJoCo buffers. mj_fwdPosition will
|
||||
# update them in place, so we can avoid indexing overhead in the main loop.
|
||||
site_xpos = physics.named.data.site_xpos[site_name]
|
||||
site_xmat = physics.named.data.site_xmat[site_name]
|
||||
|
||||
# This is an index into the rows of `update` and the columns of `jac`
|
||||
# that selects DOFs associated with joints that we are allowed to manipulate.
|
||||
if joint_names is None:
|
||||
dof_indices = slice(None) # Update all DOFs.
|
||||
elif isinstance(joint_names, (list, np.ndarray, tuple)):
|
||||
if isinstance(joint_names, tuple):
|
||||
joint_names = list(joint_names)
|
||||
# Find the indices of the DOFs belonging to each named joint. Note that
|
||||
# these are not necessarily the same as the joint IDs, since a single joint
|
||||
# may have >1 DOF (e.g. ball joints).
|
||||
indexer = physics.named.model.dof_jntid.axes.row
|
||||
# `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row
|
||||
# indexer to map each joint name to the indices of its corresponding DOFs.
|
||||
dof_indices = indexer.convert_key_item(joint_names)
|
||||
else:
|
||||
raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names)))
|
||||
|
||||
steps = 0
|
||||
success = False
|
||||
|
||||
for steps in range(max_steps):
|
||||
err_norm = 0.0
|
||||
|
||||
if target_pos is not None:
|
||||
# Translational error.
|
||||
err_pos[:] = target_pos - site_xpos
|
||||
err_norm += np.linalg.norm(err_pos)
|
||||
if target_quat is not None:
|
||||
# Rotational error.
|
||||
mjlib.mju_mat2Quat(site_xquat, site_xmat)
|
||||
mjlib.mju_negQuat(neg_site_xquat, site_xquat)
|
||||
mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat)
|
||||
mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1)
|
||||
err_norm += np.linalg.norm(err_rot) * rot_weight
|
||||
|
||||
if err_norm < tol:
|
||||
logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm)
|
||||
success = True
|
||||
break
|
||||
else:
|
||||
# TODO(b/112141670): Generalize this to other entities besides sites.
|
||||
mjlib.mj_jacSite(
|
||||
physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id
|
||||
)
|
||||
jac_joints = jac[:, dof_indices]
|
||||
|
||||
# TODO(b/112141592): This does not take joint limits into consideration.
|
||||
reg_strength = (
|
||||
regularization_strength if err_norm > regularization_threshold else 0.0
|
||||
)
|
||||
update_joints = nullspace_method(
|
||||
jac_joints, err, regularization_strength=reg_strength
|
||||
)
|
||||
|
||||
update_norm = np.linalg.norm(update_joints)
|
||||
|
||||
# Check whether we are still making enough progress, and halt if not.
|
||||
progress_criterion = err_norm / update_norm
|
||||
if progress_criterion > progress_thresh:
|
||||
logging.debug(
|
||||
"Step %2i: err_norm / update_norm (%3g) > "
|
||||
"tolerance (%3g). Halting due to insufficient progress",
|
||||
steps,
|
||||
progress_criterion,
|
||||
progress_thresh,
|
||||
)
|
||||
break
|
||||
|
||||
if update_norm > max_update_norm:
|
||||
update_joints *= max_update_norm / update_norm
|
||||
|
||||
# Write the entries for the specified joints into the full `update_nv`
|
||||
# vector.
|
||||
update_nv[dof_indices] = update_joints
|
||||
|
||||
# Update `physics.qpos`, taking quaternions into account.
|
||||
mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1)
|
||||
|
||||
# Compute the new Cartesian position of the site.
|
||||
mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr)
|
||||
|
||||
logging.debug(
|
||||
"Step %2i: err_norm=%-10.3g update_norm=%-10.3g",
|
||||
steps,
|
||||
err_norm,
|
||||
update_norm,
|
||||
)
|
||||
|
||||
if not success and steps == max_steps - 1:
|
||||
logging.warning(
|
||||
"Failed to converge after %i steps: err_norm=%3g", steps, err_norm
|
||||
)
|
||||
|
||||
if not inplace:
|
||||
# Our temporary copy of physics.data is about to go out of scope, and when
|
||||
# it does the underlying mjData pointer will be freed and physics.data.qpos
|
||||
# will be a view onto a block of deallocated memory. We therefore need to
|
||||
# make a copy of physics.data.qpos while physics.data is still alive.
|
||||
qpos = physics.data.qpos.copy()
|
||||
else:
|
||||
# If we're modifying physics.data in place then it's fine to return a view.
|
||||
qpos = physics.data.qpos
|
||||
|
||||
return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success)
|
||||
|
||||
|
||||
def nullspace_method(jac_joints, delta, regularization_strength=0.0):
|
||||
"""Calculates the joint velocities to achieve a specified end effector delta.
|
||||
|
||||
Args:
|
||||
jac_joints: The Jacobian of the end effector with respect to the joints. A
|
||||
numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta`
|
||||
and `nv` is the number of degrees of freedom.
|
||||
delta: The desired end-effector delta. A numpy array of shape `(3,)` or
|
||||
`(6,)` containing either position deltas, rotation deltas, or both.
|
||||
regularization_strength: (optional) Coefficient of the quadratic penalty
|
||||
on joint movements. Default is zero, i.e. no regularization.
|
||||
|
||||
Returns:
|
||||
An `(nv,)` numpy array of joint velocities.
|
||||
|
||||
Reference:
|
||||
Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian
|
||||
transpose, pseudoinverse and damped least squares methods.
|
||||
https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf
|
||||
"""
|
||||
hess_approx = jac_joints.T.dot(jac_joints)
|
||||
joint_delta = jac_joints.T.dot(delta)
|
||||
if regularization_strength > 0:
|
||||
# L2 regularization
|
||||
hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength
|
||||
return np.linalg.solve(hess_approx, joint_delta)
|
||||
else:
|
||||
return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0]
|
||||
|
||||
|
||||
class InverseKinematics:
|
||||
def __init__(self, xml_path: str):
|
||||
"""Initializes the inverse kinematics class."""
|
||||
...
|
||||
# TODO
|
0
gello/dm_control_tasks/manipulation/__init__.py
Normal file
0
gello/dm_control_tasks/manipulation/__init__.py
Normal file
111
gello/dm_control_tasks/manipulation/arenas/floors.py
Normal file
111
gello/dm_control_tasks/manipulation/arenas/floors.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
"""Simple floor arenas."""
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.dm_control_tasks.arenas import base
|
||||
|
||||
_GROUNDPLANE_QUAD_SIZE = 0.25
|
||||
|
||||
|
||||
class FixedManipulationArena(base.Arena):
|
||||
@property
|
||||
def arm_attachment_site(self) -> np.ndarray:
|
||||
return self.attachment_site
|
||||
|
||||
|
||||
class Floor(base.Arena):
|
||||
"""An arena with a checkered pattern."""
|
||||
|
||||
def _build(
|
||||
self,
|
||||
name: str = "floor",
|
||||
size: Tuple[float, float] = (8, 8),
|
||||
reflectance: float = 0.2,
|
||||
top_camera_y_padding_factor: float = 1.1,
|
||||
top_camera_distance: float = 5.0,
|
||||
) -> None:
|
||||
super()._build(name=name)
|
||||
|
||||
self._size = size
|
||||
self._top_camera_y_padding_factor = top_camera_y_padding_factor
|
||||
self._top_camera_distance = top_camera_distance
|
||||
|
||||
assert self._mjcf_root.worldbody is not None
|
||||
|
||||
z_offset = 0.00
|
||||
|
||||
# Add arm attachement site
|
||||
self._mjcf_root.worldbody.add(
|
||||
"site",
|
||||
name="arm_attachment",
|
||||
pos=(0, 0, z_offset),
|
||||
size=(0.01, 0.01, 0.01),
|
||||
type="sphere",
|
||||
rgba=(0, 0, 0, 0),
|
||||
)
|
||||
|
||||
# Add light.
|
||||
self._mjcf_root.worldbody.add(
|
||||
"light",
|
||||
pos=(0, 0, 1.5),
|
||||
dir=(0, 0, -1),
|
||||
directional=True,
|
||||
)
|
||||
|
||||
self._ground_texture = self._mjcf_root.asset.add(
|
||||
"texture",
|
||||
rgb1=[0.2, 0.3, 0.4],
|
||||
rgb2=[0.1, 0.2, 0.3],
|
||||
type="2d",
|
||||
builtin="checker",
|
||||
name="groundplane",
|
||||
width=200,
|
||||
height=200,
|
||||
mark="edge",
|
||||
markrgb=[0.8, 0.8, 0.8],
|
||||
)
|
||||
|
||||
self._ground_material = self._mjcf_root.asset.add(
|
||||
"material",
|
||||
name="groundplane",
|
||||
texrepeat=[2, 2], # Makes white squares exactly 1x1 length units.
|
||||
texuniform=True,
|
||||
reflectance=reflectance,
|
||||
texture=self._ground_texture,
|
||||
)
|
||||
|
||||
# Build groundplane.
|
||||
self._ground_geom = self._mjcf_root.worldbody.add(
|
||||
"geom",
|
||||
type="plane",
|
||||
name="groundplane",
|
||||
material=self._ground_material,
|
||||
size=list(size) + [_GROUNDPLANE_QUAD_SIZE],
|
||||
)
|
||||
|
||||
# Choose the FOV so that the floor always fits nicely within the frame
|
||||
# irrespective of actual floor size.
|
||||
fovy_radians = 2 * np.arctan2(
|
||||
top_camera_y_padding_factor * size[1], top_camera_distance
|
||||
)
|
||||
self._top_camera = self._mjcf_root.worldbody.add(
|
||||
"camera",
|
||||
name="top_camera",
|
||||
pos=[0, 0, top_camera_distance],
|
||||
quat=[1, 0, 0, 0],
|
||||
fovy=np.rad2deg(fovy_radians),
|
||||
)
|
||||
|
||||
@property
|
||||
def ground_geoms(self):
|
||||
return (self._ground_geom,)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def arm_attachment_site(self) -> mjcf.Element:
|
||||
return self._mjcf_root.worldbody.find("site", "arm_attachment")
|
56
gello/dm_control_tasks/manipulation/tasks/base.py
Normal file
56
gello/dm_control_tasks/manipulation/tasks/base.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
"""A abstract class for all walker tasks."""
|
||||
|
||||
from dm_control import composer, mjcf
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
|
||||
|
||||
# Timestep of the physics simulation.
|
||||
_PHYSICS_TIMESTEP: float = 0.002
|
||||
|
||||
# Interval between agent actions, in seconds.
|
||||
# We send a control signal every (_CONTROL_TIMESTEP / _PHYSICS_TIMESTEP) physics steps.
|
||||
_CONTROL_TIMESTEP: float = 0.02 # 50 Hz.
|
||||
|
||||
|
||||
class ManipulationTask(composer.Task):
|
||||
"""Base composer task for walker robots."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arm: Manipulator,
|
||||
arena: FixedManipulationArena,
|
||||
physics_timestep=_PHYSICS_TIMESTEP,
|
||||
control_timestep=_CONTROL_TIMESTEP,
|
||||
) -> None:
|
||||
self._arm = arm
|
||||
self._arena = arena
|
||||
|
||||
self._arm_geoms = arm.root_body.find_all("geom")
|
||||
self._arena_geoms = arena.root_body.find_all("geom")
|
||||
arena.attach(arm, attach_site=arena.arm_attachment_site)
|
||||
|
||||
self.set_timesteps(
|
||||
control_timestep=control_timestep, physics_timestep=physics_timestep
|
||||
)
|
||||
|
||||
# Enable all robot observables. Note: No additional entity observables should
|
||||
# be added in subclasses.
|
||||
for observable in self._arm.observables.proprioception:
|
||||
observable.enabled = True
|
||||
|
||||
def in_collision(self, physics: mjcf.Physics) -> bool:
|
||||
"""Checks if the arm is in collision with the floor (self._arena)."""
|
||||
arm_ids = [physics.bind(g).element_id for g in self._arm_geoms]
|
||||
arena_ids = [physics.bind(g).element_id for g in self._arena_geoms]
|
||||
|
||||
return any(
|
||||
(con.geom1 in arm_ids and con.geom2 in arena_ids) # arm and arena
|
||||
or (con.geom1 in arena_ids and con.geom2 in arm_ids) # arena and arm
|
||||
or (con.geom1 in arm_ids and con.geom2 in arm_ids) # self collision
|
||||
for con in physics.data.contact[: physics.data.ncon]
|
||||
)
|
||||
|
||||
@property
|
||||
def root_entity(self):
|
||||
return self._arena
|
122
gello/dm_control_tasks/manipulation/tasks/block_play.py
Normal file
122
gello/dm_control_tasks/manipulation/tasks/block_play.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
"""A task where a walker must learn to stand."""
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from dm_control import mjcf
|
||||
from dm_control.suite.utils.randomizers import random_limited_quaternion
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
|
||||
from gello.dm_control_tasks.manipulation.tasks import base
|
||||
|
||||
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
|
||||
|
||||
|
||||
class BlockPlay(base.ManipulationTask):
|
||||
"""Task for a manipulator. Blocks are randomly placed in the scene."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arm: Manipulator,
|
||||
arena: FixedManipulationArena,
|
||||
physics_timestep=base._PHYSICS_TIMESTEP,
|
||||
control_timestep=base._CONTROL_TIMESTEP,
|
||||
num_blocks: int = 10,
|
||||
size: float = 0.03,
|
||||
reset_joints: Optional[np.ndarray] = None,
|
||||
) -> None:
|
||||
super().__init__(arm, arena, physics_timestep, control_timestep)
|
||||
|
||||
# find key frame
|
||||
key_frames = self.root_entity.mjcf_model.find_all("key")
|
||||
if len(key_frames) == 0:
|
||||
key_frames = None
|
||||
else:
|
||||
key_frames = key_frames[0]
|
||||
|
||||
# Create target.
|
||||
block_joints = []
|
||||
for i in range(num_blocks):
|
||||
# select random colors = np.random.uniform(0, 1, size=3)
|
||||
color = np.concatenate([np.random.uniform(0, 1, size=3), [1.0]])
|
||||
|
||||
# attach a body for block i
|
||||
b = self.root_entity.mjcf_model.worldbody.add(
|
||||
"body", name=f"block_{i}", pos=(0, 0, 0)
|
||||
)
|
||||
|
||||
# # add a freejoint to the block so it can be moved
|
||||
_joint = b.add("freejoint")
|
||||
block_joints.append(_joint)
|
||||
|
||||
# add a geom to the block
|
||||
b.add(
|
||||
"geom",
|
||||
name=f"block_geom_{i}",
|
||||
type="box",
|
||||
size=(size, size, size),
|
||||
rgba=color,
|
||||
# contype=0,
|
||||
# conaffinity=0,
|
||||
)
|
||||
assert key_frames is not None
|
||||
key_frames.qpos = np.concatenate([key_frames.qpos, np.zeros(7)])
|
||||
|
||||
# # save xml to file
|
||||
# _xml_string = self.root_entity.mjcf_model.to_xml_string()
|
||||
# with open("block_play.xml", "w") as f:
|
||||
# f.write(_xml_string)
|
||||
|
||||
self._block_joints = block_joints
|
||||
self._block_size = size
|
||||
self._reset_joints = reset_joints
|
||||
|
||||
def initialize_episode(self, physics, random_state):
|
||||
# Randomly set feasible target position
|
||||
if self._reset_joints is not None:
|
||||
self._arm.set_joints(physics, self._reset_joints)
|
||||
else:
|
||||
self._arm.randomize_joints(physics, random_state)
|
||||
physics.forward()
|
||||
|
||||
# check if arm is in collision with floor
|
||||
while self.in_collision(physics):
|
||||
self._arm.randomize_joints(physics, random_state)
|
||||
physics.forward()
|
||||
|
||||
# Randomize block positions
|
||||
for block_j in self._block_joints:
|
||||
randomize_pose(
|
||||
block_j,
|
||||
physics,
|
||||
random_state=random_state,
|
||||
position_range=0.5,
|
||||
z_offset=self._block_size * 2,
|
||||
)
|
||||
|
||||
physics.forward()
|
||||
|
||||
def get_reward(self, physics):
|
||||
# flange position
|
||||
return 0
|
||||
|
||||
|
||||
def randomize_pose(
|
||||
free_joint: mjcf.RootElement,
|
||||
physics: mjcf.Physics,
|
||||
random_state: np.random.RandomState,
|
||||
position_range: float = 0.5,
|
||||
z_offset: float = 0.0,
|
||||
) -> None:
|
||||
"""Randomize the pose of an entity."""
|
||||
entity_pos = random_state.uniform(-position_range, position_range, size=2)
|
||||
# make x, y farther than 0.1 from 0, 0
|
||||
while np.linalg.norm(entity_pos) < 0.3:
|
||||
entity_pos = random_state.uniform(-position_range, position_range, size=2)
|
||||
|
||||
entity_pos = np.concatenate([entity_pos, [z_offset]])
|
||||
entity_quat = random_limited_quaternion(random_state, limit=np.pi)
|
||||
|
||||
qpos = np.concatenate([entity_pos, entity_quat])
|
||||
|
||||
physics.bind(free_joint).qpos = qpos
|
63
gello/dm_control_tasks/manipulation/tasks/reach.py
Normal file
63
gello/dm_control_tasks/manipulation/tasks/reach.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
"""A task where a walker must learn to stand."""
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite.utils import randomizers
|
||||
from dm_control.utils import rewards
|
||||
|
||||
from gello.dm_control_tasks.arms.manipulator import Manipulator
|
||||
from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena
|
||||
from gello.dm_control_tasks.manipulation.tasks import base
|
||||
|
||||
_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6)
|
||||
|
||||
|
||||
class Reach(base.ManipulationTask):
|
||||
"""Reach task for a manipulator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arm: Manipulator,
|
||||
arena: FixedManipulationArena,
|
||||
physics_timestep=base._PHYSICS_TIMESTEP,
|
||||
control_timestep=base._CONTROL_TIMESTEP,
|
||||
distance_tolerance: float = 0.5,
|
||||
) -> None:
|
||||
super().__init__(arm, arena, physics_timestep, control_timestep)
|
||||
|
||||
self._distance_tolerance = distance_tolerance
|
||||
|
||||
# Create target.
|
||||
self._target = self.root_entity.mjcf_model.worldbody.add(
|
||||
"site",
|
||||
name="target",
|
||||
type="sphere",
|
||||
pos=(0, 0, 0),
|
||||
size=(0.1,),
|
||||
rgba=_TARGET_COLOR,
|
||||
)
|
||||
|
||||
def initialize_episode(self, physics, random_state):
|
||||
# Randomly set feasible target position
|
||||
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
|
||||
physics.forward()
|
||||
flange_position = physics.bind(self._arm.flange).xpos[:3]
|
||||
print(flange_position)
|
||||
|
||||
# set target position to flange position
|
||||
physics.bind(self._target).pos = flange_position
|
||||
|
||||
# Randomize initial position of the arm.
|
||||
randomizers.randomize_limited_and_rotational_joints(physics, random_state)
|
||||
physics.forward()
|
||||
|
||||
def get_reward(self, physics):
|
||||
# flange position
|
||||
flange_pos = physics.bind(self._arm.flange).pos[:3]
|
||||
distance = np.linalg.norm(physics.bind(self._target).pos[:3] - flange_pos)
|
||||
return -rewards.tolerance(
|
||||
distance,
|
||||
bounds=(0, self._distance_tolerance),
|
||||
margin=self._distance_tolerance,
|
||||
value_at_margin=0,
|
||||
sigmoid="linear",
|
||||
)
|
23
gello/dm_control_tasks/mjcf_utils.py
Normal file
23
gello/dm_control_tasks/mjcf_utils.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dm_control import mjcf
|
||||
|
||||
# Path to the root of the project.
|
||||
_PROJECT_ROOT: Path = Path(__file__).parent.parent.parent
|
||||
|
||||
# Path to the Menagerie submodule.
|
||||
MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie"
|
||||
|
||||
|
||||
def safe_find_all(
|
||||
root: mjcf.RootElement,
|
||||
feature_name: str,
|
||||
immediate_children_only: bool = False,
|
||||
exclude_attachments: bool = False,
|
||||
) -> List[mjcf.Element]:
|
||||
"""Find all given elements or throw an error if none are found."""
|
||||
features = root.find_all(feature_name, immediate_children_only, exclude_attachments)
|
||||
if not features:
|
||||
raise ValueError(f"No {feature_name} found in the MJCF model.")
|
||||
return features
|
0
gello/dynamixel/__init__.py
Normal file
0
gello/dynamixel/__init__.py
Normal file
274
gello/dynamixel/driver.py
Normal file
274
gello/dynamixel/driver.py
Normal file
|
@ -0,0 +1,274 @@
|
|||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Protocol, Sequence
|
||||
|
||||
import numpy as np
|
||||
from dynamixel_sdk.group_sync_read import GroupSyncRead
|
||||
from dynamixel_sdk.group_sync_write import GroupSyncWrite
|
||||
from dynamixel_sdk.packet_handler import PacketHandler
|
||||
from dynamixel_sdk.port_handler import PortHandler
|
||||
from dynamixel_sdk.robotis_def import (
|
||||
COMM_SUCCESS,
|
||||
DXL_HIBYTE,
|
||||
DXL_HIWORD,
|
||||
DXL_LOBYTE,
|
||||
DXL_LOWORD,
|
||||
)
|
||||
|
||||
# Constants
|
||||
ADDR_TORQUE_ENABLE = 64
|
||||
ADDR_GOAL_POSITION = 116
|
||||
LEN_GOAL_POSITION = 4
|
||||
ADDR_PRESENT_POSITION = 132
|
||||
ADDR_PRESENT_POSITION = 140
|
||||
LEN_PRESENT_POSITION = 4
|
||||
TORQUE_ENABLE = 1
|
||||
TORQUE_DISABLE = 0
|
||||
|
||||
|
||||
class DynamixelDriverProtocol(Protocol):
|
||||
def set_joints(self, joint_angles: Sequence[float]):
|
||||
"""Set the joint angles for the Dynamixel servos.
|
||||
|
||||
Args:
|
||||
joint_angles (Sequence[float]): A list of joint angles.
|
||||
"""
|
||||
...
|
||||
|
||||
def torque_enabled(self) -> bool:
|
||||
"""Check if torque is enabled for the Dynamixel servos.
|
||||
|
||||
Returns:
|
||||
bool: True if torque is enabled, False if it is disabled.
|
||||
"""
|
||||
...
|
||||
|
||||
def set_torque_mode(self, enable: bool):
|
||||
"""Set the torque mode for the Dynamixel servos.
|
||||
|
||||
Args:
|
||||
enable (bool): True to enable torque, False to disable.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_joints(self) -> np.ndarray:
|
||||
"""Get the current joint angles in radians.
|
||||
|
||||
Returns:
|
||||
np.ndarray: An array of joint angles.
|
||||
"""
|
||||
...
|
||||
|
||||
def close(self):
|
||||
"""Close the driver."""
|
||||
|
||||
|
||||
class FakeDynamixelDriver(DynamixelDriverProtocol):
|
||||
def __init__(self, ids: Sequence[int]):
|
||||
self._ids = ids
|
||||
self._joint_angles = np.zeros(len(ids), dtype=int)
|
||||
self._torque_enabled = False
|
||||
|
||||
def set_joints(self, joint_angles: Sequence[float]):
|
||||
if len(joint_angles) != len(self._ids):
|
||||
raise ValueError(
|
||||
"The length of joint_angles must match the number of servos"
|
||||
)
|
||||
if not self._torque_enabled:
|
||||
raise RuntimeError("Torque must be enabled to set joint angles")
|
||||
self._joint_angles = np.array(joint_angles)
|
||||
|
||||
def torque_enabled(self) -> bool:
|
||||
return self._torque_enabled
|
||||
|
||||
def set_torque_mode(self, enable: bool):
|
||||
self._torque_enabled = enable
|
||||
|
||||
def get_joints(self) -> np.ndarray:
|
||||
return self._joint_angles.copy()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class DynamixelDriver(DynamixelDriverProtocol):
|
||||
def __init__(
|
||||
self, ids: Sequence[int], port: str = "/dev/ttyUSB0", baudrate: int = 57600
|
||||
):
|
||||
"""Initialize the DynamixelDriver class.
|
||||
|
||||
Args:
|
||||
ids (Sequence[int]): A list of IDs for the Dynamixel servos.
|
||||
port (str): The USB port to connect to the arm.
|
||||
baudrate (int): The baudrate for communication.
|
||||
"""
|
||||
self._ids = ids
|
||||
self._joint_angles = None
|
||||
self._lock = Lock()
|
||||
|
||||
# Initialize the port handler, packet handler, and group sync read/write
|
||||
self._portHandler = PortHandler(port)
|
||||
self._packetHandler = PacketHandler(2.0)
|
||||
self._groupSyncRead = GroupSyncRead(
|
||||
self._portHandler,
|
||||
self._packetHandler,
|
||||
ADDR_PRESENT_POSITION,
|
||||
LEN_PRESENT_POSITION,
|
||||
)
|
||||
self._groupSyncWrite = GroupSyncWrite(
|
||||
self._portHandler,
|
||||
self._packetHandler,
|
||||
ADDR_GOAL_POSITION,
|
||||
LEN_GOAL_POSITION,
|
||||
)
|
||||
|
||||
# Open the port and set the baudrate
|
||||
if not self._portHandler.openPort():
|
||||
raise RuntimeError("Failed to open the port")
|
||||
|
||||
if not self._portHandler.setBaudRate(baudrate):
|
||||
raise RuntimeError(f"Failed to change the baudrate, {baudrate}")
|
||||
|
||||
# Add parameters for each Dynamixel servo to the group sync read
|
||||
for dxl_id in self._ids:
|
||||
if not self._groupSyncRead.addParam(dxl_id):
|
||||
raise RuntimeError(
|
||||
f"Failed to add parameter for Dynamixel with ID {dxl_id}"
|
||||
)
|
||||
|
||||
# Disable torque for each Dynamixel servo
|
||||
self._torque_enabled = False
|
||||
try:
|
||||
self.set_torque_mode(self._torque_enabled)
|
||||
except Exception as e:
|
||||
print(f"port: {port}, {e}")
|
||||
|
||||
self._stop_thread = Event()
|
||||
self._start_reading_thread()
|
||||
|
||||
def set_joints(self, joint_angles: Sequence[float]):
|
||||
if len(joint_angles) != len(self._ids):
|
||||
raise ValueError(
|
||||
"The length of joint_angles must match the number of servos"
|
||||
)
|
||||
if not self._torque_enabled:
|
||||
raise RuntimeError("Torque must be enabled to set joint angles")
|
||||
|
||||
for dxl_id, angle in zip(self._ids, joint_angles):
|
||||
# Convert the angle to the appropriate value for the servo
|
||||
position_value = int(angle * 2048 / np.pi)
|
||||
|
||||
# Allocate goal position value into byte array
|
||||
param_goal_position = [
|
||||
DXL_LOBYTE(DXL_LOWORD(position_value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(position_value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(position_value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(position_value)),
|
||||
]
|
||||
|
||||
# Add goal position value to the Syncwrite parameter storage
|
||||
dxl_addparam_result = self._groupSyncWrite.addParam(
|
||||
dxl_id, param_goal_position
|
||||
)
|
||||
if not dxl_addparam_result:
|
||||
raise RuntimeError(
|
||||
f"Failed to set joint angle for Dynamixel with ID {dxl_id}"
|
||||
)
|
||||
|
||||
# Syncwrite goal position
|
||||
dxl_comm_result = self._groupSyncWrite.txPacket()
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
raise RuntimeError("Failed to syncwrite goal position")
|
||||
|
||||
# Clear syncwrite parameter storage
|
||||
self._groupSyncWrite.clearParam()
|
||||
|
||||
def torque_enabled(self) -> bool:
|
||||
return self._torque_enabled
|
||||
|
||||
def set_torque_mode(self, enable: bool):
|
||||
torque_value = TORQUE_ENABLE if enable else TORQUE_DISABLE
|
||||
with self._lock:
|
||||
for dxl_id in self._ids:
|
||||
dxl_comm_result, dxl_error = self._packetHandler.write1ByteTxRx(
|
||||
self._portHandler, dxl_id, ADDR_TORQUE_ENABLE, torque_value
|
||||
)
|
||||
if dxl_comm_result != COMM_SUCCESS or dxl_error != 0:
|
||||
print(dxl_comm_result)
|
||||
print(dxl_error)
|
||||
raise RuntimeError(
|
||||
f"Failed to set torque mode for Dynamixel with ID {dxl_id}"
|
||||
)
|
||||
|
||||
self._torque_enabled = enable
|
||||
|
||||
def _start_reading_thread(self):
|
||||
self._reading_thread = Thread(target=self._read_joint_angles)
|
||||
self._reading_thread.daemon = True
|
||||
self._reading_thread.start()
|
||||
|
||||
def _read_joint_angles(self):
|
||||
# Continuously read joint angles and update the joint_angles array
|
||||
while not self._stop_thread.is_set():
|
||||
time.sleep(0.001)
|
||||
with self._lock:
|
||||
_joint_angles = np.zeros(len(self._ids), dtype=int)
|
||||
dxl_comm_result = self._groupSyncRead.txRxPacket()
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
print(f"warning, comm failed: {dxl_comm_result}")
|
||||
continue
|
||||
for i, dxl_id in enumerate(self._ids):
|
||||
if self._groupSyncRead.isAvailable(
|
||||
dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION
|
||||
):
|
||||
angle = self._groupSyncRead.getData(
|
||||
dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION
|
||||
)
|
||||
angle = np.int32(np.uint32(angle))
|
||||
_joint_angles[i] = angle
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Failed to get joint angles for Dynamixel with ID {dxl_id}"
|
||||
)
|
||||
self._joint_angles = _joint_angles
|
||||
# self._groupSyncRead.clearParam() # TODO what does this do? should i add it
|
||||
|
||||
def get_joints(self) -> np.ndarray:
|
||||
# Return a copy of the joint_angles array to avoid race conditions
|
||||
while self._joint_angles is None:
|
||||
time.sleep(0.1)
|
||||
with self._lock:
|
||||
_j = self._joint_angles.copy()
|
||||
return _j / 2048.0 * np.pi
|
||||
|
||||
def close(self):
|
||||
self._stop_thread.set()
|
||||
self._reading_thread.join()
|
||||
self._portHandler.closePort()
|
||||
|
||||
|
||||
def main():
|
||||
# Set the port, baudrate, and servo IDs
|
||||
ids = [1]
|
||||
|
||||
# Create a DynamixelDriver instance
|
||||
try:
|
||||
driver = DynamixelDriver(ids)
|
||||
except FileNotFoundError:
|
||||
driver = DynamixelDriver(ids, port="/dev/cu.usbserial-FT7WBMUB")
|
||||
|
||||
driver.set_torque_mode(True)
|
||||
driver.set_torque_mode(False)
|
||||
|
||||
# Print the joint angles
|
||||
try:
|
||||
while True:
|
||||
joint_angles = driver.get_joints()
|
||||
print(f"Joint angles for IDs {ids}: {joint_angles}")
|
||||
# print(f"Joint angles for IDs {ids[1]}: {joint_angles[1]}")
|
||||
except KeyboardInterrupt:
|
||||
driver.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
35
gello/dynamixel/tests/test_driver.py
Normal file
35
gello/dynamixel/tests/test_driver.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gello.dynamixel.driver import FakeDynamixelDriver
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_driver():
|
||||
return FakeDynamixelDriver(ids=[1, 2])
|
||||
|
||||
|
||||
def test_set_joints(fake_driver):
|
||||
fake_driver.set_torque_mode(True)
|
||||
fake_driver.set_joints([np.pi / 2, np.pi / 2])
|
||||
assert np.allclose(fake_driver.get_joints(), [np.pi / 2, np.pi / 2])
|
||||
|
||||
|
||||
def test_set_joints_wrong_length(fake_driver):
|
||||
with pytest.raises(ValueError):
|
||||
fake_driver.set_joints([np.pi / 2])
|
||||
|
||||
|
||||
def test_set_joints_torque_disabled(fake_driver):
|
||||
with pytest.raises(RuntimeError):
|
||||
fake_driver.set_joints([np.pi / 2, np.pi / 2])
|
||||
|
||||
|
||||
def test_torque_enabled(fake_driver):
|
||||
assert not fake_driver.torque_enabled()
|
||||
fake_driver.set_torque_mode(True)
|
||||
assert fake_driver.torque_enabled()
|
||||
|
||||
|
||||
def test_get_joints(fake_driver):
|
||||
assert np.allclose(fake_driver.get_joints(), [0, 0])
|
88
gello/env.py
Normal file
88
gello/env.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gello.cameras.camera import CameraDriver
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
|
||||
class Rate:
|
||||
def __init__(self, rate: float):
|
||||
self.last = time.time()
|
||||
self.rate = rate
|
||||
|
||||
def sleep(self) -> None:
|
||||
while self.last + 1.0 / self.rate > time.time():
|
||||
time.sleep(0.0001)
|
||||
self.last = time.time()
|
||||
|
||||
|
||||
class RobotEnv:
|
||||
def __init__(
|
||||
self,
|
||||
robot: Robot,
|
||||
control_rate_hz: float = 100.0,
|
||||
camera_dict: Optional[Dict[str, CameraDriver]] = None,
|
||||
) -> None:
|
||||
self._robot = robot
|
||||
self._rate = Rate(control_rate_hz)
|
||||
self._camera_dict = {} if camera_dict is None else camera_dict
|
||||
|
||||
def robot(self) -> Robot:
|
||||
"""Get the robot object.
|
||||
|
||||
Returns:
|
||||
robot: the robot object.
|
||||
"""
|
||||
return self._robot
|
||||
|
||||
def __len__(self):
|
||||
return 0
|
||||
|
||||
def step(self, joints: np.ndarray) -> Dict[str, Any]:
|
||||
"""Step the environment forward.
|
||||
|
||||
Args:
|
||||
joints: joint angles command to step the environment with.
|
||||
|
||||
Returns:
|
||||
obs: observation from the environment.
|
||||
"""
|
||||
assert len(joints) == (
|
||||
self._robot.num_dofs()
|
||||
), f"input:{len(joints)}, robot:{self._robot.num_dofs()}"
|
||||
assert self._robot.num_dofs() == len(joints)
|
||||
self._robot.command_joint_state(joints)
|
||||
self._rate.sleep()
|
||||
return self.get_obs()
|
||||
|
||||
def get_obs(self) -> Dict[str, Any]:
|
||||
"""Get observation from the environment.
|
||||
|
||||
Returns:
|
||||
obs: observation from the environment.
|
||||
"""
|
||||
observations = {}
|
||||
for name, camera in self._camera_dict.items():
|
||||
image, depth = camera.read()
|
||||
observations[f"{name}_rgb"] = image
|
||||
observations[f"{name}_depth"] = depth
|
||||
|
||||
robot_obs = self._robot.get_observations()
|
||||
assert "joint_positions" in robot_obs
|
||||
assert "joint_velocities" in robot_obs
|
||||
assert "ee_pos_quat" in robot_obs
|
||||
observations["joint_positions"] = robot_obs["joint_positions"]
|
||||
observations["joint_velocities"] = robot_obs["joint_velocities"]
|
||||
observations["ee_pos_quat"] = robot_obs["ee_pos_quat"]
|
||||
observations["gripper_position"] = robot_obs["gripper_position"]
|
||||
return observations
|
||||
|
||||
|
||||
def main() -> None:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
134
gello/robots/dynamixel.py
Normal file
134
gello/robots/dynamixel.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
from typing import Dict, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
|
||||
class DynamixelRobot(Robot):
|
||||
"""A class representing a UR robot."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
joint_ids: Sequence[int],
|
||||
joint_offsets: Optional[Sequence[float]] = None,
|
||||
joint_signs: Optional[Sequence[int]] = None,
|
||||
real: bool = False,
|
||||
port: str = "/dev/ttyUSB0",
|
||||
baudrate: int = 57600,
|
||||
gripper_config: Optional[Tuple[int, float, float]] = None,
|
||||
start_joints: Optional[np.ndarray] = None,
|
||||
):
|
||||
from gello.dynamixel.driver import (
|
||||
DynamixelDriver,
|
||||
DynamixelDriverProtocol,
|
||||
FakeDynamixelDriver,
|
||||
)
|
||||
|
||||
print(f"attempting to connect to port: {port}")
|
||||
self.gripper_open_close: Optional[Tuple[float, float]]
|
||||
if gripper_config is not None:
|
||||
assert joint_offsets is not None
|
||||
assert joint_signs is not None
|
||||
|
||||
# joint_ids.append(gripper_config[0])
|
||||
# joint_offsets.append(0.0)
|
||||
# joint_signs.append(1)
|
||||
joint_ids = tuple(joint_ids) + (gripper_config[0],)
|
||||
joint_offsets = tuple(joint_offsets) + (0.0,)
|
||||
joint_signs = tuple(joint_signs) + (1,)
|
||||
self.gripper_open_close = (
|
||||
gripper_config[1] * np.pi / 180,
|
||||
gripper_config[2] * np.pi / 180,
|
||||
)
|
||||
else:
|
||||
self.gripper_open_close = None
|
||||
|
||||
self._joint_ids = joint_ids
|
||||
self._driver: DynamixelDriverProtocol
|
||||
|
||||
if joint_offsets is None:
|
||||
self._joint_offsets = np.zeros(len(joint_ids))
|
||||
else:
|
||||
self._joint_offsets = np.array(joint_offsets)
|
||||
|
||||
if joint_signs is None:
|
||||
self._joint_signs = np.ones(len(joint_ids))
|
||||
else:
|
||||
self._joint_signs = np.array(joint_signs)
|
||||
|
||||
assert len(self._joint_ids) == len(self._joint_offsets), (
|
||||
f"joint_ids: {len(self._joint_ids)}, "
|
||||
f"joint_offsets: {len(self._joint_offsets)}"
|
||||
)
|
||||
assert len(self._joint_ids) == len(self._joint_signs), (
|
||||
f"joint_ids: {len(self._joint_ids)}, "
|
||||
f"joint_signs: {len(self._joint_signs)}"
|
||||
)
|
||||
assert np.all(
|
||||
np.abs(self._joint_signs) == 1
|
||||
), f"joint_signs: {self._joint_signs}"
|
||||
|
||||
if real:
|
||||
self._driver = DynamixelDriver(joint_ids, port=port, baudrate=baudrate)
|
||||
self._driver.set_torque_mode(False)
|
||||
else:
|
||||
self._driver = FakeDynamixelDriver(joint_ids)
|
||||
self._torque_on = False
|
||||
self._last_pos = None
|
||||
self._alpha = 0.99
|
||||
|
||||
if start_joints is not None:
|
||||
# loop through all joints and add +- 2pi to the joint offsets to get the closest to start joints
|
||||
new_joint_offsets = []
|
||||
current_joints = self.get_joint_state()
|
||||
assert current_joints.shape == start_joints.shape
|
||||
if gripper_config is not None:
|
||||
current_joints = current_joints[:-1]
|
||||
start_joints = start_joints[:-1]
|
||||
for c_joint, s_joint, joint_offset in zip(
|
||||
current_joints, start_joints, self._joint_offsets
|
||||
):
|
||||
new_joint_offsets.append(
|
||||
np.pi * 2 * np.round((s_joint - c_joint) / (2 * np.pi))
|
||||
+ joint_offset
|
||||
)
|
||||
if gripper_config is not None:
|
||||
new_joint_offsets.append(self._joint_offsets[-1])
|
||||
self._joint_offsets = np.array(new_joint_offsets)
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return len(self._joint_ids)
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
pos = (self._driver.get_joints() - self._joint_offsets) * self._joint_signs
|
||||
assert len(pos) == self.num_dofs()
|
||||
|
||||
if self.gripper_open_close is not None:
|
||||
# map pos to [0, 1]
|
||||
g_pos = (pos[-1] - self.gripper_open_close[0]) / (
|
||||
self.gripper_open_close[1] - self.gripper_open_close[0]
|
||||
)
|
||||
g_pos = min(max(0, g_pos), 1)
|
||||
pos[-1] = g_pos
|
||||
|
||||
if self._last_pos is None:
|
||||
self._last_pos = pos
|
||||
else:
|
||||
# exponential smoothing
|
||||
pos = self._last_pos * (1 - self._alpha) + pos * self._alpha
|
||||
self._last_pos = pos
|
||||
|
||||
return pos
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
self._driver.set_joints((joint_state + self._joint_offsets).tolist())
|
||||
|
||||
def set_torque_mode(self, mode: bool):
|
||||
if mode == self._torque_on:
|
||||
return
|
||||
self._driver.set_torque_mode(mode)
|
||||
self._torque_on = mode
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
return {"joint_state": self.get_joint_state()}
|
88
gello/robots/panda.py
Normal file
88
gello/robots/panda.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import time
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
MAX_OPEN = 0.09
|
||||
|
||||
|
||||
class PandaRobot(Robot):
|
||||
"""A class representing a UR robot."""
|
||||
|
||||
def __init__(self, robot_ip: str = "100.97.47.74"):
|
||||
from polymetis import GripperInterface, RobotInterface
|
||||
|
||||
self.robot = RobotInterface(
|
||||
ip_address=robot_ip,
|
||||
)
|
||||
self.gripper = GripperInterface(
|
||||
ip_address="localhost",
|
||||
)
|
||||
self.robot.go_home()
|
||||
self.robot.start_joint_impedance()
|
||||
self.gripper.goto(width=MAX_OPEN, speed=255, force=255)
|
||||
time.sleep(1)
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
"""Get the number of joints of the robot.
|
||||
|
||||
Returns:
|
||||
int: The number of joints of the robot.
|
||||
"""
|
||||
return 8
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
"""Get the current state of the leader robot.
|
||||
|
||||
Returns:
|
||||
T: The current state of the leader robot.
|
||||
"""
|
||||
robot_joints = self.robot.get_joint_positions()
|
||||
gripper_pos = self.gripper.get_state()
|
||||
pos = np.append(robot_joints, gripper_pos.width / MAX_OPEN)
|
||||
return pos
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
"""Command the leader robot to a given state.
|
||||
|
||||
Args:
|
||||
joint_state (np.ndarray): The state to command the leader robot to.
|
||||
"""
|
||||
import torch
|
||||
|
||||
self.robot.update_desired_joint_positions(torch.tensor(joint_state[:-1]))
|
||||
self.gripper.goto(width=(MAX_OPEN * (1 - joint_state[-1])), speed=1, force=1)
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
joints = self.get_joint_state()
|
||||
pos_quat = np.zeros(7)
|
||||
gripper_pos = np.array([joints[-1]])
|
||||
return {
|
||||
"joint_positions": joints,
|
||||
"joint_velocities": joints,
|
||||
"ee_pos_quat": pos_quat,
|
||||
"gripper_position": gripper_pos,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
robot = PandaRobot()
|
||||
current_joints = robot.get_joint_state()
|
||||
# move a small delta 0.1 rad
|
||||
move_joints = current_joints + 0.05
|
||||
# make last joint (gripper) closed
|
||||
move_joints[-1] = 0.5
|
||||
time.sleep(1)
|
||||
m = 0.09
|
||||
robot.gripper.goto(1 * m, speed=255, force=255)
|
||||
time.sleep(1)
|
||||
robot.gripper.goto(1.05 * m, speed=255, force=255)
|
||||
time.sleep(1)
|
||||
robot.gripper.goto(1.1 * m, speed=255, force=255)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
128
gello/robots/robot.py
Normal file
128
gello/robots/robot.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Dict, Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Robot(Protocol):
|
||||
"""Robot protocol.
|
||||
|
||||
A protocol for a robot that can be controlled.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def num_dofs(self) -> int:
|
||||
"""Get the number of joints of the robot.
|
||||
|
||||
Returns:
|
||||
int: The number of joints of the robot.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
"""Get the current state of the leader robot.
|
||||
|
||||
Returns:
|
||||
T: The current state of the leader robot.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
"""Command the leader robot to a given state.
|
||||
|
||||
Args:
|
||||
joint_state (np.ndarray): The state to command the leader robot to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
"""Get the current observations of the robot.
|
||||
|
||||
This is to extract all the information that is available from the robot,
|
||||
such as joint positions, joint velocities, etc. This may also include
|
||||
information from additional sensors, such as cameras, force sensors, etc.
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: A dictionary of observations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PrintRobot(Robot):
|
||||
"""A robot that prints the commanded joint state."""
|
||||
|
||||
def __init__(self, num_dofs: int, dont_print: bool = False):
|
||||
self._num_dofs = num_dofs
|
||||
self._joint_state = np.zeros(num_dofs)
|
||||
self._dont_print = dont_print
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return self._num_dofs
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
return self._joint_state
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
assert len(joint_state) == (self._num_dofs), (
|
||||
f"Expected joint state of length {self._num_dofs}, "
|
||||
f"got {len(joint_state)}."
|
||||
)
|
||||
self._joint_state = joint_state
|
||||
if not self._dont_print:
|
||||
print(self._joint_state)
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
joint_state = self.get_joint_state()
|
||||
pos_quat = np.zeros(7)
|
||||
return {
|
||||
"joint_positions": joint_state,
|
||||
"joint_velocities": joint_state,
|
||||
"ee_pos_quat": pos_quat,
|
||||
"gripper_position": np.array(0),
|
||||
}
|
||||
|
||||
|
||||
class BimanualRobot(Robot):
|
||||
def __init__(self, robot_l: Robot, robot_r: Robot):
|
||||
self._robot_l = robot_l
|
||||
self._robot_r = robot_r
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return self._robot_l.num_dofs() + self._robot_r.num_dofs()
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
return np.concatenate(
|
||||
(self._robot_l.get_joint_state(), self._robot_r.get_joint_state())
|
||||
)
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()])
|
||||
self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :])
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
l_obs = self._robot_l.get_observations()
|
||||
r_obs = self._robot_r.get_observations()
|
||||
assert l_obs.keys() == r_obs.keys()
|
||||
return_obs = {}
|
||||
for k in l_obs.keys():
|
||||
try:
|
||||
return_obs[k] = np.concatenate((l_obs[k], r_obs[k]))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(k)
|
||||
print(l_obs[k])
|
||||
print(r_obs[k])
|
||||
raise RuntimeError()
|
||||
|
||||
return return_obs
|
||||
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
358
gello/robots/robotiq_gripper.py
Normal file
358
gello/robots/robotiq_gripper.py
Normal file
|
@ -0,0 +1,358 @@
|
|||
"""Module to control Robotiq's grippers - tested with HAND-E.
|
||||
|
||||
Taken from https://github.com/githubuser0xFFFF/py_robotiq_gripper/blob/master/src/robotiq_gripper.py
|
||||
"""
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import OrderedDict, Tuple, Union
|
||||
|
||||
|
||||
class RobotiqGripper:
|
||||
"""Communicates with the gripper directly, via socket with string commands, leveraging string names for variables."""
|
||||
|
||||
# WRITE VARIABLES (CAN ALSO READ)
|
||||
ACT = (
|
||||
"ACT" # act : activate (1 while activated, can be reset to clear fault status)
|
||||
)
|
||||
GTO = (
|
||||
"GTO" # gto : go to (will perform go to with the actions set in pos, for, spe)
|
||||
)
|
||||
ATR = "ATR" # atr : auto-release (emergency slow move)
|
||||
ADR = (
|
||||
"ADR" # adr : auto-release direction (open(1) or close(0) during auto-release)
|
||||
)
|
||||
FOR = "FOR" # for : force (0-255)
|
||||
SPE = "SPE" # spe : speed (0-255)
|
||||
POS = "POS" # pos : position (0-255), 0 = open
|
||||
# READ VARIABLES
|
||||
STA = "STA" # status (0 = is reset, 1 = activating, 3 = active)
|
||||
PRE = "PRE" # position request (echo of last commanded position)
|
||||
OBJ = "OBJ" # object detection (0 = moving, 1 = outer grip, 2 = inner grip, 3 = no object at rest)
|
||||
FLT = "FLT" # fault (0=ok, see manual for errors if not zero)
|
||||
|
||||
ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work
|
||||
|
||||
class GripperStatus(Enum):
|
||||
"""Gripper status reported by the gripper. The integer values have to match what the gripper sends."""
|
||||
|
||||
RESET = 0
|
||||
ACTIVATING = 1
|
||||
# UNUSED = 2 # This value is currently not used by the gripper firmware
|
||||
ACTIVE = 3
|
||||
|
||||
class ObjectStatus(Enum):
|
||||
"""Object status reported by the gripper. The integer values have to match what the gripper sends."""
|
||||
|
||||
MOVING = 0
|
||||
STOPPED_OUTER_OBJECT = 1
|
||||
STOPPED_INNER_OBJECT = 2
|
||||
AT_DEST = 3
|
||||
|
||||
def __init__(self):
|
||||
"""Constructor."""
|
||||
self.socket = None
|
||||
self.command_lock = threading.Lock()
|
||||
self._min_position = 0
|
||||
self._max_position = 255
|
||||
self._min_speed = 0
|
||||
self._max_speed = 255
|
||||
self._min_force = 0
|
||||
self._max_force = 255
|
||||
|
||||
def connect(self, hostname: str, port: int, socket_timeout: float = 10.0) -> None:
|
||||
"""Connects to a gripper at the given address.
|
||||
|
||||
:param hostname: Hostname or ip.
|
||||
:param port: Port.
|
||||
:param socket_timeout: Timeout for blocking socket operations.
|
||||
"""
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
assert self.socket is not None
|
||||
self.socket.connect((hostname, port))
|
||||
self.socket.settimeout(socket_timeout)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Closes the connection with the gripper."""
|
||||
assert self.socket is not None
|
||||
self.socket.close()
|
||||
|
||||
def _set_vars(self, var_dict: OrderedDict[str, Union[int, float]]):
|
||||
"""Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response.
|
||||
|
||||
:param var_dict: Dictionary of variables to set (variable_name, value).
|
||||
:return: True on successful reception of ack, false if no ack was received, indicating the set may not
|
||||
have been effective.
|
||||
"""
|
||||
assert self.socket is not None
|
||||
# construct unique command
|
||||
cmd = "SET"
|
||||
for variable, value in var_dict.items():
|
||||
cmd += f" {variable} {str(value)}"
|
||||
cmd += "\n" # new line is required for the command to finish
|
||||
# atomic commands send/rcv
|
||||
with self.command_lock:
|
||||
self.socket.sendall(cmd.encode(self.ENCODING))
|
||||
data = self.socket.recv(1024)
|
||||
return self._is_ack(data)
|
||||
|
||||
def _set_var(self, variable: str, value: Union[int, float]):
|
||||
"""Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response.
|
||||
|
||||
:param variable: Variable to set.
|
||||
:param value: Value to set for the variable.
|
||||
:return: True on successful reception of ack, false if no ack was received, indicating the set may not
|
||||
have been effective.
|
||||
"""
|
||||
return self._set_vars(OrderedDict([(variable, value)]))
|
||||
|
||||
def _get_var(self, variable: str):
|
||||
"""Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the response is received or the socket times out.
|
||||
|
||||
:param variable: Name of the variable to retrieve.
|
||||
:return: Value of the variable as integer.
|
||||
"""
|
||||
assert self.socket is not None
|
||||
# atomic commands send/rcv
|
||||
with self.command_lock:
|
||||
cmd = f"GET {variable}\n"
|
||||
self.socket.sendall(cmd.encode(self.ENCODING))
|
||||
data = self.socket.recv(1024)
|
||||
|
||||
# expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value
|
||||
# note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here
|
||||
var_name, value_str = data.decode(self.ENCODING).split()
|
||||
if var_name != variable:
|
||||
raise ValueError(
|
||||
f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'"
|
||||
)
|
||||
value = int(value_str)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _is_ack(data: str):
|
||||
return data == b"ack"
|
||||
|
||||
def _reset(self):
|
||||
"""Reset the gripper.
|
||||
|
||||
The following code is executed in the corresponding script function
|
||||
def rq_reset(gripper_socket="1"):
|
||||
rq_set_var("ACT", 0, gripper_socket)
|
||||
rq_set_var("ATR", 0, gripper_socket)
|
||||
|
||||
while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
|
||||
rq_set_var("ACT", 0, gripper_socket)
|
||||
rq_set_var("ATR", 0, gripper_socket)
|
||||
sync()
|
||||
end
|
||||
|
||||
sleep(0.5)
|
||||
end
|
||||
"""
|
||||
self._set_var(self.ACT, 0)
|
||||
self._set_var(self.ATR, 0)
|
||||
while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
|
||||
self._set_var(self.ACT, 0)
|
||||
self._set_var(self.ATR, 0)
|
||||
time.sleep(0.5)
|
||||
|
||||
def activate(self, auto_calibrate: bool = True):
|
||||
"""Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags.
|
||||
|
||||
:param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion.
|
||||
|
||||
The following code is executed in the corresponding script function
|
||||
def rq_activate(gripper_socket="1"):
|
||||
if (not rq_is_gripper_activated(gripper_socket)):
|
||||
rq_reset(gripper_socket)
|
||||
|
||||
while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
|
||||
rq_reset(gripper_socket)
|
||||
sync()
|
||||
end
|
||||
|
||||
rq_set_var("ACT",1, gripper_socket)
|
||||
end
|
||||
end
|
||||
|
||||
def rq_activate_and_wait(gripper_socket="1"):
|
||||
if (not rq_is_gripper_activated(gripper_socket)):
|
||||
rq_activate(gripper_socket)
|
||||
sleep(1.0)
|
||||
|
||||
while(not rq_get_var("ACT", 1, gripper_socket) == 1 or not rq_get_var("STA", 1, gripper_socket) == 3):
|
||||
sleep(0.1)
|
||||
end
|
||||
|
||||
sleep(0.5)
|
||||
end
|
||||
end
|
||||
"""
|
||||
if not self.is_active():
|
||||
self._reset()
|
||||
while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
|
||||
time.sleep(0.01)
|
||||
|
||||
self._set_var(self.ACT, 1)
|
||||
time.sleep(1.0)
|
||||
while not self._get_var(self.ACT) == 1 or not self._get_var(self.STA) == 3:
|
||||
time.sleep(0.01)
|
||||
|
||||
# auto-calibrate position range if desired
|
||||
if auto_calibrate:
|
||||
self.auto_calibrate()
|
||||
|
||||
def is_active(self):
|
||||
"""Returns whether the gripper is active."""
|
||||
status = self._get_var(self.STA)
|
||||
return (
|
||||
RobotiqGripper.GripperStatus(status) == RobotiqGripper.GripperStatus.ACTIVE
|
||||
)
|
||||
|
||||
def get_min_position(self) -> int:
|
||||
"""Returns the minimum position the gripper can reach (open position)."""
|
||||
return self._min_position
|
||||
|
||||
def get_max_position(self) -> int:
|
||||
"""Returns the maximum position the gripper can reach (closed position)."""
|
||||
return self._max_position
|
||||
|
||||
def get_open_position(self) -> int:
|
||||
"""Returns what is considered the open position for gripper (minimum position value)."""
|
||||
return self.get_min_position()
|
||||
|
||||
def get_closed_position(self) -> int:
|
||||
"""Returns what is considered the closed position for gripper (maximum position value)."""
|
||||
return self.get_max_position()
|
||||
|
||||
def is_open(self):
|
||||
"""Returns whether the current position is considered as being fully open."""
|
||||
return self.get_current_position() <= self.get_open_position()
|
||||
|
||||
def is_closed(self):
|
||||
"""Returns whether the current position is considered as being fully closed."""
|
||||
return self.get_current_position() >= self.get_closed_position()
|
||||
|
||||
def get_current_position(self) -> int:
|
||||
"""Returns the current position as returned by the physical hardware."""
|
||||
return self._get_var(self.POS)
|
||||
|
||||
def auto_calibrate(self, log: bool = True) -> None:
|
||||
"""Attempts to calibrate the open and closed positions, by slowly closing and opening the gripper.
|
||||
|
||||
:param log: Whether to print the results to log.
|
||||
"""
|
||||
# first try to open in case we are holding an object
|
||||
(position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
|
||||
if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
|
||||
raise RuntimeError(f"Calibration failed opening to start: {str(status)}")
|
||||
|
||||
# try to close as far as possible, and record the number
|
||||
(position, status) = self.move_and_wait_for_pos(
|
||||
self.get_closed_position(), 64, 1
|
||||
)
|
||||
if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
|
||||
raise RuntimeError(
|
||||
f"Calibration failed because of an object: {str(status)}"
|
||||
)
|
||||
assert position <= self._max_position
|
||||
self._max_position = position
|
||||
|
||||
# try to open as far as possible, and record the number
|
||||
(position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
|
||||
if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
|
||||
raise RuntimeError(
|
||||
f"Calibration failed because of an object: {str(status)}"
|
||||
)
|
||||
assert position >= self._min_position
|
||||
self._min_position = position
|
||||
|
||||
if log:
|
||||
print(
|
||||
f"Gripper auto-calibrated to [{self.get_min_position()}, {self.get_max_position()}]"
|
||||
)
|
||||
|
||||
def move(self, position: int, speed: int, force: int) -> Tuple[bool, int]:
|
||||
"""Sends commands to start moving towards the given position, with the specified speed and force.
|
||||
|
||||
:param position: Position to move to [min_position, max_position]
|
||||
:param speed: Speed to move at [min_speed, max_speed]
|
||||
:param force: Force to use [min_force, max_force]
|
||||
:return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with
|
||||
the actual position that was requested, after being adjusted to the min/max calibrated range.
|
||||
"""
|
||||
|
||||
def clip_val(min_val, val, max_val):
|
||||
return max(min_val, min(val, max_val))
|
||||
|
||||
clip_pos = clip_val(self._min_position, position, self._max_position)
|
||||
clip_spe = clip_val(self._min_speed, speed, self._max_speed)
|
||||
clip_for = clip_val(self._min_force, force, self._max_force)
|
||||
|
||||
# moves to the given position with the given speed and force
|
||||
var_dict = OrderedDict(
|
||||
[
|
||||
(self.POS, clip_pos),
|
||||
(self.SPE, clip_spe),
|
||||
(self.FOR, clip_for),
|
||||
(self.GTO, 1),
|
||||
]
|
||||
)
|
||||
succ = self._set_vars(var_dict)
|
||||
time.sleep(0.008) # need to wait (dont know why)
|
||||
return succ, clip_pos
|
||||
|
||||
def move_and_wait_for_pos(
|
||||
self, position: int, speed: int, force: int
|
||||
) -> Tuple[int, ObjectStatus]: # noqa
|
||||
"""Sends commands to start moving towards the given position, with the specified speed and force, and then waits for the move to complete.
|
||||
|
||||
:param position: Position to move to [min_position, max_position]
|
||||
:param speed: Speed to move at [min_speed, max_speed]
|
||||
:param force: Force to use [min_force, max_force]
|
||||
:return: A tuple with an integer representing the last position returned by the gripper after it notified
|
||||
that the move had completed, a status indicating how the move ended (see ObjectStatus enum for details). Note
|
||||
that it is possible that the position was not reached, if an object was detected during motion.
|
||||
"""
|
||||
set_ok, cmd_pos = self.move(position, speed, force)
|
||||
if not set_ok:
|
||||
raise RuntimeError("Failed to set variables for move.")
|
||||
|
||||
# wait until the gripper acknowledges that it will try to go to the requested position
|
||||
while self._get_var(self.PRE) != cmd_pos:
|
||||
time.sleep(0.001)
|
||||
|
||||
# wait until not moving
|
||||
cur_obj = self._get_var(self.OBJ)
|
||||
while (
|
||||
RobotiqGripper.ObjectStatus(cur_obj) == RobotiqGripper.ObjectStatus.MOVING
|
||||
):
|
||||
cur_obj = self._get_var(self.OBJ)
|
||||
|
||||
# report the actual position and the object status
|
||||
final_pos = self._get_var(self.POS)
|
||||
final_obj = cur_obj
|
||||
return final_pos, RobotiqGripper.ObjectStatus(final_obj)
|
||||
|
||||
|
||||
def main():
|
||||
# test open and closing the gripper
|
||||
gripper = RobotiqGripper()
|
||||
gripper.connect(hostname="192.168.1.10", port=63352)
|
||||
# gripper.activate()
|
||||
print(gripper.get_current_position())
|
||||
gripper.move(20, 255, 1)
|
||||
time.sleep(0.2)
|
||||
print(gripper.get_current_position())
|
||||
gripper.move(65, 255, 1)
|
||||
time.sleep(0.2)
|
||||
print(gripper.get_current_position())
|
||||
gripper.move(20, 255, 1)
|
||||
gripper.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
256
gello/robots/sim_robot.py
Normal file
256
gello/robots/sim_robot.py
Normal file
|
@ -0,0 +1,256 @@
|
|||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import zmq
|
||||
from dm_control import mjcf
|
||||
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
assert mujoco.viewer is mujoco.viewer
|
||||
|
||||
|
||||
def attach_hand_to_arm(
|
||||
arm_mjcf: mjcf.RootElement,
|
||||
hand_mjcf: mjcf.RootElement,
|
||||
) -> None:
|
||||
"""Attaches a hand to an arm.
|
||||
|
||||
The arm must have a site named "attachment_site".
|
||||
|
||||
Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm
|
||||
|
||||
Args:
|
||||
arm_mjcf: The mjcf.RootElement of the arm.
|
||||
hand_mjcf: The mjcf.RootElement of the hand.
|
||||
|
||||
Raises:
|
||||
ValueError: If the arm does not have a site named "attachment_site".
|
||||
"""
|
||||
physics = mjcf.Physics.from_mjcf_model(hand_mjcf)
|
||||
|
||||
attachment_site = arm_mjcf.find("site", "attachment_site")
|
||||
if attachment_site is None:
|
||||
raise ValueError("No attachment site found in the arm model.")
|
||||
|
||||
# Expand the ctrl and qpos keyframes to account for the new hand DoFs.
|
||||
arm_key = arm_mjcf.find("key", "home")
|
||||
if arm_key is not None:
|
||||
hand_key = hand_mjcf.find("key", "home")
|
||||
if hand_key is None:
|
||||
arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)])
|
||||
arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)])
|
||||
else:
|
||||
arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl])
|
||||
arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos])
|
||||
|
||||
attachment_site.attach(hand_mjcf)
|
||||
|
||||
|
||||
def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None):
|
||||
# assert robot_xml_path.endswith(".xml")
|
||||
|
||||
arena = mjcf.RootElement()
|
||||
arm_simulate = mjcf.from_path(robot_xml_path)
|
||||
# arm_copy = mjcf.from_path(xml_path)
|
||||
|
||||
if gripper_xml_path is not None:
|
||||
# attach gripper to the robot at "attachment_site"
|
||||
gripper_simulate = mjcf.from_path(gripper_xml_path)
|
||||
attach_hand_to_arm(arm_simulate, gripper_simulate)
|
||||
|
||||
arena.worldbody.attach(arm_simulate)
|
||||
# arena.worldbody.attach(arm_copy)
|
||||
|
||||
return arena
|
||||
|
||||
|
||||
class ZMQServerThread(threading.Thread):
|
||||
def __init__(self, server):
|
||||
super().__init__()
|
||||
self._server = server
|
||||
|
||||
def run(self):
|
||||
self._server.serve()
|
||||
|
||||
def terminate(self):
|
||||
self._server.stop()
|
||||
|
||||
|
||||
class ZMQRobotServer:
|
||||
"""A class representing a ZMQ server for a robot."""
|
||||
|
||||
def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556):
|
||||
self._robot = robot
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REP)
|
||||
addr = f"tcp://{host}:{port}"
|
||||
self._socket.bind(addr)
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def serve(self) -> None:
|
||||
"""Serve the robot state and commands over ZMQ."""
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
message = self._socket.recv()
|
||||
request = pickle.loads(message)
|
||||
|
||||
# Call the appropriate method based on the request
|
||||
method = request.get("method")
|
||||
args = request.get("args", {})
|
||||
result: Any
|
||||
if method == "num_dofs":
|
||||
result = self._robot.num_dofs()
|
||||
elif method == "get_joint_state":
|
||||
result = self._robot.get_joint_state()
|
||||
elif method == "command_joint_state":
|
||||
result = self._robot.command_joint_state(**args)
|
||||
elif method == "get_observations":
|
||||
result = self._robot.get_observations()
|
||||
else:
|
||||
result = {"error": "Invalid method"}
|
||||
print(result)
|
||||
raise NotImplementedError(
|
||||
f"Invalid method: {method}, {args, result}"
|
||||
)
|
||||
|
||||
self._socket.send(pickle.dumps(result))
|
||||
except zmq.error.Again:
|
||||
print("Timeout in ZMQLeaderServer serve")
|
||||
# Timeout occurred, check if the stop event is set
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
self._socket.close()
|
||||
self._context.term()
|
||||
|
||||
|
||||
class MujocoRobotServer:
|
||||
def __init__(
|
||||
self,
|
||||
xml_path: str,
|
||||
gripper_xml_path: Optional[str] = None,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 5556,
|
||||
print_joints: bool = False,
|
||||
):
|
||||
self._has_gripper = gripper_xml_path is not None
|
||||
arena = build_scene(xml_path, gripper_xml_path)
|
||||
|
||||
assets: Dict[str, str] = {}
|
||||
for asset in arena.asset.all_children():
|
||||
if asset.tag == "mesh":
|
||||
f = asset.file
|
||||
assets[f.get_vfs_filename()] = asset.file.contents
|
||||
|
||||
xml_string = arena.to_xml_string()
|
||||
# save xml_string to file
|
||||
with open("arena.xml", "w") as f:
|
||||
f.write(xml_string)
|
||||
|
||||
self._model = mujoco.MjModel.from_xml_string(xml_string, assets)
|
||||
self._data = mujoco.MjData(self._model)
|
||||
|
||||
self._num_joints = self._model.nu
|
||||
|
||||
self._joint_state = np.zeros(self._num_joints)
|
||||
self._joint_cmd = self._joint_state
|
||||
|
||||
self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port)
|
||||
self._zmq_server_thread = ZMQServerThread(self._zmq_server)
|
||||
|
||||
self._print_joints = print_joints
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return self._num_joints
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
return self._joint_state
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
assert len(joint_state) == self._num_joints, (
|
||||
f"Expected joint state of length {self._num_joints}, "
|
||||
f"got {len(joint_state)}."
|
||||
)
|
||||
if self._has_gripper:
|
||||
_joint_state = joint_state.copy()
|
||||
_joint_state[-1] = _joint_state[-1] * 255
|
||||
self._joint_cmd = _joint_state
|
||||
else:
|
||||
self._joint_cmd = joint_state.copy()
|
||||
|
||||
def freedrive_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def set_freedrive_mode(self, enable: bool):
|
||||
pass
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
joint_positions = self._data.qpos.copy()[: self._num_joints]
|
||||
joint_velocities = self._data.qvel.copy()[: self._num_joints]
|
||||
ee_site = "attachment_site"
|
||||
try:
|
||||
ee_pos = self._data.site_xpos.copy()[
|
||||
mujoco.mj_name2id(self._model, 6, ee_site)
|
||||
]
|
||||
ee_mat = self._data.site_xmat.copy()[
|
||||
mujoco.mj_name2id(self._model, 6, ee_site)
|
||||
]
|
||||
ee_quat = np.zeros(4)
|
||||
mujoco.mju_mat2Quat(ee_quat, ee_mat)
|
||||
except Exception:
|
||||
ee_pos = np.zeros(3)
|
||||
ee_quat = np.zeros(4)
|
||||
ee_quat[0] = 1
|
||||
gripper_pos = self._data.qpos.copy()[self._num_joints - 1]
|
||||
return {
|
||||
"joint_positions": joint_positions,
|
||||
"joint_velocities": joint_velocities,
|
||||
"ee_pos_quat": np.concatenate([ee_pos, ee_quat]),
|
||||
"gripper_position": gripper_pos,
|
||||
}
|
||||
|
||||
def serve(self) -> None:
|
||||
# start the zmq server
|
||||
self._zmq_server_thread.start()
|
||||
with mujoco.viewer.launch_passive(self._model, self._data) as viewer:
|
||||
while viewer.is_running():
|
||||
step_start = time.time()
|
||||
|
||||
# mj_step can be replaced with code that also evaluates
|
||||
# a policy and applies a control signal before stepping the physics.
|
||||
self._data.ctrl[:] = self._joint_cmd
|
||||
# self._data.qpos[:] = self._joint_cmd
|
||||
mujoco.mj_step(self._model, self._data)
|
||||
self._joint_state = self._data.qpos.copy()[: self._num_joints]
|
||||
|
||||
if self._print_joints:
|
||||
print(self._joint_state)
|
||||
|
||||
# Example modification of a viewer option: toggle contact points every two seconds.
|
||||
with viewer.lock():
|
||||
# TODO remove?
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(
|
||||
self._data.time % 2
|
||||
)
|
||||
|
||||
# Pick up changes to the physics state, apply perturbations, update options from GUI.
|
||||
viewer.sync()
|
||||
|
||||
# Rudimentary time keeping, will drift relative to wall clock.
|
||||
time_until_next_step = self._model.opt.timestep - (
|
||||
time.time() - step_start
|
||||
)
|
||||
if time_until_next_step > 0:
|
||||
time.sleep(time_until_next_step)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._zmq_server_thread.join()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.stop()
|
133
gello/robots/ur.py
Normal file
133
gello/robots/ur.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
|
||||
class URRobot(Robot):
|
||||
"""A class representing a UR robot."""
|
||||
|
||||
def __init__(self, robot_ip: str = "192.168.1.10", no_gripper: bool = False):
|
||||
import rtde_control
|
||||
import rtde_receive
|
||||
|
||||
[print("in ur robot") for _ in range(4)]
|
||||
try:
|
||||
self.robot = rtde_control.RTDEControlInterface(robot_ip)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(robot_ip)
|
||||
|
||||
self.r_inter = rtde_receive.RTDEReceiveInterface(robot_ip)
|
||||
if not no_gripper:
|
||||
from gello.robots.robotiq_gripper import RobotiqGripper
|
||||
|
||||
self.gripper = RobotiqGripper()
|
||||
self.gripper.connect(hostname=robot_ip, port=63352)
|
||||
print("gripper connected")
|
||||
# gripper.activate()
|
||||
|
||||
[print("connect") for _ in range(4)]
|
||||
|
||||
self._free_drive = False
|
||||
self.robot.endFreedriveMode()
|
||||
self._use_gripper = not no_gripper
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
"""Get the number of joints of the robot.
|
||||
|
||||
Returns:
|
||||
int: The number of joints of the robot.
|
||||
"""
|
||||
if self._use_gripper:
|
||||
return 7
|
||||
return 6
|
||||
|
||||
def _get_gripper_pos(self) -> float:
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
gripper_pos = self.gripper.get_current_position()
|
||||
assert 0 <= gripper_pos <= 255, "Gripper position must be between 0 and 255"
|
||||
return gripper_pos / 255
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
"""Get the current state of the leader robot.
|
||||
|
||||
Returns:
|
||||
T: The current state of the leader robot.
|
||||
"""
|
||||
robot_joints = self.r_inter.getActualQ()
|
||||
if self._use_gripper:
|
||||
gripper_pos = self._get_gripper_pos()
|
||||
pos = np.append(robot_joints, gripper_pos)
|
||||
else:
|
||||
pos = robot_joints
|
||||
return pos
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
"""Command the leader robot to a given state.
|
||||
|
||||
Args:
|
||||
joint_state (np.ndarray): The state to command the leader robot to.
|
||||
"""
|
||||
velocity = 0.5
|
||||
acceleration = 0.5
|
||||
dt = 1.0 / 500 # 2ms
|
||||
lookahead_time = 0.2
|
||||
gain = 100
|
||||
|
||||
robot_joints = joint_state[:6]
|
||||
t_start = self.robot.initPeriod()
|
||||
self.robot.servoJ(
|
||||
robot_joints, velocity, acceleration, dt, lookahead_time, gain
|
||||
)
|
||||
if self._use_gripper:
|
||||
gripper_pos = joint_state[-1] * 255
|
||||
self.gripper.move(gripper_pos, 255, 10)
|
||||
self.robot.waitPeriod(t_start)
|
||||
|
||||
def freedrive_enabled(self) -> bool:
|
||||
"""Check if the robot is in freedrive mode.
|
||||
|
||||
Returns:
|
||||
bool: True if the robot is in freedrive mode, False otherwise.
|
||||
"""
|
||||
return self._free_drive
|
||||
|
||||
def set_freedrive_mode(self, enable: bool) -> None:
|
||||
"""Set the freedrive mode of the robot.
|
||||
|
||||
Args:
|
||||
enable (bool): True to enable freedrive mode, False to disable it.
|
||||
"""
|
||||
if enable and not self._free_drive:
|
||||
self._free_drive = True
|
||||
self.robot.freedriveMode()
|
||||
elif not enable and self._free_drive:
|
||||
self._free_drive = False
|
||||
self.robot.endFreedriveMode()
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
joints = self.get_joint_state()
|
||||
pos_quat = np.zeros(7)
|
||||
gripper_pos = np.array([joints[-1]])
|
||||
return {
|
||||
"joint_positions": joints,
|
||||
"joint_velocities": joints,
|
||||
"ee_pos_quat": pos_quat,
|
||||
"gripper_position": gripper_pos,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
robot_ip = "192.168.1.11"
|
||||
ur = URRobot(robot_ip, no_gripper=True)
|
||||
print(ur)
|
||||
ur.set_freedrive_mode(True)
|
||||
print(ur.get_observations())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
358
gello/robots/xarm_robot.py
Normal file
358
gello/robots/xarm_robot.py
Normal file
|
@ -0,0 +1,358 @@
|
|||
import dataclasses
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
|
||||
def _aa_from_quat(quat: np.ndarray) -> np.ndarray:
|
||||
"""Convert a quaternion to an axis-angle representation.
|
||||
|
||||
Args:
|
||||
quat (np.ndarray): The quaternion to convert.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The axis-angle representation of the quaternion.
|
||||
"""
|
||||
assert quat.shape == (4,), "Input quaternion must be a 4D vector."
|
||||
norm = np.linalg.norm(quat)
|
||||
assert norm != 0, "Input quaternion must not be a zero vector."
|
||||
quat = quat / norm # Normalize the quaternion
|
||||
|
||||
Q = Quaternion(w=quat[3], x=quat[0], y=quat[1], z=quat[2])
|
||||
angle = Q.angle
|
||||
axis = Q.axis
|
||||
aa = axis * angle
|
||||
return aa
|
||||
|
||||
|
||||
def _quat_from_aa(aa: np.ndarray) -> np.ndarray:
|
||||
"""Convert an axis-angle representation to a quaternion.
|
||||
|
||||
Args:
|
||||
aa (np.ndarray): The axis-angle representation to convert.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The quaternion representation of the axis-angle.
|
||||
"""
|
||||
assert aa.shape == (3,), "Input axis-angle must be a 3D vector."
|
||||
norm = np.linalg.norm(aa)
|
||||
assert norm != 0, "Input axis-angle must not be a zero vector."
|
||||
axis = aa / norm # Normalize the axis-angle
|
||||
|
||||
Q = Quaternion(axis=axis, angle=norm)
|
||||
quat = np.array([Q.x, Q.y, Q.z, Q.w])
|
||||
return quat
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RobotState:
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
gripper: float
|
||||
j1: float
|
||||
j2: float
|
||||
j3: float
|
||||
j4: float
|
||||
j5: float
|
||||
j6: float
|
||||
j7: float
|
||||
aa: np.ndarray
|
||||
|
||||
@staticmethod
|
||||
def from_robot(
|
||||
cartesian: np.ndarray,
|
||||
joints: np.ndarray,
|
||||
gripper: float,
|
||||
aa: np.ndarray,
|
||||
) -> "RobotState":
|
||||
return RobotState(
|
||||
cartesian[0],
|
||||
cartesian[1],
|
||||
cartesian[2],
|
||||
gripper,
|
||||
joints[0],
|
||||
joints[1],
|
||||
joints[2],
|
||||
joints[3],
|
||||
joints[4],
|
||||
joints[5],
|
||||
joints[6],
|
||||
aa,
|
||||
)
|
||||
|
||||
def cartesian_pos(self) -> np.ndarray:
|
||||
return np.array([self.x, self.y, self.z])
|
||||
|
||||
def quat(self) -> np.ndarray:
|
||||
return _quat_from_aa(self.aa)
|
||||
|
||||
def joints(self) -> np.ndarray:
|
||||
return np.array([self.j1, self.j2, self.j3, self.j4, self.j5, self.j6, self.j7])
|
||||
|
||||
def gripper_pos(self) -> float:
|
||||
return self.gripper
|
||||
|
||||
|
||||
class Rate:
|
||||
def __init__(self, *, duration):
|
||||
self.duration = duration
|
||||
self.last = time.time()
|
||||
|
||||
def sleep(self, duration=None) -> None:
|
||||
duration = self.duration if duration is None else duration
|
||||
assert duration >= 0
|
||||
now = time.time()
|
||||
passed = now - self.last
|
||||
remaining = duration - passed
|
||||
assert passed >= 0
|
||||
if remaining > 0.0001:
|
||||
time.sleep(remaining)
|
||||
self.last = time.time()
|
||||
|
||||
|
||||
class XArmRobot(Robot):
|
||||
GRIPPER_OPEN = 800
|
||||
GRIPPER_CLOSE = 0
|
||||
# MAX_DELTA = 0.2
|
||||
DEFAULT_MAX_DELTA = 0.05
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
return 8
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
state = self.get_state()
|
||||
gripper = state.gripper_pos()
|
||||
all_dofs = np.concatenate([state.joints(), np.array([gripper])])
|
||||
return all_dofs
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
if len(joint_state) == 7:
|
||||
self.set_command(joint_state, None)
|
||||
elif len(joint_state) == 8:
|
||||
self.set_command(joint_state[:7], joint_state[7])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid joint state: {joint_state}, len={len(joint_state)}"
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
if self.robot is not None:
|
||||
self.robot.disconnect()
|
||||
|
||||
if self.command_thread is not None:
|
||||
self.command_thread.join()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip: str = "192.168.1.226",
|
||||
real: bool = True,
|
||||
control_frequency: float = 50.0,
|
||||
max_delta: float = DEFAULT_MAX_DELTA,
|
||||
):
|
||||
print(ip)
|
||||
self.real = real
|
||||
self.max_delta = max_delta
|
||||
if real:
|
||||
from xarm.wrapper import XArmAPI
|
||||
|
||||
self.robot = XArmAPI(ip, is_radian=True)
|
||||
else:
|
||||
self.robot = None
|
||||
|
||||
self._control_frequency = control_frequency
|
||||
self._clear_error_states()
|
||||
self._set_gripper_position(self.GRIPPER_OPEN)
|
||||
|
||||
self.last_state_lock = threading.Lock()
|
||||
self.target_command_lock = threading.Lock()
|
||||
self.last_state = self._update_last_state()
|
||||
self.target_command = {
|
||||
"joints": self.last_state.joints(),
|
||||
"gripper": 0,
|
||||
}
|
||||
self.running = True
|
||||
self.command_thread = None
|
||||
if real:
|
||||
self.command_thread = threading.Thread(target=self._robot_thread)
|
||||
self.command_thread.start()
|
||||
|
||||
def get_state(self) -> RobotState:
|
||||
with self.last_state_lock:
|
||||
return self.last_state
|
||||
|
||||
def set_command(self, joints: np.ndarray, gripper: Optional[float] = None) -> None:
|
||||
with self.target_command_lock:
|
||||
self.target_command = {
|
||||
"joints": joints,
|
||||
"gripper": gripper,
|
||||
}
|
||||
|
||||
def _clear_error_states(self):
|
||||
if self.robot is None:
|
||||
return
|
||||
self.robot.clean_error()
|
||||
self.robot.clean_warn()
|
||||
self.robot.motion_enable(True)
|
||||
time.sleep(1)
|
||||
self.robot.set_mode(1)
|
||||
time.sleep(1)
|
||||
self.robot.set_collision_sensitivity(0)
|
||||
time.sleep(1)
|
||||
self.robot.set_state(state=0)
|
||||
time.sleep(1)
|
||||
self.robot.set_gripper_enable(True)
|
||||
time.sleep(1)
|
||||
self.robot.set_gripper_mode(0)
|
||||
time.sleep(1)
|
||||
self.robot.set_gripper_speed(3000)
|
||||
time.sleep(1)
|
||||
|
||||
def _get_gripper_pos(self) -> float:
|
||||
if self.robot is None:
|
||||
return 0.0
|
||||
code, gripper_pos = self.robot.get_gripper_position()
|
||||
while code != 0 or gripper_pos is None:
|
||||
print(f"Error code {code} in get_gripper_position(). {gripper_pos}")
|
||||
time.sleep(0.001)
|
||||
code, gripper_pos = self.robot.get_gripper_position()
|
||||
if code == 22:
|
||||
self._clear_error_states()
|
||||
|
||||
normalized_gripper_pos = (gripper_pos - self.GRIPPER_OPEN) / (
|
||||
self.GRIPPER_CLOSE - self.GRIPPER_OPEN
|
||||
)
|
||||
return normalized_gripper_pos
|
||||
|
||||
def _set_gripper_position(self, pos: int) -> None:
|
||||
if self.robot is None:
|
||||
return
|
||||
self.robot.set_gripper_position(pos, wait=False)
|
||||
# while self.robot.get_is_moving():
|
||||
# time.sleep(0.01)
|
||||
|
||||
def _robot_thread(self):
|
||||
rate = Rate(
|
||||
duration=1 / self._control_frequency
|
||||
) # command and update rate for robot
|
||||
step_times = []
|
||||
count = 0
|
||||
|
||||
while self.running:
|
||||
s_t = time.time()
|
||||
# update last state
|
||||
self.last_state = self._update_last_state()
|
||||
with self.target_command_lock:
|
||||
joint_delta = np.array(
|
||||
self.target_command["joints"] - self.last_state.joints()
|
||||
)
|
||||
gripper_command = self.target_command["gripper"]
|
||||
|
||||
norm = np.linalg.norm(joint_delta)
|
||||
|
||||
# threshold delta to be at most 0.01 in norm space
|
||||
if norm > self.max_delta:
|
||||
delta = joint_delta / norm * self.max_delta
|
||||
else:
|
||||
delta = joint_delta
|
||||
|
||||
# command position
|
||||
self._set_position(
|
||||
self.last_state.joints() + delta,
|
||||
)
|
||||
|
||||
if gripper_command is not None:
|
||||
set_point = gripper_command
|
||||
self._set_gripper_position(
|
||||
self.GRIPPER_OPEN
|
||||
+ set_point * (self.GRIPPER_CLOSE - self.GRIPPER_OPEN)
|
||||
)
|
||||
self.last_state = self._update_last_state()
|
||||
|
||||
rate.sleep()
|
||||
step_times.append(time.time() - s_t)
|
||||
count += 1
|
||||
if count % 1000 == 0:
|
||||
# Mean, Std, Min, Max, only show 3 decimal places and string pad with 10 spaces
|
||||
frequency = 1 / np.mean(step_times)
|
||||
# print(f"Step time - mean: {np.mean(step_times):10.3f}, std: {np.std(step_times):10.3f}, min: {np.min(step_times):10.3f}, max: {np.max(step_times):10.3f}")
|
||||
print(
|
||||
f"Low Level Frequency - mean: {frequency:10.3f}, std: {np.std(frequency):10.3f}, min: {np.min(frequency):10.3f}, max: {np.max(frequency):10.3f}"
|
||||
)
|
||||
step_times = []
|
||||
|
||||
def _update_last_state(self) -> RobotState:
|
||||
with self.last_state_lock:
|
||||
if self.robot is None:
|
||||
return RobotState(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, np.zeros(3))
|
||||
|
||||
gripper_pos = self._get_gripper_pos()
|
||||
|
||||
code, servo_angle = self.robot.get_servo_angle(is_radian=True)
|
||||
while code != 0:
|
||||
print(f"Error code {code} in get_servo_angle().")
|
||||
self._clear_error_states()
|
||||
code, servo_angle = self.robot.get_servo_angle(is_radian=True)
|
||||
|
||||
code, cart_pos = self.robot.get_position_aa(is_radian=True)
|
||||
while code != 0:
|
||||
print(f"Error code {code} in get_position().")
|
||||
self._clear_error_states()
|
||||
code, cart_pos = self.robot.get_position_aa(is_radian=True)
|
||||
|
||||
cart_pos = np.array(cart_pos)
|
||||
aa = cart_pos[3:]
|
||||
cart_pos[:3] /= 1000
|
||||
|
||||
return RobotState.from_robot(
|
||||
cart_pos,
|
||||
servo_angle,
|
||||
gripper_pos,
|
||||
aa,
|
||||
)
|
||||
|
||||
def _set_position(
|
||||
self,
|
||||
joints: np.ndarray,
|
||||
) -> None:
|
||||
if self.robot is None:
|
||||
return
|
||||
# threhold xyz to be in min max
|
||||
ret = self.robot.set_servo_angle_j(joints, wait=False, is_radian=True)
|
||||
if ret in [1, 9]:
|
||||
self._clear_error_states()
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
state = self.get_state()
|
||||
pos_quat = np.concatenate([state.cartesian_pos(), state.quat()])
|
||||
joints = self.get_joint_state()
|
||||
return {
|
||||
"joint_positions": joints, # rotational joint + gripper state
|
||||
"joint_velocities": joints,
|
||||
"ee_pos_quat": pos_quat,
|
||||
"gripper_position": np.array(state.gripper_pos()),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
ip = "192.168.1.226"
|
||||
robot = XArmRobot(ip)
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
print(robot.get_state())
|
||||
|
||||
time.sleep(1)
|
||||
print(robot.get_state())
|
||||
print("end")
|
||||
robot.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
69
gello/zmq_core/camera_node.py
Normal file
69
gello/zmq_core/camera_node.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import pickle
|
||||
import threading
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from gello.cameras.camera import CameraDriver
|
||||
|
||||
DEFAULT_CAMERA_PORT = 5000
|
||||
|
||||
|
||||
class ZMQClientCamera(CameraDriver):
|
||||
"""A class representing a ZMQ client for a leader robot."""
|
||||
|
||||
def __init__(self, port: int = DEFAULT_CAMERA_PORT, host: str = "127.0.0.1"):
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REQ)
|
||||
self._socket.connect(f"tcp://{host}:{port}")
|
||||
|
||||
def read(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get the current state of the leader robot.
|
||||
|
||||
Returns:
|
||||
T: The current state of the leader robot.
|
||||
"""
|
||||
# pack the image_size and send it to the server
|
||||
send_message = pickle.dumps(img_size)
|
||||
self._socket.send(send_message)
|
||||
state_dict = pickle.loads(self._socket.recv())
|
||||
return state_dict
|
||||
|
||||
|
||||
class ZMQServerCamera:
|
||||
def __init__(
|
||||
self,
|
||||
camera: CameraDriver,
|
||||
port: int = DEFAULT_CAMERA_PORT,
|
||||
host: str = "127.0.0.1",
|
||||
):
|
||||
self._camera = camera
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REP)
|
||||
addr = f"tcp://{host}:{port}"
|
||||
debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}"
|
||||
print(debug_message)
|
||||
self._timout_message = f"Timeout in Camera Server, Camera: {camera}"
|
||||
self._socket.bind(addr)
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def serve(self) -> None:
|
||||
"""Serve the leader robot state over ZMQ."""
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
message = self._socket.recv()
|
||||
img_size = pickle.loads(message)
|
||||
camera_read = self._camera.read(img_size)
|
||||
self._socket.send(pickle.dumps(camera_read))
|
||||
except zmq.Again:
|
||||
print(self._timout_message)
|
||||
# Timeout occurred, check if the stop event is set
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the server to stop serving."""
|
||||
self._stop_event.set()
|
125
gello/zmq_core/robot_node.py
Normal file
125
gello/zmq_core/robot_node.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
import pickle
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from gello.robots.robot import Robot
|
||||
|
||||
DEFAULT_ROBOT_PORT = 6000
|
||||
|
||||
|
||||
class ZMQServerRobot:
|
||||
def __init__(
|
||||
self,
|
||||
robot: Robot,
|
||||
port: int = DEFAULT_ROBOT_PORT,
|
||||
host: str = "127.0.0.1",
|
||||
):
|
||||
self._robot = robot
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REP)
|
||||
addr = f"tcp://{host}:{port}"
|
||||
debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}"
|
||||
print(debug_message)
|
||||
self._timout_message = f"Timeout in Robot Server, Robot: {robot}"
|
||||
self._socket.bind(addr)
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def serve(self) -> None:
|
||||
"""Serve the leader robot state over ZMQ."""
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Wait for next request from client
|
||||
message = self._socket.recv()
|
||||
request = pickle.loads(message)
|
||||
|
||||
# Call the appropriate method based on the request
|
||||
method = request.get("method")
|
||||
args = request.get("args", {})
|
||||
result: Any
|
||||
if method == "num_dofs":
|
||||
result = self._robot.num_dofs()
|
||||
elif method == "get_joint_state":
|
||||
result = self._robot.get_joint_state()
|
||||
elif method == "command_joint_state":
|
||||
result = self._robot.command_joint_state(**args)
|
||||
elif method == "get_observations":
|
||||
result = self._robot.get_observations()
|
||||
else:
|
||||
result = {"error": "Invalid method"}
|
||||
print(result)
|
||||
raise NotImplementedError(
|
||||
f"Invalid method: {method}, {args, result}"
|
||||
)
|
||||
|
||||
self._socket.send(pickle.dumps(result))
|
||||
except zmq.Again:
|
||||
print(self._timout_message)
|
||||
# Timeout occurred, check if the stop event is set
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the server to stop serving."""
|
||||
self._stop_event.set()
|
||||
|
||||
|
||||
class ZMQClientRobot(Robot):
|
||||
"""A class representing a ZMQ client for a leader robot."""
|
||||
|
||||
def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"):
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REQ)
|
||||
self._socket.connect(f"tcp://{host}:{port}")
|
||||
|
||||
def num_dofs(self) -> int:
|
||||
"""Get the number of joints in the robot.
|
||||
|
||||
Returns:
|
||||
int: The number of joints in the robot.
|
||||
"""
|
||||
request = {"method": "num_dofs"}
|
||||
send_message = pickle.dumps(request)
|
||||
self._socket.send(send_message)
|
||||
result = pickle.loads(self._socket.recv())
|
||||
return result
|
||||
|
||||
def get_joint_state(self) -> np.ndarray:
|
||||
"""Get the current state of the leader robot.
|
||||
|
||||
Returns:
|
||||
T: The current state of the leader robot.
|
||||
"""
|
||||
request = {"method": "get_joint_state"}
|
||||
send_message = pickle.dumps(request)
|
||||
self._socket.send(send_message)
|
||||
result = pickle.loads(self._socket.recv())
|
||||
return result
|
||||
|
||||
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
||||
"""Command the leader robot to the given state.
|
||||
|
||||
Args:
|
||||
joint_state (T): The state to command the leader robot to.
|
||||
"""
|
||||
request = {
|
||||
"method": "command_joint_state",
|
||||
"args": {"joint_state": joint_state},
|
||||
}
|
||||
send_message = pickle.dumps(request)
|
||||
self._socket.send(send_message)
|
||||
result = pickle.loads(self._socket.recv())
|
||||
return result
|
||||
|
||||
def get_observations(self) -> Dict[str, np.ndarray]:
|
||||
"""Get the current observations of the leader robot.
|
||||
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: The current observations of the leader robot.
|
||||
"""
|
||||
request = {"method": "get_observations"}
|
||||
send_message = pickle.dumps(request)
|
||||
self._socket.send(send_message)
|
||||
result = pickle.loads(self._socket.recv())
|
||||
return result
|
BIN
imgs/gello_matching_joints.jpg
Normal file
BIN
imgs/gello_matching_joints.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 289 KiB |
BIN
imgs/robot_known_configuration.jpg
Normal file
BIN
imgs/robot_known_configuration.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 359 KiB |
BIN
imgs/title.png
Normal file
BIN
imgs/title.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 269 KiB |
1
kill_nodes.sh
Executable file
1
kill_nodes.sh
Executable file
|
@ -0,0 +1 @@
|
|||
pkill -9 -f launch_nodes.py
|
8
mypy.ini
Normal file
8
mypy.ini
Normal file
|
@ -0,0 +1,8 @@
|
|||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
namespace_packages = True
|
||||
no_implicit_optional = True
|
||||
exclude = third_party/
|
||||
|
||||
[mypy-dynamixel_sdk.*]
|
||||
ignore_missing_imports = True
|
12
pyrightconfig.json
Normal file
12
pyrightconfig.json
Normal file
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"include": [
|
||||
"gello"
|
||||
],
|
||||
"exclude": [
|
||||
"**/node_modules",
|
||||
"src/typestubs"
|
||||
],
|
||||
"reportMissingImports": false,
|
||||
"reportMissingTypeStubs": false,
|
||||
"reportGeneralTypeIssues": false
|
||||
}
|
17
requirements.txt
Normal file
17
requirements.txt
Normal file
|
@ -0,0 +1,17 @@
|
|||
# keep in alphabetical order
|
||||
dm_control
|
||||
Pillow
|
||||
pygame
|
||||
pyspacemouse
|
||||
PyQt6
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pure-python-adb
|
||||
quaternion
|
||||
tyro
|
||||
ur-rtde
|
||||
zmq
|
||||
xarm
|
||||
xarm-python-sdk
|
||||
numpy-quaternion
|
||||
termcolor
|
13
requirements_dev.txt
Normal file
13
requirements_dev.txt
Normal file
|
@ -0,0 +1,13 @@
|
|||
# keep in alphabetical order
|
||||
# for development
|
||||
black
|
||||
flake8
|
||||
flake8-docstrings
|
||||
isort
|
||||
ipdb
|
||||
jupyterlab
|
||||
mypy
|
||||
neovim
|
||||
pyright
|
||||
pytest
|
||||
python-lsp-server[all]
|
59
scripts/arm_blocks_play.py
Normal file
59
scripts/arm_blocks_play.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
from dm_control import composer, viewer
|
||||
|
||||
from gello.agents.gello_agent import DynamixelRobotConfig
|
||||
from gello.dm_control_tasks.arms.ur5e import UR5e
|
||||
from gello.dm_control_tasks.manipulation.arenas.floors import Floor
|
||||
from gello.dm_control_tasks.manipulation.tasks.block_play import BlockPlay
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
use_gello: bool = False
|
||||
|
||||
|
||||
config = DynamixelRobotConfig(
|
||||
joint_ids=(1, 2, 3, 4, 5, 6),
|
||||
joint_offsets=(
|
||||
-np.pi / 2,
|
||||
1 * np.pi / 2 + np.pi,
|
||||
np.pi / 2 + 0 * np.pi,
|
||||
0 * np.pi + np.pi / 2,
|
||||
np.pi - 2 * np.pi / 2,
|
||||
-1 * np.pi / 2 + 2 * np.pi,
|
||||
),
|
||||
joint_signs=(1, 1, -1, 1, 1, 1),
|
||||
gripper_config=(7, 20, -22),
|
||||
)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
reset_joints_left = np.deg2rad([90, -90, -90, -90, 90, 0, 0])
|
||||
robot = UR5e()
|
||||
task = BlockPlay(robot, Floor(), reset_joints=reset_joints_left[:-1])
|
||||
# task = BlockPlay(robot, Floor())
|
||||
env = composer.Environment(task=task)
|
||||
|
||||
action_space = env.action_spec()
|
||||
if args.use_gello:
|
||||
gello = config.make_robot(
|
||||
port="/dev/cu.usbserial-FT7WBEIA", start_joints=reset_joints_left
|
||||
)
|
||||
|
||||
def policy(timestep) -> np.ndarray:
|
||||
if args.use_gello:
|
||||
joint_command = gello.get_joint_state()
|
||||
joint_command = np.array(joint_command).copy()
|
||||
|
||||
joint_command[-1] = joint_command[-1] * 255
|
||||
return joint_command
|
||||
return np.random.uniform(action_space.minimum, action_space.maximum)
|
||||
|
||||
viewer.launch(env, policy=policy)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
98
scripts/gello_get_offset.py
Normal file
98
scripts/gello_get_offset.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
|
||||
from gello.dynamixel.driver import DynamixelDriver
|
||||
|
||||
MENAGERIE_ROOT: Path = Path(__file__).parent / "third_party" / "mujoco_menagerie"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
port: str = "/dev/ttyUSB0"
|
||||
"""The port that GELLO is connected to."""
|
||||
|
||||
start_joints: Tuple[float, ...] = (0, 0, 0, 0, 0, 0)
|
||||
"""The joint angles that the GELLO is placed in at (in radians)."""
|
||||
|
||||
joint_signs: Tuple[float, ...] = (1, 1, -1, 1, 1, 1)
|
||||
"""The joint angles that the GELLO is placed in at (in radians)."""
|
||||
|
||||
gripper: bool = True
|
||||
"""Whether or not the gripper is attached."""
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.joint_signs) == len(self.start_joints)
|
||||
for idx, j in enumerate(self.joint_signs):
|
||||
assert (
|
||||
j == -1 or j == 1
|
||||
), f"Joint idx: {idx} should be -1 or 1, but got {j}."
|
||||
|
||||
@property
|
||||
def num_robot_joints(self) -> int:
|
||||
return len(self.start_joints)
|
||||
|
||||
@property
|
||||
def num_joints(self) -> int:
|
||||
extra_joints = 1 if self.gripper else 0
|
||||
return self.num_robot_joints + extra_joints
|
||||
|
||||
|
||||
def get_config(args: Args) -> None:
|
||||
joint_ids = list(range(1, args.num_joints + 1))
|
||||
driver = DynamixelDriver(joint_ids, port=args.port, baudrate=57600)
|
||||
|
||||
# assume that the joint state shouold be args.start_joints
|
||||
# find the offset, which is a multiple of np.pi/2 that minimizes the error between the current joint state and args.start_joints
|
||||
# this is done by brute force, we seach in a range of +/- 8pi
|
||||
|
||||
def get_error(offset: float, index: int, joint_state: np.ndarray) -> float:
|
||||
joint_sign_i = args.joint_signs[index]
|
||||
joint_i = joint_sign_i * (joint_state[index] - offset)
|
||||
start_i = args.start_joints[index]
|
||||
return np.abs(joint_i - start_i)
|
||||
|
||||
for _ in range(10):
|
||||
driver.get_joints() # warmup
|
||||
|
||||
for _ in range(1):
|
||||
best_offsets = []
|
||||
curr_joints = driver.get_joints()
|
||||
for i in range(args.num_robot_joints):
|
||||
best_offset = 0
|
||||
best_error = 1e6
|
||||
for offset in np.linspace(
|
||||
-8 * np.pi, 8 * np.pi, 8 * 4 + 1
|
||||
): # intervals of pi/2
|
||||
error = get_error(offset, i, curr_joints)
|
||||
if error < best_error:
|
||||
best_error = error
|
||||
best_offset = offset
|
||||
best_offsets.append(best_offset)
|
||||
print()
|
||||
print("best offsets : ", [f"{x:.3f}" for x in best_offsets])
|
||||
print(
|
||||
"best offsets function of pi: ["
|
||||
+ ", ".join([f"{int(np.round(x/(np.pi/2)))}*np.pi/2" for x in best_offsets])
|
||||
+ " ]",
|
||||
)
|
||||
if args.gripper:
|
||||
print(
|
||||
"gripper open (degrees) ",
|
||||
np.rad2deg(driver.get_joints()[-1]) - 0.2,
|
||||
)
|
||||
print(
|
||||
"gripper close (degrees) ",
|
||||
np.rad2deg(driver.get_joints()[-1]) - 42,
|
||||
)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
get_config(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(tyro.cli(Args))
|
39
scripts/launch.py
Normal file
39
scripts/launch.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
import subprocess
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
|
||||
def run_docker_container():
|
||||
user = os.getenv("USER")
|
||||
container_name = f"gello_{user}"
|
||||
gello_path = os.path.abspath(os.path.join(current_file_path, "../../"))
|
||||
volume_mapping = f"{gello_path}:/gello"
|
||||
|
||||
cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"--runtime=nvidia",
|
||||
"--rm",
|
||||
"--name",
|
||||
container_name,
|
||||
"--privileged",
|
||||
"--volume",
|
||||
volume_mapping,
|
||||
"--volume",
|
||||
"/home/gello:/homefolder",
|
||||
"--net=host",
|
||||
"--volume",
|
||||
"/dev/serial/by-id/:/dev/serial/by-id/",
|
||||
"-it",
|
||||
"gello:latest",
|
||||
"bash",
|
||||
"-c",
|
||||
"pip install -e third_party/DynamixelSDK/python && exec bash",
|
||||
]
|
||||
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_docker_container()
|
41
scripts/visualize_example.py
Normal file
41
scripts/visualize_example.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
|
||||
|
||||
def main():
|
||||
_PROJECT_ROOT: Path = Path(__file__).parent.parent
|
||||
_MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie"
|
||||
xml = _MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml"
|
||||
|
||||
# xml = _MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml"
|
||||
|
||||
m = mujoco.MjModel.from_xml_path(xml.as_posix())
|
||||
d = mujoco.MjData(m)
|
||||
|
||||
with mujoco.viewer.launch_passive(m, d) as viewer:
|
||||
# Close the viewer automatically after 30 wall-seconds.
|
||||
while viewer.is_running():
|
||||
step_start = time.time()
|
||||
|
||||
# mj_step can be replaced with code that also evaluates
|
||||
# a policy and applies a control signal before stepping the physics.
|
||||
mujoco.mj_step(m, d)
|
||||
|
||||
# Example modification of a viewer option: toggle contact points every two seconds.
|
||||
with viewer.lock():
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(d.time % 2)
|
||||
|
||||
# Pick up changes to the physics state, apply perturbations, update options from GUI.
|
||||
viewer.sync()
|
||||
|
||||
# Rudimentary time keeping, will drift relative to wall clock.
|
||||
time_until_next_step = m.opt.timestep - (time.time() - step_start)
|
||||
if time_until_next_step > 0:
|
||||
time.sleep(time_until_next_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
25
setup.py
Normal file
25
setup.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import setuptools
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setuptools.setup(
|
||||
name="gello",
|
||||
version="0.0.1",
|
||||
author="Philipp Wu",
|
||||
author_email="philippwu@berkeley.edu",
|
||||
description="software for GELLO",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/wuphilipp/gello_software",
|
||||
packages=setuptools.find_packages(),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
],
|
||||
python_requires=">=3.8",
|
||||
license="MIT",
|
||||
install_requires=[
|
||||
"numpy",
|
||||
],
|
||||
)
|
1
third_party/DynamixelSDK
vendored
Submodule
1
third_party/DynamixelSDK
vendored
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 3450c7078917b262d9b36042c15444047aae226e
|
1
third_party/mujoco_menagerie
vendored
Submodule
1
third_party/mujoco_menagerie
vendored
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit c118f3d5d5ec9fb27f832ac78b3c4971234f0f4f
|
Loading…
Add table
Add a link
Reference in a new issue