diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..ca5f8de --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +ignore = D100, D101, D102, D103, D104, D105, D107, E203, E501, W503, SIM201, SIM113, B027, C408, B008 +docstring-convention = google diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml new file mode 100644 index 0000000..5db5e14 --- /dev/null +++ b/.github/workflows/pythonapp.yml @@ -0,0 +1,44 @@ +name: Python application + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + schedule: + - cron: "0 0 * * 0" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + # python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.10"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements_dev.txt + pip install -r requirements.txt + pip install -e . + git submodule init + git submodule update + pip install third_party/DynamixelSDK/python + + - name: Black + run: | + black --check gello + - name: flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 gello --count --select=E9,F63,F7,F82 --show-source --statistics + - name: Test with pytest + run: | + pytest gello diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1b141e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# keep each top level section alphabetical +# keep each item within the sections alphabetical + +# large files +*.png +*.gif + +# mac +*.DS_Store + +# misc +*logs* +*runs* +*tb_logs* +*wandb* +*wandb_logs* +outputs/* + +# python +*__pycache__ +*.egg-info +*.hypothesis +*.ipynb_checkpoints +*.mypy_cache +*.pyc + +# vim +*.swp +MUJOCO_LOG.TXT diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..b8c7d8f --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "third_party/mujoco_menagerie"] + path = third_party/mujoco_menagerie + url = https://github.com/deepmind/mujoco_menagerie.git +[submodule "third_party/DynamixelSDK"] + path = third_party/DynamixelSDK + url = https://github.com/ROBOTIS-GIT/DynamixelSDK.git diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..04bab0f --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,6 @@ +[settings] +use_parentheses=True +include_trailing_comma=True +multi_line_output=3 +ensure_newline_before_comments=True +line_length=88 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7c92e72 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,38 @@ +repos: + +# remove unused python imports +- repo: https://github.com/myint/autoflake.git + rev: v2.1.1 + hooks: + - id: autoflake + args: ["--in-place", "--remove-all-unused-imports", "--ignore-init-module-imports"] + +# sort imports +- repo: https://github.com/timothycrosley/isort + rev: 5.12.0 + hooks: + - id: isort + +# code format according to black +- repo: https://github.com/ambv/black + rev: 23.3.0 + hooks: + - id: black + +# check for python styling with flake8 +- repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + additional_dependencies: [ + 'flake8-docstrings', + 'flake8-bugbear', + 'flake8-comprehensions', + 'flake8-simplify', + ] + +# cleanup notebooks +- repo: https://github.com/kynan/nbstripout + rev: 0.6.1 + hooks: + - id: nbstripout diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e3f44ea --- /dev/null +++ b/Dockerfile @@ -0,0 +1,22 @@ +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 + +WORKDIR /gello + +# Set environment variables first (less likely to change) +ENV PYTHONPATH=/gello:/gello/third_party/oculus_reader/ + +# Group apt updates and installs together +RUN apt update && apt install -y \ + libhidapi-dev \ + python3-pip \ + android-tools-adb \ + libegl1-mesa-dev && \ + rm -rf /var/lib/apt/lists/* + + +# Python alias setup +RUN echo "alias python=python3" >> ~/.bashrc + +# Install Python dependencies +COPY requirements.txt /gello +RUN pip install -r requirements.txt \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7b3eeaa --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Philipp Wu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 1bf2160..e929fb6 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,174 @@ -# GELLO Software -This is the central repo that holds the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/ +# GELLO +This is the central repo that holds the all the software for GELLO. See the website for the paper and other resources for GELLO https://wuphilipp.github.io/gello_site/ +See the GELLO hardware repo for the STL files and hardware instructions for building your own GELLO https://github.com/wuphilipp/gello_mechanical +``` +git clone https://github.com/wuphilipp/gello_software.git +cd gello +``` -# Installation +

+ +

-Coming Soon +## Use your own enviroment +``` +git submodule init +git submodule update +pip install -r requirements.txt +pip install -e . +pip install -e third_party/DynamixelSDK/python +``` + +## Use with Docker +First install ```docker``` following this [link](https://docs.docker.com/engine/install/ubuntu/) on your host machine. +Then you can clone the repo and build the corresponding docker environment + +Build the docker image and tag it as gello:latest. If you are going to name it differently, you need to change the launch.py image name +``` +docker build . -t gello:latest +``` + +We have provided an entry point into the docker container +``` +python experiments/launch.py +``` + +# GELLO configuration setup (PLEASE READ) +Now that you have downloaded the code, there is some additional preparation work to properly configure the Dynamixels and GELLO. +These instructions will guide you on how to update the motor ids of the Dynamixels and then how to extract the joint offsets to configure your GELLO. + +## Update motor IDs +Install the [dynamixel_wizard](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2/). +By default, each motor has the ID 1. In order for multiple dynamixels to be controlled by the same U2D2 controller board, each dynamixel must have a unique ID. +This process must be done one motor at a time. Connect each motor, starting from the base motor, and assign them in increasing order until you reach the gripper. + +Steps: + * Connect a single motor to the controller and connect the controller to the computer. + * Open the dynamixel wizard + * Click scan (found at the top left corner), this should detect the dynamixel. Connect to the motor + * Look for the ID address and change the ID to the appropriate number. + * Repeat for each motor + +## Create the GELLO configuration and determining joint ID's +After the motor ID's are set, we can now connect to the GELLO controller device. However each motor has its own joint offset, which will result in a joint offset between GELLO and your actual robot arm. +Dynamixels have a symmetric 4 hole pattern which means there the joint offset is a multiple of pi/2. +The `GelloAgent` class accepts a `DynamixelRobotConfig` (found in `gello/agents/gello_agent.py`). The Dynamixel config specifies the parameters you need to find to operate your GELLO. Look at the documentation for more details. + +We have created a simple script to automatically detect the joint offset: +* set GELLO into a known configuration, where you know what the corresponding joint angles should be. For example, we set out GELLO in this configuration, where we know the desired ground truth joints. (0, -90, 90, -90, -90, 0) +

+ + +

+ +* run +``` +python scripts/gello_get_offset.py \ + --start-joints 0 -1.57 1.57 -1.57 -1.57 0 \ # in radians + --joint-signs 1 1 -1 1 1 1 \ + --port /dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6 +# replace values with your own +``` +* Use the known starting joints for `start-joints`. +* Use the `joint-signs` for your own robot (see below). +* Use your serial port for `port`. You can find the port id of your U2D2 Dynamixel device by running `ls /dev/serial/by-id` and looking for the path that starts with `usb-FTDI_USB__-__Serial_Converter` (on Ubuntu). On Mac, look in /dev/ and the device that starts with `cu.usbserial` + +`joint-signs` for each robot type: +* UR: `1 1 -1 1 1 1` +* Panda: `1 -1 1 1 1 -1 1` +* xArm: `1 1 1 1 1 1 1` + +The script prints out a list of joint offsets. Go to `gello/agents/gello_agent.py` and add a DynamixelRobotConfig to the PORT_CONFIG_MAP. You are now ready to run your GELLO! + +# Using GELLO to control a robot! + +The code provided here is simple and only relies on python packages. The code does NOT use ROS, but a ROS wrapper can easily be adapted from this code. +For multiprocessing, we leverage [ZMQ](https://zeromq.org/) + +## Testing in sim +First test your GELLO with a simulated robot to make sure that the joint angles match as expected. +In one terminal run +``` +python experiments/launch_nodes.py --robot +``` +This launched the robot node. A simulated robot using the mujoco viewer should appear. + +Then, launch your GELLO (the controller node). +``` +python experiments/run_env.py --agent=gello +``` +You should be able to use GELLO to control the simulated robot! + +## Running on a real robot. +Once you have verified that your GELLO is properly configured, you can test it on a real robot! + +Before you run with the real robot, you will have to install a robot specific python package. +The supported robots are in `gello/robots`. + * UR: [ur_rtde](https://sdurobotics.gitlab.io/ur_rtde/installation/installation.html) + * panda: [polymetis](https://facebookresearch.github.io/fairo/polymetis/installation.html). If you use a different framework to control the panda, the code is easy to adpot. See/Modify `gello/robots/panda.py` + * xArm: [xArm python SDK](https://github.com/xArm-Developer/xArm-Python-SDK) + +``` +# Launch all of the node +python experiments/launch_nodes.py --robot= +# run the enviroment loop +python experiments/run_env.py --agent=gello +``` + +Ideally you can start your GELLO near a known configuration each time. If this is possible, you can set the `--start-joint` flag with GELLO's known starting configuration. This also enables the robot to reset before you begin teleoperation. + +## Collect data +We have provided a simple example for collecting data with gello. +To save trajectories with the keyboard, add the following flag `--use-save-interface` + +Data can then be processed using the demo_to_gdict script. +``` +python gello/data_utils/demo_to_gdict.py --source-dir= +``` + +## Running a bimanual system with GELLO +GELLO also be used in bimanual configurations. +For an example, see the `bimanual_ur` robot in `launch_nodes.py` and `--bimanual` flag in the `run_env.py` script. + +## Notes +Due to the use of multiprocessing, sometimes python process are not killed properly. We have provided the kill_nodes script which will kill the +python processes. +``` +./kill_nodes.sh +``` + +### Using a new robot! +If you want to use a new robot you need a GELLO that is compatible. If the kiniamtics are close enough, you may directly use an existing GELLO. Otherwise you will have to design your own. +To add a new robot, simply implement the `Robot` protocol found in `gello/robots/robot`. See `gello/robots/panda.py`, `gello/robots/ur.py`, `gello/robots/xarm_robot.py` for examples. + +### Contributing +Please make a PR if you would like to contribute! The goal of this project is to enable more accessible and higher quality teleoperation devices and we would love your input! + +You can optionally install some dev packages. +``` +pip install -r requirements_dev.txt +``` + +The code is organized as follows: + * `scripts`: contains some helpful python `scripts` + * `experiments`: contains entrypoints into the gello code + * `gello`: contains all of the `gello` python package code + * `agents`: teleoperation agents + * `cameras`: code to interface with camera hardware + * `data_utils`: data processing utils. used for imitation learning + * `dm_control_tasks`: dm_control utils to build a simple dm_control enviroment. used for demos + * `dynamixel`: code to interface with the dynamixel hardware + * `robots`: robot specific interfaces + * `zmq_core`: zmq utilities for enabling a multi node system + + +This code base uses `isort` and `black` for code formatting. +pre-commits hooks are great. This will automatically do some checking/formatting. To use the pre-commit hooks, run the following: +``` +pip install pre-commit +pre-commit install +``` # Citation @@ -15,3 +179,11 @@ Coming Soon year={2023}, } ``` + +# License & Acknowledgements +This source code is licensed under the MIT license found in the LICENSE file. in the root directory of this source tree. + +This project builds on top of or utilizes the following third party dependencies. + * [google-deepmind/mujoco_menagerie](https://github.com/google-deepmind/mujoco_menagerie): Prebuilt robot models for mujoco + * [brentyi/tyro](https://github.com/brentyi/tyro): Argument parsing and configuration + * [ZMQ](https://zeromq.org/): Enables easy create of node like processes in python. diff --git a/config_hostmachine.sh b/config_hostmachine.sh new file mode 100644 index 0000000..60051b8 --- /dev/null +++ b/config_hostmachine.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# install nvidia driver 520 +sudo apt-get install nvidia-520 -y + +# install docker +sudo apt-get update +sudo apt-get install ca-certificates curl gnupg + +sudo install -m 0755 -d /etc/apt/keyrings +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg +sudo chmod a+r /etc/apt/keyrings/docker.gpg + +echo \ + "deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + "$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + +sudo apt-get update +sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin diff --git a/experiments/launch_camera_clients.py b/experiments/launch_camera_clients.py new file mode 100644 index 0000000..a89efb7 --- /dev/null +++ b/experiments/launch_camera_clients.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import tyro + +from gello.zmq_core.camera_node import ZMQClientCamera + + +@dataclass +class Args: + ports: Tuple[int, ...] = (5000, 5001) + hostname: str = "127.0.0.1" + # hostname: str = "128.32.175.167" + + +def main(args): + cameras = [] + import cv2 + + images_display_names = [] + for port in args.ports: + cameras.append(ZMQClientCamera(port=port, host=args.hostname)) + images_display_names.append(f"image_{port}") + cv2.namedWindow(images_display_names[-1], cv2.WINDOW_NORMAL) + + while True: + for display_name, camera in zip(images_display_names, cameras): + image, depth = camera.read() + stacked_depth = np.dstack([depth, depth, depth]).astype(np.uint8) + image_depth = cv2.hconcat([image[:, :, ::-1], stacked_depth]) + cv2.imshow(display_name, image_depth) + cv2.waitKey(1) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/experiments/launch_camera_nodes.py b/experiments/launch_camera_nodes.py new file mode 100644 index 0000000..b5c6168 --- /dev/null +++ b/experiments/launch_camera_nodes.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from multiprocessing import Process + +import tyro + +from gello.cameras.realsense_camera import RealSenseCamera, get_device_ids +from gello.zmq_core.camera_node import ZMQServerCamera + + +@dataclass +class Args: + # hostname: str = "127.0.0.1" + hostname: str = "128.32.175.167" + + +def launch_server(port: int, camera_id: int, args: Args): + camera = RealSenseCamera(camera_id) + server = ZMQServerCamera(camera, port=port, host=args.hostname) + print(f"Starting camera server on port {port}") + server.serve() + + +def main(args): + ids = get_device_ids() + camera_port = 5000 + camera_servers = [] + for camera_id in ids: + # start a python process for each camera + print(f"Launching camera {camera_id} on port {camera_port}") + camera_servers.append( + Process(target=launch_server, args=(camera_port, camera_id, args)) + ) + camera_port += 1 + + for server in camera_servers: + server.start() + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/experiments/launch_nodes.py b/experiments/launch_nodes.py new file mode 100644 index 0000000..9251711 --- /dev/null +++ b/experiments/launch_nodes.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from pathlib import Path + +import tyro + +from gello.robots.robot import BimanualRobot, PrintRobot +from gello.zmq_core.robot_node import ZMQServerRobot + + +@dataclass +class Args: + robot: str = "xarm" + robot_port: int = 6001 + hostname: str = "127.0.0.1" + robot_ip: str = "192.168.1.10" + + +def launch_robot_server(args: Args): + port = args.robot_port + if args.robot == "sim_ur": + MENAGERIE_ROOT: Path = ( + Path(__file__).parent.parent / "third_party" / "mujoco_menagerie" + ) + xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml" + gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml" + from gello.robots.sim_robot import MujocoRobotServer + + server = MujocoRobotServer( + xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname + ) + server.serve() + elif args.robot == "sim_panda": + from gello.robots.sim_robot import MujocoRobotServer + + MENAGERIE_ROOT: Path = ( + Path(__file__).parent.parent / "third_party" / "mujoco_menagerie" + ) + xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml" + gripper_xml = None + server = MujocoRobotServer( + xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname + ) + server.serve() + elif args.robot == "sim_xarm": + from gello.robots.sim_robot import MujocoRobotServer + + MENAGERIE_ROOT: Path = ( + Path(__file__).parent.parent / "third_party" / "mujoco_menagerie" + ) + xml = MENAGERIE_ROOT / "ufactory_xarm7" / "xarm7.xml" + gripper_xml = None + server = MujocoRobotServer( + xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname + ) + server.serve() + + else: + if args.robot == "xarm": + from gello.robots.xarm_robot import XArmRobot + + robot = XArmRobot(ip=args.robot_ip) + elif args.robot == "ur": + from gello.robots.ur import URRobot + + robot = URRobot(robot_ip=args.robot_ip) + elif args.robot == "panda": + from gello.robots.panda import PandaRobot + + robot = PandaRobot(robot_ip=args.robot_ip) + elif args.robot == "bimanual_ur": + from gello.robots.ur import URRobot + + # IP for the bimanual robot setup is hardcoded + _robot_l = URRobot(robot_ip="192.168.2.10") + _robot_r = URRobot(robot_ip="192.168.1.10") + robot = BimanualRobot(_robot_l, _robot_r) + elif args.robot == "none" or args.robot == "print": + robot = PrintRobot(8) + + else: + raise NotImplementedError( + f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none" + ) + server = ZMQServerRobot(robot, port=port, host=args.hostname) + print(f"Starting robot server on port {port}") + server.serve() + + +def main(args): + launch_robot_server(args) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/experiments/quick_run.py b/experiments/quick_run.py new file mode 100644 index 0000000..9d1fc90 --- /dev/null +++ b/experiments/quick_run.py @@ -0,0 +1,192 @@ +import atexit +import glob +import time +from dataclasses import dataclass +from multiprocessing import Process +from pathlib import Path +from typing import Optional + +import numpy as np +import tyro + +from gello.agents.agent import DummyAgent +from gello.agents.gello_agent import GelloAgent +from gello.agents.spacemouse_agent import SpacemouseAgent +from gello.env import RobotEnv +from gello.zmq_core.robot_node import ZMQClientRobot, ZMQServerRobot + + +@dataclass +class Args: + hz: int = 100 + + agent: str = "gello" + robot: str = "ur5" + gello_port: Optional[str] = None + mock: bool = False + verbose: bool = False + + hostname: str = "127.0.0.1" + robot_port: int = 6001 + + +def launch_robot_server(port: int, args: Args): + if args.robot == "sim_ur": + MENAGERIE_ROOT: Path = ( + Path(__file__).parent.parent / "third_party" / "mujoco_menagerie" + ) + xml = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml" + gripper_xml = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml" + from gello.robots.sim_robot import MujocoRobotServer + + server = MujocoRobotServer( + xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname + ) + server.serve() + elif args.robot == "sim_panda": + from gello.robots.sim_robot import MujocoRobotServer + + MENAGERIE_ROOT: Path = ( + Path(__file__).parent.parent / "third_party" / "mujoco_menagerie" + ) + xml = MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml" + gripper_xml = None + server = MujocoRobotServer( + xml_path=xml, gripper_xml_path=gripper_xml, port=port, host=args.hostname + ) + server.serve() + + else: + if args.robot == "xarm": + from gello.robots.xarm_robot import XArmRobot + + robot = XArmRobot() + elif args.robot == "ur5": + from gello.robots.ur import URRobot + + robot = URRobot(robot_ip=args.robot_ip) + else: + raise NotImplementedError( + f"Robot {args.robot} not implemented, choose one of: sim_ur, xarm, ur, bimanual_ur, none" + ) + server = ZMQServerRobot(robot, port=port, host=args.hostname) + print(f"Starting robot server on port {port}") + server.serve() + + +def start_robot_process(args: Args): + process = Process(target=launch_robot_server, args=(args.robot_port, args)) + + # Function to kill the child process + def kill_child_process(process): + print("Killing child process...") + process.terminate() + + # Register the kill_child_process function to be called at exit + atexit.register(kill_child_process, process) + process.start() + + +def main(args: Args): + start_robot_process(args) + + robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname) + env = RobotEnv(robot_client, control_rate_hz=args.hz) + + if args.agent == "gello": + gello_port = args.gello_port + if gello_port is None: + usb_ports = glob.glob("/dev/serial/by-id/*") + print(f"Found {len(usb_ports)} ports") + if len(usb_ports) > 0: + gello_port = usb_ports[0] + print(f"using port {gello_port}") + else: + raise ValueError( + "No gello port found, please specify one or plug in gello" + ) + agent = GelloAgent(port=gello_port) + + reset_joints = np.array([0, 0, 0, -np.pi, 0, np.pi, 0, 0]) + curr_joints = env.get_obs()["joint_positions"] + if reset_joints.shape == curr_joints.shape: + max_delta = (np.abs(curr_joints - reset_joints)).max() + steps = min(int(max_delta / 0.01), 100) + + for jnt in np.linspace(curr_joints, reset_joints, steps): + env.step(jnt) + time.sleep(0.001) + + elif args.agent == "quest": + from gello.agents.quest_agent import SingleArmQuestAgent + + agent = SingleArmQuestAgent(robot_type=args.robot, which_hand="l") + elif args.agent == "spacemouse": + agent = SpacemouseAgent(robot_type=args.robot, verbose=args.verbose) + elif args.agent == "dummy" or args.agent == "none": + agent = DummyAgent(num_dofs=robot_client.num_dofs()) + else: + raise ValueError("Invalid agent name") + + # going to start position + print("Going to start position") + start_pos = agent.act(env.get_obs()) + obs = env.get_obs() + joints = obs["joint_positions"] + + abs_deltas = np.abs(start_pos - joints) + id_max_joint_delta = np.argmax(abs_deltas) + + max_joint_delta = 0.8 + if abs_deltas[id_max_joint_delta] > max_joint_delta: + id_mask = abs_deltas > max_joint_delta + print() + ids = np.arange(len(id_mask))[id_mask] + for i, delta, joint, current_j in zip( + ids, + abs_deltas[id_mask], + start_pos[id_mask], + joints[id_mask], + ): + print( + f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}" + ) + return + + print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}") + assert len(start_pos) == len( + joints + ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}" + + max_delta = 0.05 + for _ in range(25): + obs = env.get_obs() + command_joints = agent.act(obs) + current_joints = obs["joint_positions"] + delta = command_joints - current_joints + max_joint_delta = np.abs(delta).max() + if max_joint_delta > max_delta: + delta = delta / max_joint_delta * max_delta + env.step(current_joints + delta) + + obs = env.get_obs() + joints = obs["joint_positions"] + action = agent.act(obs) + if (action - joints > 0.5).any(): + print("Action is too big") + + # print which joints are too big + joint_index = np.where(action - joints > 0.8) + for j in joint_index: + print( + f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}" + ) + exit() + + while True: + action = agent.act(obs) + obs = env.step(action) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/experiments/run_env.py b/experiments/run_env.py new file mode 100644 index 0000000..bd2319a --- /dev/null +++ b/experiments/run_env.py @@ -0,0 +1,246 @@ +import datetime +import glob +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import numpy as np +import tyro + +from gello.agents.agent import BimanualAgent, DummyAgent +from gello.agents.gello_agent import GelloAgent +from gello.data_utils.format_obs import save_frame +from gello.env import RobotEnv +from gello.robots.robot import PrintRobot +from gello.zmq_core.robot_node import ZMQClientRobot + + +def print_color(*args, color=None, attrs=(), **kwargs): + import termcolor + + if len(args) > 0: + args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args) + print(*args, **kwargs) + + +@dataclass +class Args: + agent: str = "none" + robot_port: int = 6001 + wrist_camera_port: int = 5000 + base_camera_port: int = 5001 + hostname: str = "127.0.0.1" + robot_type: str = None # only needed for quest agent or spacemouse agent + hz: int = 100 + start_joints: Optional[Tuple[float, ...]] = None + + gello_port: Optional[str] = None + mock: bool = False + use_save_interface: bool = False + data_dir: str = "~/bc_data" + bimanual: bool = False + verbose: bool = False + + +def main(args): + if args.mock: + robot_client = PrintRobot(8, dont_print=True) + camera_clients = {} + else: + camera_clients = { + # you can optionally add camera nodes here for imitation learning purposes + # "wrist": ZMQClientCamera(port=args.wrist_camera_port, host=args.hostname), + # "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname), + } + robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname) + env = RobotEnv(robot_client, control_rate_hz=args.hz, camera_dict=camera_clients) + + if args.bimanual: + if args.agent == "gello": + # dynamixel control box port map (to distinguish left and right gello) + right = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0" + left = "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0" + left_agent = GelloAgent(port=left) + right_agent = GelloAgent(port=right) + agent = BimanualAgent(left_agent, right_agent) + elif args.agent == "quest": + from gello.agents.quest_agent import SingleArmQuestAgent + + left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l") + right_agent = SingleArmQuestAgent( + robot_type=args.robot_type, which_hand="r" + ) + agent = BimanualAgent(left_agent, right_agent) + # raise NotImplementedError + elif args.agent == "spacemouse": + from gello.agents.spacemouse_agent import SpacemouseAgent + + left_path = "/dev/hidraw0" + right_path = "/dev/hidraw1" + left_agent = SpacemouseAgent( + robot_type=args.robot_type, device_path=left_path, verbose=args.verbose + ) + right_agent = SpacemouseAgent( + robot_type=args.robot_type, + device_path=right_path, + verbose=args.verbose, + invert_button=True, + ) + agent = BimanualAgent(left_agent, right_agent) + else: + raise ValueError(f"Invalid agent name for bimanual: {args.agent}") + + # System setup specific. This reset configuration works well on our setup. If you are mounting the robot + # differently, you need a separate reset joint configuration. + reset_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0, 0]) + reset_joints_right = np.deg2rad([0, -90, 90, -90, -90, 0, 0]) + reset_joints = np.concatenate([reset_joints_left, reset_joints_right]) + curr_joints = env.get_obs()["joint_positions"] + max_delta = (np.abs(curr_joints - reset_joints)).max() + steps = min(int(max_delta / 0.01), 100) + + for jnt in np.linspace(curr_joints, reset_joints, steps): + env.step(jnt) + else: + if args.agent == "gello": + gello_port = args.gello_port + if gello_port is None: + usb_ports = glob.glob("/dev/serial/by-id/*") + print(f"Found {len(usb_ports)} ports") + if len(usb_ports) > 0: + gello_port = usb_ports[0] + print(f"using port {gello_port}") + else: + raise ValueError( + "No gello port found, please specify one or plug in gello" + ) + if args.start_joints is None: + reset_joints = np.deg2rad( + [0, -90, 90, -90, -90, 0, 0] + ) # Change this to your own reset joints + else: + reset_joints = args.start_joints + agent = GelloAgent(port=gello_port, start_joints=args.start_joints) + curr_joints = env.get_obs()["joint_positions"] + if reset_joints.shape == curr_joints.shape: + max_delta = (np.abs(curr_joints - reset_joints)).max() + steps = min(int(max_delta / 0.01), 100) + + for jnt in np.linspace(curr_joints, reset_joints, steps): + env.step(jnt) + time.sleep(0.001) + elif args.agent == "quest": + from gello.agents.quest_agent import SingleArmQuestAgent + + agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l") + elif args.agent == "spacemouse": + from gello.agents.spacemouse_agent import SpacemouseAgent + + agent = SpacemouseAgent(robot_type=args.robot_type, verbose=args.verbose) + elif args.agent == "dummy" or args.agent == "none": + agent = DummyAgent(num_dofs=robot_client.num_dofs()) + elif args.agent == "policy": + raise NotImplementedError("add your imitation policy here if there is one") + else: + raise ValueError("Invalid agent name") + + # going to start position + print("Going to start position") + start_pos = agent.act(env.get_obs()) + obs = env.get_obs() + joints = obs["joint_positions"] + + abs_deltas = np.abs(start_pos - joints) + id_max_joint_delta = np.argmax(abs_deltas) + + max_joint_delta = 0.8 + if abs_deltas[id_max_joint_delta] > max_joint_delta: + id_mask = abs_deltas > max_joint_delta + print() + ids = np.arange(len(id_mask))[id_mask] + for i, delta, joint, current_j in zip( + ids, + abs_deltas[id_mask], + start_pos[id_mask], + joints[id_mask], + ): + print( + f"joint[{i}]: \t delta: {delta:4.3f} , leader: \t{joint:4.3f} , follower: \t{current_j:4.3f}" + ) + return + + print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}") + assert len(start_pos) == len( + joints + ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}" + + max_delta = 0.05 + for _ in range(25): + obs = env.get_obs() + command_joints = agent.act(obs) + current_joints = obs["joint_positions"] + delta = command_joints - current_joints + max_joint_delta = np.abs(delta).max() + if max_joint_delta > max_delta: + delta = delta / max_joint_delta * max_delta + env.step(current_joints + delta) + + obs = env.get_obs() + joints = obs["joint_positions"] + action = agent.act(obs) + if (action - joints > 0.5).any(): + print("Action is too big") + + # print which joints are too big + joint_index = np.where(action - joints > 0.8) + for j in joint_index: + print( + f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}" + ) + exit() + + if args.use_save_interface: + from gello.data_utils.keyboard_interface import KBReset + + kb_interface = KBReset() + + print_color("\nStart 🚀🚀🚀", color="green", attrs=("bold",)) + + save_path = None + start_time = time.time() + while True: + num = time.time() - start_time + message = f"\rTime passed: {round(num, 2)} " + print_color( + message, + color="white", + attrs=("bold",), + end="", + flush=True, + ) + action = agent.act(obs) + dt = datetime.datetime.now() + if args.use_save_interface: + state = kb_interface.update() + if state == "start": + dt_time = datetime.datetime.now() + save_path = ( + Path(args.data_dir).expanduser() + / args.agent + / dt_time.strftime("%m%d_%H%M%S") + ) + save_path.mkdir(parents=True, exist_ok=True) + print(f"Saving to {save_path}") + elif state == "save": + assert save_path is not None, "something went wrong" + save_frame(save_path, dt, obs, action) + elif state == "normal": + save_path = None + else: + raise ValueError(f"Invalid state {state}") + obs = env.step(action) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/gello/__init__.py b/gello/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gello/agents/agent.py b/gello/agents/agent.py new file mode 100644 index 0000000..f07457e --- /dev/null +++ b/gello/agents/agent.py @@ -0,0 +1,43 @@ +from typing import Any, Dict, Protocol + +import numpy as np + + +class Agent(Protocol): + def act(self, obs: Dict[str, Any]) -> np.ndarray: + """Returns an action given an observation. + + Args: + obs: observation from the environment. + + Returns: + action: action to take on the environment. + """ + raise NotImplementedError + + +class DummyAgent(Agent): + def __init__(self, num_dofs: int): + self.num_dofs = num_dofs + + def act(self, obs: Dict[str, Any]) -> np.ndarray: + return np.zeros(self.num_dofs) + + +class BimanualAgent(Agent): + def __init__(self, agent_left: Agent, agent_right: Agent): + self.agent_left = agent_left + self.agent_right = agent_right + + def act(self, obs: Dict[str, Any]) -> np.ndarray: + left_obs = {} + right_obs = {} + for key, val in obs.items(): + L = val.shape[0] + half_dim = L // 2 + assert L == half_dim * 2, f"{key} must be even, something is wrong" + left_obs[key] = val[:half_dim] + right_obs[key] = val[half_dim:] + return np.concatenate( + [self.agent_left.act(left_obs), self.agent_right.act(right_obs)] + ) diff --git a/gello/agents/gello_agent.py b/gello/agents/gello_agent.py new file mode 100644 index 0000000..fca60c3 --- /dev/null +++ b/gello/agents/gello_agent.py @@ -0,0 +1,139 @@ +import os +from dataclasses import dataclass +from typing import Dict, Optional, Sequence, Tuple + +import numpy as np + +from gello.agents.agent import Agent +from gello.robots.dynamixel import DynamixelRobot + + +@dataclass +class DynamixelRobotConfig: + joint_ids: Sequence[int] + """The joint ids of GELLO (not including the gripper). Usually (1, 2, 3 ...).""" + + joint_offsets: Sequence[float] + """The joint offsets of GELLO. There needs to be a joint offset for each joint_id and should be a multiple of pi/2.""" + + joint_signs: Sequence[int] + """The joint signs of GELLO. There needs to be a joint sign for each joint_id and should be either 1 or -1. + + This will be different for each arm design. Refernce the examples below for the correct signs for your robot. + """ + + gripper_config: Tuple[int, int, int] + """The gripper config of GELLO. This is a tuple of (gripper_joint_id, degrees in open_position, degrees in closed_position).""" + + def __post_init__(self): + assert len(self.joint_ids) == len(self.joint_offsets) + assert len(self.joint_ids) == len(self.joint_signs) + + def make_robot( + self, port: str = "/dev/ttyUSB0", start_joints: Optional[np.ndarray] = None + ) -> DynamixelRobot: + return DynamixelRobot( + joint_ids=self.joint_ids, + joint_offsets=list(self.joint_offsets), + real=True, + joint_signs=list(self.joint_signs), + port=port, + gripper_config=self.gripper_config, + start_joints=start_joints, + ) + + +PORT_CONFIG_MAP: Dict[str, DynamixelRobotConfig] = { + # xArm + # "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig( + # joint_ids=(1, 2, 3, 4, 5, 6, 7), + # joint_offsets=( + # 2 * np.pi / 2, + # 2 * np.pi / 2, + # 2 * np.pi / 2, + # 2 * np.pi / 2, + # -1 * np.pi / 2 + 2 * np.pi, + # 1 * np.pi / 2, + # 1 * np.pi / 2, + # ), + # joint_signs=(1, 1, 1, 1, 1, 1, 1), + # gripper_config=(8, 279, 279 - 50), + # ), + # panda + # "/dev/cu.usbserial-FT3M9NVB": DynamixelRobotConfig( + "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT3M9NVB-if00-port0": DynamixelRobotConfig( + joint_ids=(1, 2, 3, 4, 5, 6, 7), + joint_offsets=( + 3 * np.pi / 2, + 2 * np.pi / 2, + 1 * np.pi / 2, + 4 * np.pi / 2, + -2 * np.pi / 2 + 2 * np.pi, + 3 * np.pi / 2, + 4 * np.pi / 2, + ), + joint_signs=(1, -1, 1, 1, 1, -1, 1), + gripper_config=(8, 195, 152), + ), + # Left UR + "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBEIA-if00-port0": DynamixelRobotConfig( + joint_ids=(1, 2, 3, 4, 5, 6), + joint_offsets=( + 0, + 1 * np.pi / 2 + np.pi, + np.pi / 2 + 0 * np.pi, + 0 * np.pi + np.pi / 2, + np.pi - 2 * np.pi / 2, + -1 * np.pi / 2 + 2 * np.pi, + ), + joint_signs=(1, 1, -1, 1, 1, 1), + gripper_config=(7, 20, -22), + ), + # Right UR + "/dev/serial/by-id/usb-FTDI_USB__-__Serial_Converter_FT7WBG6A-if00-port0": DynamixelRobotConfig( + joint_ids=(1, 2, 3, 4, 5, 6), + joint_offsets=( + np.pi + 0 * np.pi, + 2 * np.pi + np.pi / 2, + 2 * np.pi + np.pi / 2, + 2 * np.pi + np.pi / 2, + 1 * np.pi, + 3 * np.pi / 2, + ), + joint_signs=(1, 1, -1, 1, 1, 1), + gripper_config=(7, 286, 248), + ), +} + + +class GelloAgent(Agent): + def __init__( + self, + port: str, + dynamixel_config: Optional[DynamixelRobotConfig] = None, + start_joints: Optional[np.ndarray] = None, + ): + if dynamixel_config is not None: + self._robot = dynamixel_config.make_robot( + port=port, start_joints=start_joints + ) + else: + assert os.path.exists(port), port + assert port in PORT_CONFIG_MAP, f"Port {port} not in config map" + + config = PORT_CONFIG_MAP[port] + self._robot = config.make_robot(port=port, start_joints=start_joints) + + def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray: + return self._robot.get_joint_state() + dyna_joints = self._robot.get_joint_state() + # current_q = dyna_joints[:-1] # last one dim is the gripper + current_gripper = dyna_joints[-1] # last one dim is the gripper + + print(current_gripper) + if current_gripper < 0.2: + self._robot.set_torque_mode(False) + return obs["joint_positions"] + else: + self._robot.set_torque_mode(False) + return dyna_joints diff --git a/gello/agents/quest_agent.py b/gello/agents/quest_agent.py new file mode 100644 index 0000000..be7af59 --- /dev/null +++ b/gello/agents/quest_agent.py @@ -0,0 +1,178 @@ +from typing import Dict + +import numpy as np +import quaternion +from dm_control import mjcf +from dm_control.utils.inverse_kinematics import qpos_from_site_pose +from oculus_reader.reader import OculusReader + +from gello.agents.agent import Agent +from gello.agents.spacemouse_agent import apply_transfer, mj2ur, ur2mj +from gello.dm_control_tasks.arms.ur5e import UR5e + +# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on +# our setup, for details please checkout the project page. +quest2ur = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) +ur2quest = np.linalg.inv(quest2ur) + +translation_scaling_factor = 2.0 + + +class SingleArmQuestAgent(Agent): + def __init__(self, robot_type: str, which_hand: str, verbose: bool = False) -> None: + """Interact with the robot using the quest controller. + + leftTrig: press to start control (also record the current position as the home position) + leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper + """ + self.which_hand = which_hand + assert self.which_hand in ["l", "r"] + + self.oculus_reader = OculusReader() + if robot_type == "ur5": + _robot = UR5e() + else: + raise ValueError(f"Unknown robot type: {robot_type}") + self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model) + self.control_active = False + self.reference_quest_pose = None + self.reference_ee_rot_ur = None + self.reference_ee_pos_ur = None + + self.robot_type = robot_type + self._verbose = verbose + + def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray: + if self.robot_type == "ur5": + num_dof = 6 + current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper + current_gripper_angle = obs["joint_positions"][-1] + # run the fk + self.physics.data.qpos[:num_dof] = current_qpos + self.physics.step() + + ee_rot_mj = np.array( + self.physics.named.data.site_xmat["attachment_site"] + ).reshape(3, 3) + ee_pos_mj = np.array(self.physics.named.data.site_xpos["attachment_site"]) + if self.which_hand == "l": + pose_key = "l" + trigger_key = "leftTrig" + # joystick_key = "leftJS" + # left yx + gripper_open_key = "Y" + gripper_close_key = "X" + elif self.which_hand == "r": + pose_key = "r" + trigger_key = "rightTrig" + # joystick_key = "rightJS" + # right ba for the key + gripper_open_key = "B" + gripper_close_key = "A" + else: + raise ValueError(f"Unknown hand: {self.which_hand}") + # check the trigger button state + pose_data, button_data = self.oculus_reader.get_transformations_and_buttons() + if len(pose_data) == 0 or len(button_data) == 0: + print("no data, quest not yet ready") + return np.concatenate([current_qpos, [current_gripper_angle]]) + + new_gripper_angle = current_gripper_angle + if button_data[gripper_open_key]: + new_gripper_angle = 1 + if button_data[gripper_close_key]: + new_gripper_angle = 0 + arm_not_move_return = np.concatenate([current_qpos, [new_gripper_angle]]) + if len(pose_data) == 0: + print("no data, quest not yet ready") + return arm_not_move_return + + trigger_state = button_data[trigger_key][0] + if trigger_state > 0.5: + if self.control_active is True: + if self._verbose: + print("controlling the arm") + current_pose = pose_data[pose_key] + delta_rot = current_pose[:3, :3] @ np.linalg.inv( + self.reference_quest_pose[:3, :3] + ) + delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3] + delta_pos_ur = ( + apply_transfer(quest2ur, delta_pos) * translation_scaling_factor + ) + # ? is this the case? + delta_rot_ur = quest2ur[:3, :3] @ delta_rot @ ur2quest[:3, :3] + if self._verbose: + print( + f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}" + ) + next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur + next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur + + target_quat = quaternion.as_float_array( + quaternion.from_rotation_matrix(ur2mj[:3, :3] @ next_ee_rot_ur) + ) + ik_result = qpos_from_site_pose( + self.physics, + "attachment_site", + target_pos=apply_transfer(ur2mj, next_ee_pos_ur), + target_quat=target_quat, + tol=1e-14, + max_steps=400, + ) + self.physics.reset() + if ik_result.success: + new_qpos = ik_result.qpos[:num_dof] + else: + print("ik failed, using the original qpos") + return arm_not_move_return + command = np.concatenate([new_qpos, [new_gripper_angle]]) + return command + + else: # last state is not in active + self.control_active = True + if self._verbose: + print("control activated!") + self.reference_quest_pose = pose_data[pose_key] + + self.reference_ee_rot_ur = mj2ur[:3, :3] @ ee_rot_mj + self.reference_ee_pos_ur = apply_transfer(mj2ur, ee_pos_mj) + return arm_not_move_return + else: + if self._verbose: + print("deactive control") + self.control_active = False + self.reference_quest_pose = None + return arm_not_move_return + + +class DualArmQuestAgent(Agent): + def __init__(self, robot_type: str) -> None: + self.left_arm = SingleArmQuestAgent(robot_type, "l") + self.right_arm = SingleArmQuestAgent(robot_type, "r") + + def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray: + pass + + +if __name__ == "__main__": + oculus_reader = OculusReader() + while True: + """ + example output: + ({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ], + [-0.107737 , 0.0958919, 0.989544 , -0.833478 ], + [ 0.549685 , 0.835106 , -0.0210789, -0.892425 ], + [ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ], + [ 0.070887, 0.516083, -0.8536 , -0.238691], + [-0.941994, -0.246809, -0.227447, -0.370447], + [ 0. , 0. , 0. , 1. ]])}, + {'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)}) + + """ + pose_data, button_data = oculus_reader.get_transformations_and_buttons() + if len(pose_data) == 0: + print("no data") + continue + else: + print(pose_data["l"]) diff --git a/gello/agents/spacemouse_agent.py b/gello/agents/spacemouse_agent.py new file mode 100644 index 0000000..03a42f1 --- /dev/null +++ b/gello/agents/spacemouse_agent.py @@ -0,0 +1,224 @@ +import threading +import time +from dataclasses import dataclass +from typing import Dict, Optional + +import numpy as np +from dm_control import mjcf +from dm_control.utils.inverse_kinematics import qpos_from_site_pose + +from gello.agents.agent import Agent +from gello.dm_control_tasks.arms.ur5e import UR5e + +# mujoco has a slightly different coordinate system than UR control box +mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +ur2mj = np.linalg.inv(mj2ur) + +# cartensian space control, controller <> robot relative pose matters. This extrinsics is based on +# our setup, for details please checkout the project page. +spacemouse2ur = np.array( + [ + [-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] +) +ur2spacemouse = np.linalg.inv(spacemouse2ur) + + +def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray: + # xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix + if len(xyz) == 3: + xyz = np.append(xyz, 1) + return np.matmul(mat, xyz)[:3] + + +@dataclass +class SpacemouseConfig: + angle_scale: float = 0.24 + translation_scale: float = 0.06 + # only control the xyz, rotation direction, not the gripper + invert_control: np.ndarray = np.ones(6) + rotation_mode: str = "euler" + + +class SpacemouseAgent(Agent): + def __init__( + self, + robot_type: str, + config: SpacemouseConfig = SpacemouseConfig(), + device_path: Optional[str] = None, + verbose: bool = True, + invert_button: bool = False, + ) -> None: + self.config = config + self.last_state_lock = threading.Lock() + self._invert_button = invert_button + # example state:SpaceNavigator(t=3.581528532, x=0.0, y=0.0, z=0.0, roll=0.0, pitch=0.0, yaw=0.0, buttons=[0, 0]) + # all continuous inputs range from 0-1, buttons are 0 or 1 + self.spacemouse_latest_state = None + self._device_path = device_path + spacemouse_thread = threading.Thread(target=self._read_from_spacemouse) + spacemouse_thread.start() + self._verbose = verbose + if self._verbose: + print(f"robot_type: {robot_type}") + if robot_type == "ur5": + _robot = UR5e() + else: + raise ValueError(f"Unknown robot type: {robot_type}") + self.physics = mjcf.Physics.from_mjcf_model(_robot.mjcf_model) + + def _read_from_spacemouse(self): + import pyspacemouse + + if self._device_path is None: + mouse = pyspacemouse.open() + else: + mouse = pyspacemouse.open(path=self._device_path) + if mouse: + while 1: + state = mouse.read() + with self.last_state_lock: + self.spacemouse_latest_state = state + time.sleep(0.001) + else: + raise ValueError("Failed to open spacemouse") + + def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray: + import quaternion + + # obs: the folllow robot's current state + # in rad, 6/7 dof depends on the robot type + num_dof = 6 + if self._verbose: + print("act invoked") + current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper + current_gripper_angle = obs["joint_positions"][-1] + self.physics.data.qpos[:num_dof] = current_qpos + + self.physics.step() + ee_rot = np.array(self.physics.named.data.site_xmat["attachment_site"]).reshape( + 3, 3 + ) + ee_pos = np.array(self.physics.named.data.site_xpos["attachment_site"]) + + ee_rot = mj2ur[:3, :3] @ ee_rot + ee_pos = apply_transfer(mj2ur, ee_pos) + # ^ mujoco coordinate to UR + + with self.last_state_lock: + spacemouse_read = self.spacemouse_latest_state + if self._verbose: + print(f"spacemouse_read: {spacemouse_read}") + + assert spacemouse_read is not None + spacemouse_xyz_rot_np = np.array( + [ + spacemouse_read.x, + spacemouse_read.y, + spacemouse_read.z, + spacemouse_read.roll, + spacemouse_read.pitch, + spacemouse_read.yaw, + ] + ) + spacemouse_button = ( + spacemouse_read.buttons + ) # size 2 list binary indicating left/right button pressing + spacemouse_xyz_rot_np = spacemouse_xyz_rot_np * self.config.invert_control + if np.max(np.abs(spacemouse_xyz_rot_np)) > 0.9: + spacemouse_xyz_rot_np[np.abs(spacemouse_xyz_rot_np) < 0.6] = 0 + tx, ty, tz, r, p, y = spacemouse_xyz_rot_np + # convert roll pick yaw to rotation matrix (rpy) + + trans_transform = np.eye(4) + # delta xyz from the spacemouse reading + trans_transform[:3, 3] = apply_transfer( + spacemouse2ur, np.array([tx, ty, tz]) * self.config.translation_scale + ) + + # break rot_transform into each axis + rot_transform_x = np.eye(4) + rot_transform_x[:3, :3] = quaternion.as_rotation_matrix( + quaternion.from_rotation_vector( + np.array([-p, 0, 0]) * self.config.angle_scale + ) + ) + + rot_transform_y = np.eye(4) + rot_transform_y[:3, :3] = quaternion.as_rotation_matrix( + quaternion.from_rotation_vector( + np.array([0, r, 0]) * self.config.angle_scale + ) + ) + + rot_transform_z = np.eye(4) + rot_transform_z[:3, :3] = quaternion.as_rotation_matrix( + quaternion.from_rotation_vector( + np.array([0, 0, -y]) * self.config.angle_scale + ) + ) + + # in ur space + rot_transform = ( + spacemouse2ur + @ rot_transform_z + @ rot_transform_y + @ rot_transform_x + @ ur2spacemouse + ) + + if self._verbose: + print(f"rot_transform: {rot_transform}") + print(f"new spacemounse cmd in ur space = {trans_transform[:3, 3]}") + + # import pdb; pdb.set_trace() + new_ee_pos = trans_transform[:3, 3] + ee_pos + if self.config.rotation_mode == "rpy": + new_ee_rot = ee_rot @ rot_transform[:3, :3] + elif self.config.rotation_mode == "euler": + new_ee_rot = rot_transform[:3, :3] @ ee_rot + else: + raise NotImplementedError( + f"Unknown rotation mode: {self.config.rotation_mode}" + ) + + target_quat = quaternion.as_float_array( + quaternion.from_rotation_matrix(ur2mj[:3, :3] @ new_ee_rot) + ) + ik_result = qpos_from_site_pose( + self.physics, + "attachment_site", + target_pos=apply_transfer(ur2mj, new_ee_pos), + target_quat=target_quat, + tol=1e-14, + max_steps=400, + ) + self.physics.reset() + if ik_result.success: + new_qpos = ik_result.qpos[:num_dof] + else: + print("ik failed, using the original qpos") + return np.concatenate([current_qpos, [current_gripper_angle]]) + new_gripper_angle = current_gripper_angle + if self._invert_button: + if spacemouse_button[1]: + new_gripper_angle = 1 + if spacemouse_button[0]: + new_gripper_angle = 0 + else: + if spacemouse_button[1]: + new_gripper_angle = 0 + if spacemouse_button[0]: + new_gripper_angle = 1 + command = np.concatenate([new_qpos, [new_gripper_angle]]) + return command + + +if __name__ == "__main__": + import pyspacemouse + + success = pyspacemouse.open("/dev/hidraw4") + success = pyspacemouse.open("/dev/hidraw5") diff --git a/gello/cameras/camera.py b/gello/cameras/camera.py new file mode 100644 index 0000000..6a4de19 --- /dev/null +++ b/gello/cameras/camera.py @@ -0,0 +1,81 @@ +from pathlib import Path +from typing import Optional, Protocol, Tuple + +import numpy as np + + +class CameraDriver(Protocol): + """Camera protocol. + + A protocol for a camera driver. This is used to abstract the camera from the rest of the code. + """ + + def read( + self, + img_size: Optional[Tuple[int, int]] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Read a frame from the camera. + + Args: + img_size: The size of the image to return. If None, the original size is returned. + farthest: The farthest distance to map to 255. + + Returns: + np.ndarray: The color image. + np.ndarray: The depth image. + """ + + +class DummyCamera(CameraDriver): + """A dummy camera for testing.""" + + def read( + self, + img_size: Optional[Tuple[int, int]] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Read a frame from the camera. + + Args: + img_size: The size of the image to return. If None, the original size is returned. + farthest: The farthest distance to map to 255. + + Returns: + np.ndarray: The color image. + np.ndarray: The depth image. + """ + if img_size is None: + return ( + np.random.randint(255, size=(480, 640, 3), dtype=np.uint8), + np.random.randint(255, size=(480, 640, 1), dtype=np.uint16), + ) + else: + return ( + np.random.randint( + 255, size=(img_size[0], img_size[1], 3), dtype=np.uint8 + ), + np.random.randint( + 255, size=(img_size[0], img_size[1], 1), dtype=np.uint16 + ), + ) + + +class SavedCamera(CameraDriver): + def __init__(self, path: str = "example"): + self.path = str(Path(__file__).parent / path) + from PIL import Image + + self._color_img = Image.open(f"{self.path}/image.png") + self._depth_img = Image.open(f"{self.path}/depth.png") + + def read( + self, + img_size: Optional[Tuple[int, int]] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + if img_size is not None: + color_img = self._color_img.resize(img_size) + depth_img = self._depth_img.resize(img_size) + else: + color_img = self._color_img + depth_img = self._depth_img + + return np.array(color_img), np.array(depth_img)[:, :, 0:1] diff --git a/gello/cameras/realsense_camera.py b/gello/cameras/realsense_camera.py new file mode 100644 index 0000000..1fd72b8 --- /dev/null +++ b/gello/cameras/realsense_camera.py @@ -0,0 +1,123 @@ +import os +import time +from typing import List, Optional, Tuple + +import numpy as np + +from gello.cameras.camera import CameraDriver + + +def get_device_ids() -> List[str]: + import pyrealsense2 as rs + + ctx = rs.context() + devices = ctx.query_devices() + device_ids = [] + for dev in devices: + dev.hardware_reset() + device_ids.append(dev.get_info(rs.camera_info.serial_number)) + time.sleep(2) + return device_ids + + +class RealSenseCamera(CameraDriver): + def __repr__(self) -> str: + return f"RealSenseCamera(device_id={self._device_id})" + + def __init__(self, device_id: Optional[str] = None, flip: bool = False): + import pyrealsense2 as rs + + self._device_id = device_id + + if device_id is None: + ctx = rs.context() + devices = ctx.query_devices() + for dev in devices: + dev.hardware_reset() + time.sleep(2) + self._pipeline = rs.pipeline() + config = rs.config() + else: + self._pipeline = rs.pipeline() + config = rs.config() + config.enable_device(device_id) + + config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) + config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) + self._pipeline.start(config) + self._flip = flip + + def read( + self, + img_size: Optional[Tuple[int, int]] = None, # farthest: float = 0.12 + ) -> Tuple[np.ndarray, np.ndarray]: + """Read a frame from the camera. + + Args: + img_size: The size of the image to return. If None, the original size is returned. + farthest: The farthest distance to map to 255. + + Returns: + np.ndarray: The color image, shape=(H, W, 3) + np.ndarray: The depth image, shape=(H, W, 1) + """ + import cv2 + + frames = self._pipeline.wait_for_frames() + color_frame = frames.get_color_frame() + color_image = np.asanyarray(color_frame.get_data()) + depth_frame = frames.get_depth_frame() + depth_image = np.asanyarray(depth_frame.get_data()) + # depth_image = cv2.convertScaleAbs(depth_image, alpha=0.03) + if img_size is None: + image = color_image[:, :, ::-1] + depth = depth_image + else: + image = cv2.resize(color_image, img_size)[:, :, ::-1] + depth = cv2.resize(depth_image, img_size) + + # rotate 180 degree's because everything is upside down in order to center the camera + if self._flip: + image = cv2.rotate(image, cv2.ROTATE_180) + depth = cv2.rotate(depth, cv2.ROTATE_180)[:, :, None] + else: + depth = depth[:, :, None] + + return image, depth + + +def _debug_read(camera, save_datastream=False): + import cv2 + + cv2.namedWindow("image") + cv2.namedWindow("depth") + counter = 0 + if not os.path.exists("images"): + os.makedirs("images") + if save_datastream and not os.path.exists("stream"): + os.makedirs("stream") + while True: + time.sleep(0.1) + image, depth = camera.read() + depth = np.concatenate([depth, depth, depth], axis=-1) + key = cv2.waitKey(1) + cv2.imshow("image", image[:, :, ::-1]) + cv2.imshow("depth", depth) + if key == ord("s"): + cv2.imwrite(f"images/image_{counter}.png", image[:, :, ::-1]) + cv2.imwrite(f"images/depth_{counter}.png", depth) + if save_datastream: + cv2.imwrite(f"stream/image_{counter}.png", image[:, :, ::-1]) + cv2.imwrite(f"stream/depth_{counter}.png", depth) + counter += 1 + if key == 27: + break + + +if __name__ == "__main__": + device_ids = get_device_ids() + print(f"Found {len(device_ids)} devices") + print(device_ids) + rs = RealSenseCamera(flip=True, device_id=device_ids[0]) + im, depth = rs.read() + _debug_read(rs, save_datastream=True) diff --git a/gello/data_utils/conversion_utils.py b/gello/data_utils/conversion_utils.py new file mode 100644 index 0000000..d7f6977 --- /dev/null +++ b/gello/data_utils/conversion_utils.py @@ -0,0 +1,231 @@ +from typing import Dict + +import cv2 +import numpy as np +import torch +import transforms3d._gohlketransforms as ttf + + +def to_torch(array, device="cpu"): + if isinstance(array, torch.Tensor): + return array.to(device) + if isinstance(array, np.ndarray): + return torch.from_numpy(array).to(device) + else: + return torch.tensor(array).to(device) + + +def to_numpy(array): + if isinstance(array, torch.Tensor): + return array.cpu().numpy() + return array + + +def center_crop(rgb_frame, depth_frame): + H, W = rgb_frame.shape[-2:] + sq_size = min(H, W) + + # crop the center square + if H > W: + rgb_frame = rgb_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size] + depth_frame = depth_frame[..., (-sq_size / 2) : (sq_size / 2), :sq_size] + elif W < H: + rgb_frame = rgb_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)] + depth_frame = depth_frame[..., :sq_size, (-sq_size / 2) : (sq_size / 2)] + + return rgb_frame, depth_frame + + +def resize(rgb, depth, size=224): + rgb = rgb.transpose([1, 2, 0]) + rgb = cv2.resize(rgb, (size, size), interpolation=cv2.INTER_LINEAR) + rgb = rgb.transpose([2, 0, 1]) + + depth = cv2.resize(depth[0], (size, size), interpolation=cv2.INTER_LINEAR) + depth = depth.reshape([1, size, size]) + return rgb, depth + + +def filter_depth(depth, max_depth=2.0, min_depth=0.0): + depth[np.isnan(depth)] = 0.0 + depth[np.isinf(depth)] = 0.0 + depth = np.clip(depth, min_depth, max_depth) + return depth + + +def preproc_obs( + demo: Dict[str, np.ndarray], joint_only: bool = True +) -> Dict[str, np.ndarray]: + # c h w + rgb_wrist = demo.get(f"wrist_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore + depth_wrist = demo.get(f"wrist_depth").transpose([2, 0, 1]) * 1.0 # type: ignore + rgb_base = demo.get("base_rgb").transpose([2, 0, 1]) * 1.0 # type: ignore + depth_base = demo.get("base_depth").transpose([2, 0, 1]) * 1.0 # type: ignore + + # Center crop and fitler depth + rgb_wrist, depth_wrist = resize(*center_crop(rgb_wrist, depth_wrist)) + rgb_base, depth_base = resize(*center_crop(rgb_base, depth_base)) + + depth_wrist = filter_depth(depth_wrist) + depth_base = filter_depth(depth_base) + + rgb = np.stack([rgb_wrist, rgb_base], axis=0) + depth = np.stack([depth_wrist, depth_base], axis=0) + + # Dummy + dummy_cam = np.eye(4) + K = np.eye(3) + + # state + qpos, qvel, ee_pos_quat, gripper_pos = ( + demo.get("joint_positions"), # type: ignore + demo.get("joint_velocities"), # type: ignore + demo.get("ee_pos_quat"), # type: ignore + demo.get("gripper_position"), # type: ignore + ) + + if joint_only: + state: np.ndarray = qpos + else: + state: np.ndarray = np.concatenate([qpos, qvel, ee_pos_quat, gripper_pos[None]]) + + return { + "rgb": rgb, + "depth": depth, + "camera_poses": dummy_cam, + "K_matrices": K, + "state": state, + } + + +class Pose(object): + def __init__(self, x, y, z, qw, qx, qy, qz): + self.p = np.array([x, y, z]) + + # we internally use tf.transformations, which uses [x, y, z, w] for quaternions. + self.q = np.array([qx, qy, qz, qw]) + + # make sure that the quaternion has positive scalar part + if self.q[3] < 0: + self.q *= -1 + + self.q = self.q / np.linalg.norm(self.q) + + def __mul__(self, other): + assert isinstance(other, Pose) + p = self.p + ttf.quaternion_matrix(self.q)[:3, :3].dot(other.p) + q = ttf.quaternion_multiply(self.q, other.q) + return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2]) + + def __rmul__(self, other): + assert isinstance(other, Pose) + return other * self + + def __str__(self): + return "p: {}, q: {}".format(self.p, self.q) + + def inv(self): + R = ttf.quaternion_matrix(self.q)[:3, :3] + p = -R.T.dot(self.p) + q = ttf.quaternion_inverse(self.q) + return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2]) + + def to_quaternion(self): + """ + this satisfies Pose(*to_quaternion(p)) == p + """ + q_reverted = np.array([self.q[3], self.q[0], self.q[1], self.q[2]]) + return np.concatenate([self.p, q_reverted]) + + def to_axis_angle(self): + """ + returns the axis-angle representation of the rotation. + """ + angle = 2 * np.arccos(self.q[3]) + angle = angle / np.pi + if angle > 1: + angle -= 2 + + axis = self.q[:3] / np.linalg.norm(self.q[:3]) + + # keep the axes positive + if axis[0] < 0: + axis *= -1 + angle *= -1 + + return np.concatenate([self.p, axis, [angle]]) + + def to_euler(self): + q = np.array(ttf.euler_from_quaternion(self.q)) + if q[0] > np.pi: + q[0] -= 2 * np.pi + if q[1] > np.pi: + q[1] -= 2 * np.pi + if q[2] > np.pi: + q[2] -= 2 * np.pi + + q = q / np.pi + + return np.concatenate([self.p, q, [0.0]]) + + def to_44_matrix(self): + out = np.eye(4) + out[:3, :3] = ttf.quaternion_matrix(self.q)[:3, :3] + out[:3, 3] = self.p + return out + + @staticmethod + def from_axis_angle(x, y, z, ax, ay, az, phi): + """ + returns a Pose object from the axis-angle representation of the rotation. + """ + + phi = phi * np.pi + p = np.array([x, y, z]) + qw = np.cos(phi / 2.0) + qx = ax * np.sin(phi / 2.0) + qy = ay * np.sin(phi / 2.0) + qz = az * np.sin(phi / 2.0) + + return Pose(p[0], p[1], p[2], qw, qx, qy, qz) + + @staticmethod + def from_euler(x, y, z, roll, pitch, yaw, _): + """ + returns a Pose object from the euler representation of the rotation. + """ + p = np.array([x, y, z]) + roll, pitch, yaw = roll * np.pi, pitch * np.pi, yaw * np.pi + q = ttf.quaternion_from_euler(roll, pitch, yaw) + return Pose(p[0], p[1], p[2], q[3], q[0], q[1], q[2]) + + @staticmethod + def from_quaternion(x, y, z, qw, qx, qy, qz): + """ + returns a Pose object from the quaternion representation of the rotation. + """ + p = np.array([x, y, z]) + return Pose(p[0], p[1], p[2], qw, qx, qy, qz) + + +def compute_inverse_action(p, p_new, ee_control=False): + assert isinstance(p, Pose) and isinstance(p_new, Pose) + + if ee_control: + dpose = p.inv() * p_new + else: + dpose = p_new * p.inv() + + return dpose + + +def compute_forward_action(p, dpose, ee_control=False): + assert isinstance(p, Pose) and isinstance(dpose, Pose) + dpose = Pose.from_quaternion(*dpose.to_quaternion()) + + if ee_control: + p_new = p * dpose + else: + p_new = dpose * p + + return p_new diff --git a/gello/data_utils/demo_to_gdict.py b/gello/data_utils/demo_to_gdict.py new file mode 100644 index 0000000..37005b0 --- /dev/null +++ b/gello/data_utils/demo_to_gdict.py @@ -0,0 +1,345 @@ +import glob +import os +import pickle +import shutil +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import tyro +from natsort import natsorted +from tqdm import tqdm + +from gello.data_utils.plot_utils import plot_in_grid + +np.set_printoptions(precision=3, suppress=True) + +import mediapy as mp +from gdict.data import DictArray, GDict +from simple_bc.utils.visualization_utils import make_grid_video_from_numpy + +from gello.data_utils.conversion_utils import preproc_obs + +# def get_act_bounds(source_dir: str) -> np.ndarray: +# pkls = natsorted( +# glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True +# ) +# if len(pkls) <= 30: +# print(f"Skipping {source_dir} because it has less than 30 frames.") +# return None +# pkls = pkls[:-5] +# +# scale_factor = None +# for pkl in pkls: +# try: +# with open(pkl, "rb") as f: +# demo = pickle.load(f) +# except Exception as e: +# print(f"Skipping {pkl} because it is corrupted.") +# print(f"Error: {e}") +# raise Exception("Corrupted pkl") +# +# requested_control = demo.pop("control") +# curr_scale_factor = np.abs(requested_control) +# if scale_factor is None: +# scale_factor = curr_scale_factor +# else: +# scale_factor = np.maximum(scale_factor, curr_scale_factor) +# assert scale_factor is not None +# return scale_factor + + +def get_act_min_max(source_dir: str) -> Tuple[np.ndarray, np.ndarray]: + pkls = natsorted( + glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True + ) + if len(pkls) <= 30: + print(f"Skipping {source_dir} because it has less than 30 frames.") + raise RuntimeError("Too few frames") + pkls = pkls[:-5] + + scale_min = None + scale_max = None + for pkl in pkls: + try: + with open(pkl, "rb") as f: + demo = pickle.load(f) + except Exception as e: + print(f"Skipping {pkl} because it is corrupted.") + print(f"Error: {e}") + raise Exception("Corrupted pkl") + + requested_control = demo.pop("control") + curr_scale_factor = requested_control + if scale_min is None: + assert scale_max is None + scale_min = curr_scale_factor + scale_max = curr_scale_factor + else: + assert scale_max is not None + scale_min = np.minimum(scale_min, curr_scale_factor) + scale_max = np.maximum(scale_min, curr_scale_factor) + + assert scale_min is not None + assert scale_max is not None + return scale_min, scale_max + + +def convert_single_demo( + source_dir, + i, + traj_output_dir, + rgb_output_dir, + depth_output_dir, + state_output_dir, + action_output_dir, + scale_factor, + bias_factor, +): + """ + 1. converts the demo into a gdict + 2. visualizes the RGB of the demo + 3. visualizes the state + action space of the demo + 4. returns these to be collated by the caller. + """ + + pkls = natsorted( + glob.glob(os.path.join(source_dir, "**/*.pkl"), recursive=True), reverse=True + ) + demo_stack = [] + + if len(pkls) <= 30: + return 0 + + # go through the demo in reverse order. + # remove the first few frames because they are not useful. + pkls = pkls[:-5] + + for pkl in pkls: + curr_ts = {} + try: + with open(pkl, "rb") as f: + demo = pickle.load(f) + except: + print(f"Skipping {pkl} because it is corrupted.") + return 0 + + obs = preproc_obs(demo) + action = demo.pop("control") + action = (action - bias_factor) / scale_factor # normalize between -1 and 1 + + curr_ts["obs"] = obs + curr_ts["actions"] = action + curr_ts["dones"] = np.zeros(1) # random fill + curr_ts["episode_dones"] = np.zeros(1) # random fill + + curr_ts_wrapped = dict() + curr_ts_wrapped[f"traj_{i}"] = curr_ts + demo_stack = [curr_ts_wrapped] + demo_stack + + demo_dict = DictArray.stack(demo_stack) + GDict.to_hdf5(demo_dict, os.path.join(traj_output_dir + "", f"traj_{i}.h5")) + + ## save the base videos + # save the base rgb and depth videos + all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 1].transpose([0, 2, 3, 1]) + all_rgbs = all_rgbs.astype(np.uint8) + _, H, W, _ = all_rgbs.shape + all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 1].reshape([-1, H, W]) + all_depths = all_depths / 5.0 # scale to 0-1 + + mp.write_video( + os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_base.mp4"), all_rgbs, fps=30 + ) + mp.write_video( + os.path.join(depth_output_dir + "", f"traj_{i}_depth_base.mp4"), + all_depths, + fps=30, + ) + + ## save the wrist videos + # save the rgb and depth videos + all_rgbs = demo_dict[f"traj_{i}"]["obs"]["rgb"][:, 0].transpose([0, 2, 3, 1]) + all_rgbs = all_rgbs.astype(np.uint8) + _, H, W, _ = all_rgbs.shape + all_depths = demo_dict[f"traj_{i}"]["obs"]["depth"][:, 0].reshape([-1, H, W]) + all_depths = all_depths / 2.0 # scale to 0-1 + + mp.write_video( + os.path.join(rgb_output_dir + "", f"traj_{i}_rgb_wrist.mp4"), all_rgbs, fps=30 + ) + mp.write_video( + os.path.join(depth_output_dir + "", f"traj_{i}_depth_wrist.mp4"), + all_depths, + fps=30, + ) + ## + + all_depths = np.tile(all_depths[..., None], [1, 1, 1, 3]) + + # save the state and action plots + all_actions = demo_dict[f"traj_{i}"]["actions"] + all_states = demo_dict[f"traj_{i}"]["obs"]["state"] + + curr_actions = all_actions.reshape([1, *all_actions.shape]) + curr_states = all_states.reshape([-1, *all_states.shape]) + + plot_in_grid( + curr_actions, os.path.join(action_output_dir + "", f"traj_{i}_actions.png") + ) + plot_in_grid( + curr_states, os.path.join(state_output_dir + "", f"traj_{i}_states.png") + ) + + return all_rgbs, all_depths, all_actions, all_states + + +@dataclass +class Args: + source_dir: str + vis: bool = True + + +def main(args): + subdirs = natsorted(glob.glob(os.path.join(args.source_dir, "*/"), recursive=True)) + + output_dir = args.source_dir + if output_dir[-1] == "/": + output_dir = output_dir[:-1] + + output_dir = os.path.join(output_dir, "_conv") + + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + + output_dir = os.path.join(output_dir, "multiview") + + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + else: + print(f"Output directory {output_dir} already exists, and will be deleted") + shutil.rmtree(output_dir) + os.mkdir(output_dir) + + train_dir = os.path.join(output_dir, "train") + val_dir = os.path.join(output_dir, "val") + + if not os.path.isdir(train_dir): + os.mkdir(train_dir) + if not os.path.isdir(val_dir): + os.mkdir(val_dir) + + val_size = int(min(0.1 * len(subdirs), 10)) + val_indices = np.random.choice(len(subdirs), size=val_size, replace=False) + val_indices = set(val_indices) + + print("Computing scale factors") + pbar = tqdm(range(len(subdirs))) + min_scale_factor = None + max_scale_factor = None + for i in pbar: + try: + curr_min, curr_max = get_act_min_max(subdirs[i]) + if min_scale_factor is None: + assert max_scale_factor is None + min_scale_factor = curr_min + max_scale_factor = curr_max + else: + assert max_scale_factor is not None + min_scale_factor = np.minimum(min_scale_factor, curr_min) + max_scale_factor = np.maximum(max_scale_factor, curr_max) + pbar.set_description(f"t: {i}") + except Exception as e: + print(f"Error: {e}") + print(f"Skipping {subdirs[i]}") + continue + bias_factor = (min_scale_factor + max_scale_factor) / 2.0 + scale_factor = (max_scale_factor - min_scale_factor) / 2.0 + scale_factor[scale_factor == 0] = 1.0 + print("*" * 80) + print(f"scale factors: {scale_factor}") + print(f"bias factor: {bias_factor}") + # make it into a copy pasteable string where the numbers are separated by commas + scale_factor_str = ", ".join([f"{x}" for x in scale_factor]) + print(f"scale_factor = np.array([{scale_factor_str}])") + bias_factor_str = ", ".join([f"{x}" for x in bias_factor]) + print(f"bias_factor = np.array([{bias_factor_str}])") + print("*" * 80) + + tot = 0 + + all_rgbs = [] + all_depths = [] + all_actions = [] + all_states = [] + + vis_dir = os.path.join(output_dir, "vis") + state_output_dir = os.path.join(vis_dir, "state") + action_output_dir = os.path.join(vis_dir, "action") + rgb_output_dir = os.path.join(vis_dir, "rgb") + depth_output_dir = os.path.join(vis_dir, "depth") + + if not os.path.isdir(vis_dir): + os.mkdir(vis_dir) + if not os.path.isdir(state_output_dir): + os.mkdir(state_output_dir) + if not os.path.isdir(action_output_dir): + os.mkdir(action_output_dir) + if not os.path.isdir(rgb_output_dir): + os.mkdir(rgb_output_dir) + if not os.path.isdir(depth_output_dir): + os.mkdir(depth_output_dir) + + pbar = tqdm(range(len(subdirs))) + for i in pbar: + out_dir = val_dir if i in val_indices else train_dir + out_dir = os.path.join(out_dir, "none") + + if not os.path.isdir(out_dir): + os.mkdir(out_dir) + + ret = convert_single_demo( + subdirs[i], + i, + out_dir, + rgb_output_dir, + depth_output_dir, + state_output_dir, + action_output_dir, + scale_factor=scale_factor, + bias_factor=bias_factor, + ) + + if ret != 0: + all_rgbs.append(ret[0]) + all_depths.append(ret[1]) + all_actions.append(ret[2]) + all_states.append(ret[3]) + tot += 1 + + pbar.set_description(f"t: {i}") + + print( + f"Finished converting all demos to {output_dir}! (num demos: {tot} / {len(subdirs)})" + ) + + if args.vis: + if len(all_rgbs) > 0: + print(f"Visualizing all demos...") + + plot_in_grid( + all_actions, os.path.join(action_output_dir, "_all_actions.png") + ) + plot_in_grid(all_states, os.path.join(state_output_dir, "_all_states.png")) + make_grid_video_from_numpy( + all_rgbs, 10, os.path.join(rgb_output_dir, "_all_rgb.mp4"), fps=30 + ) + make_grid_video_from_numpy( + all_depths, 10, os.path.join(depth_output_dir, "_all_depth.mp4"), fps=30 + ) + + exit(0) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/gello/data_utils/format_obs.py b/gello/data_utils/format_obs.py new file mode 100644 index 0000000..005f839 --- /dev/null +++ b/gello/data_utils/format_obs.py @@ -0,0 +1,22 @@ +import datetime +import pickle +from pathlib import Path +from typing import Dict + +import numpy as np + + +def save_frame( + folder: Path, + timestamp: datetime.datetime, + obs: Dict[str, np.ndarray], + action: np.ndarray, +) -> None: + obs["control"] = action # add action to obs + + # make folder if it doesn't exist + folder.mkdir(exist_ok=True, parents=True) + recorded_file = folder / (timestamp.isoformat() + ".pkl") + + with open(recorded_file, "wb") as f: + pickle.dump(obs, f) diff --git a/gello/data_utils/keyboard_interface.py b/gello/data_utils/keyboard_interface.py new file mode 100644 index 0000000..e1f79a2 --- /dev/null +++ b/gello/data_utils/keyboard_interface.py @@ -0,0 +1,59 @@ +import pygame + +NORMAL = (128, 128, 128) +RED = (255, 0, 0) +GREEN = (0, 255, 0) + +KEY_START = pygame.K_s +KEY_CONTINUE = pygame.K_c +KEY_QUIT_RECORDING = pygame.K_q + + +class KBReset: + def __init__(self): + pygame.init() + self._screen = pygame.display.set_mode((800, 800)) + self._set_color(NORMAL) + self._saved = False + + def update(self) -> str: + pressed_last = self._get_pressed() + if KEY_QUIT_RECORDING in pressed_last: + self._set_color(RED) + self._saved = False + return "normal" + + if self._saved: + return "save" + + if KEY_START in pressed_last: + self._set_color(GREEN) + self._saved = True + return "start" + + self._set_color(NORMAL) + return "normal" + + def _get_pressed(self): + pressed = [] + pygame.event.pump() + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + pressed.append(event.key) + return pressed + + def _set_color(self, color): + self._screen.fill(color) + pygame.display.flip() + + +def main(): + kb = KBReset() + while True: + state = kb.update() + if state == "start": + print("start") + + +if __name__ == "__main__": + main() diff --git a/gello/data_utils/plot_utils.py b/gello/data_utils/plot_utils.py new file mode 100644 index 0000000..e7eb79e --- /dev/null +++ b/gello/data_utils/plot_utils.py @@ -0,0 +1,103 @@ +import matplotlib.pyplot as plt +import numpy as np + + +def plot_in_grid(vals: np.ndarray, save_path: str): + """Plot the trajectories in a grid. + + Args: + vals: B x T x N, where + B is the number of trajectories, + T is the number of timesteps, + N is the dimensionality of the values. + save_path: path to save the plot. + """ + B = len(vals) + N = vals[0].shape[-1] + # fig, axes = plt.subplots(2, 4, figsize=(20, 10)) + rows = (N // 4) + (N % 4 > 0) + fig, axes = plt.subplots(rows, 4, figsize=(20, 10)) + for b in range(B): + curr = vals[b] + for i in range(N): + T = curr.shape[0] + # give them transparency + axes[i // 4, i % 4].plot(np.arange(T), curr[:, i], alpha=0.5) + + for i in range(N): + axes[i // 4, i % 4].set_title(f"Dim {i}") + axes[i // 4, i % 4].set_ylim([-1.0, 1.0]) + + plt.savefig(save_path) + plt.close() + + fig = plt.figure(figsize=(20, 10)) + ax = fig.add_subplot(141, projection="3d") + for b in range(B): + curr = vals[b] + ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75) + # scatter the start and end points + ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r") + ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(["Trajectory", "Start", "End"]) + ax.set_xlim([-1, 1]) + ax.set_ylim([-1, 1]) + ax.set_zlim([-1, 1]) + + # get the 2D view of the XY plane, with X pointing downwards + ax.view_init(270, 0) + + ax = fig.add_subplot(142, projection="3d") + for b in range(B): + curr = vals[b] + ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75) + ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r") + ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(["Trajectory", "Start", "End"]) + ax.set_xlim([-1, 1]) + ax.set_ylim([-1, 1]) + ax.set_zlim([-1, 1]) + + # get the 2D view of the XZ plane, with X pointing leftwards + ax.view_init(0, 0) + + ax = fig.add_subplot(143, projection="3d") + for b in range(B): + curr = vals[b] + ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75) + ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r") + ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(["Trajectory", "Start", "End"]) + ax.set_xlim([-1, 1]) + ax.set_ylim([-1, 1]) + ax.set_zlim([-1, 1]) + + # get the 2D view of the YZ plane, with Y pointing leftwards + ax.view_init(0, 90) + + ax = fig.add_subplot(144, projection="3d") + for b in range(B): + curr = vals[b] + ax.plot(curr[:, 0], curr[:, 1], curr[:, 2], alpha=0.75) + ax.scatter(curr[0, 0], curr[0, 1], curr[0, 2], c="r") + ax.scatter(curr[-1, 0], curr[-1, 1], curr[-1, 2], c="g") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.legend(["Trajectory", "Start", "End"]) + ax.set_xlim([-1, 1]) + ax.set_ylim([-1, 1]) + ax.set_zlim([-1, 1]) + + plt.savefig(save_path[:-4] + "_3d.png") + + plt.close() diff --git a/gello/dm_control_tasks/arenas/__init__.py b/gello/dm_control_tasks/arenas/__init__.py new file mode 100644 index 0000000..06ca331 --- /dev/null +++ b/gello/dm_control_tasks/arenas/__init__.py @@ -0,0 +1,5 @@ +from gello.dm_control_tasks.arenas.base import Arena + +__all__ = [ + "Arena", +] diff --git a/gello/dm_control_tasks/arenas/arena.xml b/gello/dm_control_tasks/arenas/arena.xml new file mode 100644 index 0000000..fdff4cc --- /dev/null +++ b/gello/dm_control_tasks/arenas/arena.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + diff --git a/gello/dm_control_tasks/arenas/base.py b/gello/dm_control_tasks/arenas/base.py new file mode 100644 index 0000000..af2194d --- /dev/null +++ b/gello/dm_control_tasks/arenas/base.py @@ -0,0 +1,26 @@ +"""Arena composer class.""" + +from pathlib import Path + +from dm_control import composer, mjcf + +_ARENA_XML = Path(__file__).resolve().parent / "arena.xml" + + +class Arena(composer.Entity): + """Base arena class.""" + + def _build(self, name: str = "arena") -> None: + self._mjcf_root = mjcf.from_path(str(_ARENA_XML)) + if name is not None: + self._mjcf_root.model = name + + def add_free_entity(self, entity) -> mjcf.Element: + """Includes an entity as a free moving body, i.e., with a freejoint.""" + frame = self.attach(entity) + frame.add("freejoint") + return frame + + @property + def mjcf_model(self) -> mjcf.RootElement: + return self._mjcf_root diff --git a/gello/dm_control_tasks/arms/__init__.py b/gello/dm_control_tasks/arms/__init__.py new file mode 100644 index 0000000..33da9e7 --- /dev/null +++ b/gello/dm_control_tasks/arms/__init__.py @@ -0,0 +1,7 @@ +from gello.dm_control_tasks.arms.franka import Franka +from gello.dm_control_tasks.arms.ur5e import UR5e + +__all__ = [ + "UR5e", + "Franka", +] diff --git a/gello/dm_control_tasks/arms/franka.py b/gello/dm_control_tasks/arms/franka.py new file mode 100644 index 0000000..ef712ad --- /dev/null +++ b/gello/dm_control_tasks/arms/franka.py @@ -0,0 +1,28 @@ +"""Franka composer class.""" + +from pathlib import Path +from typing import Optional, Union + +from dm_control import mjcf + +from gello.dm_control_tasks.arms.manipulator import Manipulator +from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT + +XML = MENAGERIE_ROOT / "franka_emika_panda" / "panda_nohand.xml" +GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml" + + +class Franka(Manipulator): + """Franka Robot.""" + + def _build( + self, + name: str = "franka", + xml_path: Union[str, Path] = XML, + gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML, + ) -> None: + super()._build(name="franka", xml_path=XML, gripper_xml_path=GRIPPER_XML) + + @property + def flange(self) -> mjcf.Element: + return self._mjcf_root.find("site", "attachment_site") diff --git a/gello/dm_control_tasks/arms/manipulator.py b/gello/dm_control_tasks/arms/manipulator.py new file mode 100644 index 0000000..6f251e8 --- /dev/null +++ b/gello/dm_control_tasks/arms/manipulator.py @@ -0,0 +1,229 @@ +"""Manipulator composer class.""" +import abc +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +from dm_control import composer, mjcf +from dm_control.composer.observation import observable +from dm_control.mujoco.wrapper import mjbindings +from dm_control.suite.utils.randomizers import random_limited_quaternion + +from gello.dm_control_tasks import mjcf_utils + + +def attach_hand_to_arm( + arm_mjcf: mjcf.RootElement, + hand_mjcf: mjcf.RootElement, + # attach_site: str, +) -> None: + """Attaches a hand to an arm. + + The arm must have a site named "attachment_site". + + Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm + + Args: + arm_mjcf: The mjcf.RootElement of the arm. + hand_mjcf: The mjcf.RootElement of the hand. + attach_site: The name of the site to attach the hand to. + + Raises: + ValueError: If the arm does not have a site named "attachment_site". + """ + physics = mjcf.Physics.from_mjcf_model(hand_mjcf) + + # attachment_site = arm_mjcf.find("site", attach_site) + attachment_site = arm_mjcf.find("site", "attachment_site") + if attachment_site is None: + raise ValueError("No attachment site found in the arm model.") + + # Expand the ctrl and qpos keyframes to account for the new hand DoFs. + arm_key = arm_mjcf.find("key", "home") + if arm_key is not None: + hand_key = hand_mjcf.find("key", "home") + if hand_key is None: + arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)]) + arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)]) + else: + arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl]) + arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos]) + + attachment_site.attach(hand_mjcf) + + +class Manipulator(composer.Entity, abc.ABC): + """A manipulator entity.""" + + def _build( + self, + name: str, + xml_path: Union[str, Path], + gripper_xml_path: Optional[Union[str, Path]], + ) -> None: + """Builds the manipulator. + + Subclasses can not override this method, but should call this method in their + own _build() method. + """ + self._mjcf_root = mjcf.from_path(str(xml_path)) + self._arm_joints = tuple(mjcf_utils.safe_find_all(self._mjcf_root, "joint")) + if gripper_xml_path: + gripper_mjcf = mjcf.from_path(str(gripper_xml_path)) + attach_hand_to_arm(self._mjcf_root, gripper_mjcf) + + self._mjcf_root.model = name + self._add_mjcf_elements() + + def set_joints(self, physics: mjcf.Physics, joints: np.ndarray) -> None: + assert len(joints) == len(self._arm_joints) + for joint, joint_value in zip(self._arm_joints, joints): + joint_id = physics.bind(joint).element_id + joint_name = physics.model.id2name(joint_id, "joint") + physics.named.data.qpos[joint_name] = joint_value + + def randomize_joints( + self, + physics: mjcf.Physics, + random: Optional[np.random.RandomState] = None, + ) -> None: + random = random or np.random # type: ignore + assert random is not None + hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE + slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE + ball = mjbindings.enums.mjtJoint.mjJNT_BALL + free = mjbindings.enums.mjtJoint.mjJNT_FREE + + qpos = physics.named.data.qpos + + for joint in self._arm_joints: + joint_id = physics.bind(joint).element_id + # joint_id = physics.model.name2id(joint.name, "joint") + joint_name = physics.model.id2name(joint_id, "joint") + joint_type = physics.model.jnt_type[joint_id] + is_limited = physics.model.jnt_limited[joint_id] + range_min, range_max = physics.model.jnt_range[joint_id] + + if is_limited: + if joint_type in [hinge, slide]: + qpos[joint_name] = random.uniform(range_min, range_max) + + elif joint_type == ball: + qpos[joint_name] = random_limited_quaternion(random, range_max) + + else: + if joint_type == hinge: + qpos[joint_name] = random.uniform(-np.pi, np.pi) + + elif joint_type == ball: + quat = random.randn(4) + quat /= np.linalg.norm(quat) + qpos[joint_name] = quat + + elif joint_type == free: + # this should be random.randn, but changing it now could significantly + # affect benchmark results. + quat = random.rand(4) + quat /= np.linalg.norm(quat) + qpos[joint_name][3:] = quat + + def _add_mjcf_elements(self) -> None: + # Parse joints. + joints = mjcf_utils.safe_find_all(self._mjcf_root, "joint") + joints = [joint for joint in joints if joint.tag != "freejoint"] + self._joints = tuple(joints) + + # Parse actuators. + actuators = mjcf_utils.safe_find_all(self._mjcf_root, "actuator") + self._actuators = tuple(actuators) + + # Parse qpos / ctrl keyframes. + self._keyframes = {} + keyframes = mjcf_utils.safe_find_all(self._mjcf_root, "key") + if keyframes: + for frame in keyframes: + if frame.qpos is not None: + qpos = np.array(frame.qpos) + self._keyframes[frame.name] = qpos + + # add a visualizeation the flange position that is green + self.flange.parent.add( + "geom", + name="flange_geom", + type="sphere", + size="0.01", + rgba="0 1 0 1", + pos=self.flange.pos, + contype="0", + conaffinity="0", + ) + + def _build_observables(self): + return ArmObservables(self) + + @property + @abc.abstractmethod + def flange(self) -> mjcf.Element: + """Returns the flange element. + + The flange is the end effector of the manipulator where tools can be + attached, such as a gripper. + """ + + @property + def mjcf_model(self) -> mjcf.RootElement: + return self._mjcf_root + + @property + def name(self) -> str: + return self._mjcf_root.model + + @property + def joints(self) -> Tuple[mjcf.Element, ...]: + return self._joints + + @property + def actuators(self) -> Tuple[mjcf.Element, ...]: + return self._actuators + + @property + def keyframes(self) -> Dict[str, np.ndarray]: + return self._keyframes + + +class ArmObservables(composer.Observables): + """Base class for quadruped observables.""" + + @composer.observable + def joints_pos(self): + return observable.MJCFFeature("qpos", self._entity.joints) + + @composer.observable + def joints_vel(self): + return observable.MJCFFeature("qvel", self._entity.joints) + + @composer.observable + def flange_position(self): + return observable.MJCFFeature("xpos", self._entity.flange) + + @composer.observable + def flange_orientation(self): + return observable.MJCFFeature("xmat", self._entity.flange) + + # Semantic grouping of observables. + def _collect_from_attachments(self, attribute_name: str): + out: List[observable.MJCFFeature] = [] + for entity in self._entity.iter_entities(exclude_self=True): + out.extend(getattr(entity.observables, attribute_name, [])) + return out + + @property + def proprioception(self): + return [ + self.joints_pos, + self.joints_vel, + self.flange_position, + # self.flange_orientation, + # self.flange_velocity, + self.flange_angular_velocity, + ] + self._collect_from_attachments("proprioception") diff --git a/gello/dm_control_tasks/arms/ur5e.py b/gello/dm_control_tasks/arms/ur5e.py new file mode 100644 index 0000000..8ab4aba --- /dev/null +++ b/gello/dm_control_tasks/arms/ur5e.py @@ -0,0 +1,26 @@ +"""UR5e composer class.""" + +from pathlib import Path +from typing import Optional, Union + +from dm_control import mjcf + +from gello.dm_control_tasks.arms.manipulator import Manipulator +from gello.dm_control_tasks.mjcf_utils import MENAGERIE_ROOT + + +class UR5e(Manipulator): + GRIPPER_XML = MENAGERIE_ROOT / "robotiq_2f85" / "2f85.xml" + XML = MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml" + + def _build( + self, + name: str = "UR5e", + xml_path: Union[str, Path] = XML, + gripper_xml_path: Optional[Union[str, Path]] = GRIPPER_XML, + ) -> None: + super()._build(name=name, xml_path=xml_path, gripper_xml_path=gripper_xml_path) + + @property + def flange(self) -> mjcf.Element: + return self._mjcf_root.find("site", "attachment_site") diff --git a/gello/dm_control_tasks/arms/ur5e_test.py b/gello/dm_control_tasks/arms/ur5e_test.py new file mode 100644 index 0000000..78bfdc5 --- /dev/null +++ b/gello/dm_control_tasks/arms/ur5e_test.py @@ -0,0 +1,22 @@ +"""Tests for ur5e.py.""" + +from absl.testing import absltest +from dm_control import mjcf + +from gello.dm_control_tasks.arms import ur5e + + +class UR5eTest(absltest.TestCase): + def test_compiles_and_steps(self) -> None: + robot = ur5e.UR5e() + physics = mjcf.Physics.from_mjcf_model(robot.mjcf_model) + physics.step() + + def test_joints(self) -> None: + robot = ur5e.UR5e() + for joint in robot.joints: + self.assertEqual(joint.tag, "joint") + + +if __name__ == "__main__": + absltest.main() diff --git a/gello/dm_control_tasks/arms/utils.py b/gello/dm_control_tasks/arms/utils.py new file mode 100644 index 0000000..4e778de --- /dev/null +++ b/gello/dm_control_tasks/arms/utils.py @@ -0,0 +1,261 @@ +import collections + +import numpy as np +from absl import logging +from dm_control.mujoco.wrapper import mjbindings + +mjlib = mjbindings.mjlib + + +_INVALID_JOINT_NAMES_TYPE = ( + "`joint_names` must be either None, a list, a tuple, or a numpy array; " "got {}." +) +_REQUIRE_TARGET_POS_OR_QUAT = ( + "At least one of `target_pos` or `target_quat` must be specified." +) + +IKResult = collections.namedtuple("IKResult", ["qpos", "err_norm", "steps", "success"]) + + +def qpos_from_site_pose( + physics, + site_name, + target_pos=None, + target_quat=None, + joint_names=None, + tol=1e-14, + rot_weight=1.0, + regularization_threshold=0.1, + regularization_strength=3e-2, + max_update_norm=2.0, + progress_thresh=20.0, + max_steps=100, + inplace=False, +): + """Find joint positions that satisfy a target site position and/or rotation. + + Args: + physics: A `mujoco.Physics` instance. + site_name: A string specifying the name of the target site. + target_pos: A (3,) numpy array specifying the desired Cartesian position of + the site, or None if the position should be unconstrained (default). + One or both of `target_pos` or `target_quat` must be specified. + target_quat: A (4,) numpy array specifying the desired orientation of the + site as a quaternion, or None if the orientation should be unconstrained + (default). One or both of `target_pos` or `target_quat` must be specified. + joint_names: (optional) A list, tuple or numpy array specifying the names of + one or more joints that can be manipulated in order to achieve the target + site pose. If None (default), all joints may be manipulated. + tol: (optional) Precision goal for `qpos` (the maximum value of `err_norm` + in the stopping criterion). + rot_weight: (optional) Determines the weight given to rotational error + relative to translational error. + regularization_threshold: (optional) L2 regularization will be used when + inverting the Jacobian whilst `err_norm` is greater than this value. + regularization_strength: (optional) Coefficient of the quadratic penalty + on joint movements. + max_update_norm: (optional) The maximum L2 norm of the update applied to + the joint positions on each iteration. The update vector will be scaled + such that its magnitude never exceeds this value. + progress_thresh: (optional) If `err_norm` divided by the magnitude of the + joint position update is greater than this value then the optimization + will terminate prematurely. This is a useful heuristic to avoid getting + stuck in local minima. + max_steps: (optional) The maximum number of iterations to perform. + inplace: (optional) If True, `physics.data` will be modified in place. + Default value is False, i.e. a copy of `physics.data` will be made. + + Returns: + An `IKResult` namedtuple with the following fields: + qpos: An (nq,) numpy array of joint positions. + err_norm: A float, the weighted sum of L2 norms for the residual + translational and rotational errors. + steps: An int, the number of iterations that were performed. + success: Boolean, True if we converged on a solution within `max_steps`, + False otherwise. + + Raises: + ValueError: If both `target_pos` and `target_quat` are None, or if + `joint_names` has an invalid type. + """ + dtype = physics.data.qpos.dtype + + if target_pos is not None and target_quat is not None: + jac = np.empty((6, physics.model.nv), dtype=dtype) + err = np.empty(6, dtype=dtype) + jac_pos, jac_rot = jac[:3], jac[3:] + err_pos, err_rot = err[:3], err[3:] + else: + jac = np.empty((3, physics.model.nv), dtype=dtype) + err = np.empty(3, dtype=dtype) + if target_pos is not None: + jac_pos, jac_rot = jac, None + err_pos, err_rot = err, None + elif target_quat is not None: + jac_pos, jac_rot = None, jac + err_pos, err_rot = None, err + else: + raise ValueError(_REQUIRE_TARGET_POS_OR_QUAT) + + update_nv = np.zeros(physics.model.nv, dtype=dtype) + + if target_quat is not None: + site_xquat = np.empty(4, dtype=dtype) + neg_site_xquat = np.empty(4, dtype=dtype) + err_rot_quat = np.empty(4, dtype=dtype) + + if not inplace: + physics = physics.copy(share_model=True) + + # Ensure that the Cartesian position of the site is up to date. + mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr) + + # Convert site name to index. + site_id = physics.model.name2id(site_name, "site") + + # These are views onto the underlying MuJoCo buffers. mj_fwdPosition will + # update them in place, so we can avoid indexing overhead in the main loop. + site_xpos = physics.named.data.site_xpos[site_name] + site_xmat = physics.named.data.site_xmat[site_name] + + # This is an index into the rows of `update` and the columns of `jac` + # that selects DOFs associated with joints that we are allowed to manipulate. + if joint_names is None: + dof_indices = slice(None) # Update all DOFs. + elif isinstance(joint_names, (list, np.ndarray, tuple)): + if isinstance(joint_names, tuple): + joint_names = list(joint_names) + # Find the indices of the DOFs belonging to each named joint. Note that + # these are not necessarily the same as the joint IDs, since a single joint + # may have >1 DOF (e.g. ball joints). + indexer = physics.named.model.dof_jntid.axes.row + # `dof_jntid` is an `(nv,)` array indexed by joint name. We use its row + # indexer to map each joint name to the indices of its corresponding DOFs. + dof_indices = indexer.convert_key_item(joint_names) + else: + raise ValueError(_INVALID_JOINT_NAMES_TYPE.format(type(joint_names))) + + steps = 0 + success = False + + for steps in range(max_steps): + err_norm = 0.0 + + if target_pos is not None: + # Translational error. + err_pos[:] = target_pos - site_xpos + err_norm += np.linalg.norm(err_pos) + if target_quat is not None: + # Rotational error. + mjlib.mju_mat2Quat(site_xquat, site_xmat) + mjlib.mju_negQuat(neg_site_xquat, site_xquat) + mjlib.mju_mulQuat(err_rot_quat, target_quat, neg_site_xquat) + mjlib.mju_quat2Vel(err_rot, err_rot_quat, 1) + err_norm += np.linalg.norm(err_rot) * rot_weight + + if err_norm < tol: + logging.debug("Converged after %i steps: err_norm=%3g", steps, err_norm) + success = True + break + else: + # TODO(b/112141670): Generalize this to other entities besides sites. + mjlib.mj_jacSite( + physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id + ) + jac_joints = jac[:, dof_indices] + + # TODO(b/112141592): This does not take joint limits into consideration. + reg_strength = ( + regularization_strength if err_norm > regularization_threshold else 0.0 + ) + update_joints = nullspace_method( + jac_joints, err, regularization_strength=reg_strength + ) + + update_norm = np.linalg.norm(update_joints) + + # Check whether we are still making enough progress, and halt if not. + progress_criterion = err_norm / update_norm + if progress_criterion > progress_thresh: + logging.debug( + "Step %2i: err_norm / update_norm (%3g) > " + "tolerance (%3g). Halting due to insufficient progress", + steps, + progress_criterion, + progress_thresh, + ) + break + + if update_norm > max_update_norm: + update_joints *= max_update_norm / update_norm + + # Write the entries for the specified joints into the full `update_nv` + # vector. + update_nv[dof_indices] = update_joints + + # Update `physics.qpos`, taking quaternions into account. + mjlib.mj_integratePos(physics.model.ptr, physics.data.qpos, update_nv, 1) + + # Compute the new Cartesian position of the site. + mjlib.mj_fwdPosition(physics.model.ptr, physics.data.ptr) + + logging.debug( + "Step %2i: err_norm=%-10.3g update_norm=%-10.3g", + steps, + err_norm, + update_norm, + ) + + if not success and steps == max_steps - 1: + logging.warning( + "Failed to converge after %i steps: err_norm=%3g", steps, err_norm + ) + + if not inplace: + # Our temporary copy of physics.data is about to go out of scope, and when + # it does the underlying mjData pointer will be freed and physics.data.qpos + # will be a view onto a block of deallocated memory. We therefore need to + # make a copy of physics.data.qpos while physics.data is still alive. + qpos = physics.data.qpos.copy() + else: + # If we're modifying physics.data in place then it's fine to return a view. + qpos = physics.data.qpos + + return IKResult(qpos=qpos, err_norm=err_norm, steps=steps, success=success) + + +def nullspace_method(jac_joints, delta, regularization_strength=0.0): + """Calculates the joint velocities to achieve a specified end effector delta. + + Args: + jac_joints: The Jacobian of the end effector with respect to the joints. A + numpy array of shape `(ndelta, nv)`, where `ndelta` is the size of `delta` + and `nv` is the number of degrees of freedom. + delta: The desired end-effector delta. A numpy array of shape `(3,)` or + `(6,)` containing either position deltas, rotation deltas, or both. + regularization_strength: (optional) Coefficient of the quadratic penalty + on joint movements. Default is zero, i.e. no regularization. + + Returns: + An `(nv,)` numpy array of joint velocities. + + Reference: + Buss, S. R. S. (2004). Introduction to inverse kinematics with jacobian + transpose, pseudoinverse and damped least squares methods. + https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf + """ + hess_approx = jac_joints.T.dot(jac_joints) + joint_delta = jac_joints.T.dot(delta) + if regularization_strength > 0: + # L2 regularization + hess_approx += np.eye(hess_approx.shape[0]) * regularization_strength + return np.linalg.solve(hess_approx, joint_delta) + else: + return np.linalg.lstsq(hess_approx, joint_delta, rcond=-1)[0] + + +class InverseKinematics: + def __init__(self, xml_path: str): + """Initializes the inverse kinematics class.""" + ... + # TODO diff --git a/gello/dm_control_tasks/manipulation/__init__.py b/gello/dm_control_tasks/manipulation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gello/dm_control_tasks/manipulation/arenas/floors.py b/gello/dm_control_tasks/manipulation/arenas/floors.py new file mode 100644 index 0000000..0a67d60 --- /dev/null +++ b/gello/dm_control_tasks/manipulation/arenas/floors.py @@ -0,0 +1,111 @@ +"""Simple floor arenas.""" +from typing import Tuple + +import numpy as np +from dm_control import mjcf + +from gello.dm_control_tasks.arenas import base + +_GROUNDPLANE_QUAD_SIZE = 0.25 + + +class FixedManipulationArena(base.Arena): + @property + def arm_attachment_site(self) -> np.ndarray: + return self.attachment_site + + +class Floor(base.Arena): + """An arena with a checkered pattern.""" + + def _build( + self, + name: str = "floor", + size: Tuple[float, float] = (8, 8), + reflectance: float = 0.2, + top_camera_y_padding_factor: float = 1.1, + top_camera_distance: float = 5.0, + ) -> None: + super()._build(name=name) + + self._size = size + self._top_camera_y_padding_factor = top_camera_y_padding_factor + self._top_camera_distance = top_camera_distance + + assert self._mjcf_root.worldbody is not None + + z_offset = 0.00 + + # Add arm attachement site + self._mjcf_root.worldbody.add( + "site", + name="arm_attachment", + pos=(0, 0, z_offset), + size=(0.01, 0.01, 0.01), + type="sphere", + rgba=(0, 0, 0, 0), + ) + + # Add light. + self._mjcf_root.worldbody.add( + "light", + pos=(0, 0, 1.5), + dir=(0, 0, -1), + directional=True, + ) + + self._ground_texture = self._mjcf_root.asset.add( + "texture", + rgb1=[0.2, 0.3, 0.4], + rgb2=[0.1, 0.2, 0.3], + type="2d", + builtin="checker", + name="groundplane", + width=200, + height=200, + mark="edge", + markrgb=[0.8, 0.8, 0.8], + ) + + self._ground_material = self._mjcf_root.asset.add( + "material", + name="groundplane", + texrepeat=[2, 2], # Makes white squares exactly 1x1 length units. + texuniform=True, + reflectance=reflectance, + texture=self._ground_texture, + ) + + # Build groundplane. + self._ground_geom = self._mjcf_root.worldbody.add( + "geom", + type="plane", + name="groundplane", + material=self._ground_material, + size=list(size) + [_GROUNDPLANE_QUAD_SIZE], + ) + + # Choose the FOV so that the floor always fits nicely within the frame + # irrespective of actual floor size. + fovy_radians = 2 * np.arctan2( + top_camera_y_padding_factor * size[1], top_camera_distance + ) + self._top_camera = self._mjcf_root.worldbody.add( + "camera", + name="top_camera", + pos=[0, 0, top_camera_distance], + quat=[1, 0, 0, 0], + fovy=np.rad2deg(fovy_radians), + ) + + @property + def ground_geoms(self): + return (self._ground_geom,) + + @property + def size(self): + return self._size + + @property + def arm_attachment_site(self) -> mjcf.Element: + return self._mjcf_root.worldbody.find("site", "arm_attachment") diff --git a/gello/dm_control_tasks/manipulation/tasks/base.py b/gello/dm_control_tasks/manipulation/tasks/base.py new file mode 100644 index 0000000..b15fb05 --- /dev/null +++ b/gello/dm_control_tasks/manipulation/tasks/base.py @@ -0,0 +1,56 @@ +"""A abstract class for all walker tasks.""" + +from dm_control import composer, mjcf + +from gello.dm_control_tasks.arms.manipulator import Manipulator +from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena + +# Timestep of the physics simulation. +_PHYSICS_TIMESTEP: float = 0.002 + +# Interval between agent actions, in seconds. +# We send a control signal every (_CONTROL_TIMESTEP / _PHYSICS_TIMESTEP) physics steps. +_CONTROL_TIMESTEP: float = 0.02 # 50 Hz. + + +class ManipulationTask(composer.Task): + """Base composer task for walker robots.""" + + def __init__( + self, + arm: Manipulator, + arena: FixedManipulationArena, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP, + ) -> None: + self._arm = arm + self._arena = arena + + self._arm_geoms = arm.root_body.find_all("geom") + self._arena_geoms = arena.root_body.find_all("geom") + arena.attach(arm, attach_site=arena.arm_attachment_site) + + self.set_timesteps( + control_timestep=control_timestep, physics_timestep=physics_timestep + ) + + # Enable all robot observables. Note: No additional entity observables should + # be added in subclasses. + for observable in self._arm.observables.proprioception: + observable.enabled = True + + def in_collision(self, physics: mjcf.Physics) -> bool: + """Checks if the arm is in collision with the floor (self._arena).""" + arm_ids = [physics.bind(g).element_id for g in self._arm_geoms] + arena_ids = [physics.bind(g).element_id for g in self._arena_geoms] + + return any( + (con.geom1 in arm_ids and con.geom2 in arena_ids) # arm and arena + or (con.geom1 in arena_ids and con.geom2 in arm_ids) # arena and arm + or (con.geom1 in arm_ids and con.geom2 in arm_ids) # self collision + for con in physics.data.contact[: physics.data.ncon] + ) + + @property + def root_entity(self): + return self._arena diff --git a/gello/dm_control_tasks/manipulation/tasks/block_play.py b/gello/dm_control_tasks/manipulation/tasks/block_play.py new file mode 100644 index 0000000..09589a0 --- /dev/null +++ b/gello/dm_control_tasks/manipulation/tasks/block_play.py @@ -0,0 +1,122 @@ +"""A task where a walker must learn to stand.""" +from typing import Optional + +import numpy as np +from dm_control import mjcf +from dm_control.suite.utils.randomizers import random_limited_quaternion + +from gello.dm_control_tasks.arms.manipulator import Manipulator +from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena +from gello.dm_control_tasks.manipulation.tasks import base + +_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6) + + +class BlockPlay(base.ManipulationTask): + """Task for a manipulator. Blocks are randomly placed in the scene.""" + + def __init__( + self, + arm: Manipulator, + arena: FixedManipulationArena, + physics_timestep=base._PHYSICS_TIMESTEP, + control_timestep=base._CONTROL_TIMESTEP, + num_blocks: int = 10, + size: float = 0.03, + reset_joints: Optional[np.ndarray] = None, + ) -> None: + super().__init__(arm, arena, physics_timestep, control_timestep) + + # find key frame + key_frames = self.root_entity.mjcf_model.find_all("key") + if len(key_frames) == 0: + key_frames = None + else: + key_frames = key_frames[0] + + # Create target. + block_joints = [] + for i in range(num_blocks): + # select random colors = np.random.uniform(0, 1, size=3) + color = np.concatenate([np.random.uniform(0, 1, size=3), [1.0]]) + + # attach a body for block i + b = self.root_entity.mjcf_model.worldbody.add( + "body", name=f"block_{i}", pos=(0, 0, 0) + ) + + # # add a freejoint to the block so it can be moved + _joint = b.add("freejoint") + block_joints.append(_joint) + + # add a geom to the block + b.add( + "geom", + name=f"block_geom_{i}", + type="box", + size=(size, size, size), + rgba=color, + # contype=0, + # conaffinity=0, + ) + assert key_frames is not None + key_frames.qpos = np.concatenate([key_frames.qpos, np.zeros(7)]) + + # # save xml to file + # _xml_string = self.root_entity.mjcf_model.to_xml_string() + # with open("block_play.xml", "w") as f: + # f.write(_xml_string) + + self._block_joints = block_joints + self._block_size = size + self._reset_joints = reset_joints + + def initialize_episode(self, physics, random_state): + # Randomly set feasible target position + if self._reset_joints is not None: + self._arm.set_joints(physics, self._reset_joints) + else: + self._arm.randomize_joints(physics, random_state) + physics.forward() + + # check if arm is in collision with floor + while self.in_collision(physics): + self._arm.randomize_joints(physics, random_state) + physics.forward() + + # Randomize block positions + for block_j in self._block_joints: + randomize_pose( + block_j, + physics, + random_state=random_state, + position_range=0.5, + z_offset=self._block_size * 2, + ) + + physics.forward() + + def get_reward(self, physics): + # flange position + return 0 + + +def randomize_pose( + free_joint: mjcf.RootElement, + physics: mjcf.Physics, + random_state: np.random.RandomState, + position_range: float = 0.5, + z_offset: float = 0.0, +) -> None: + """Randomize the pose of an entity.""" + entity_pos = random_state.uniform(-position_range, position_range, size=2) + # make x, y farther than 0.1 from 0, 0 + while np.linalg.norm(entity_pos) < 0.3: + entity_pos = random_state.uniform(-position_range, position_range, size=2) + + entity_pos = np.concatenate([entity_pos, [z_offset]]) + entity_quat = random_limited_quaternion(random_state, limit=np.pi) + + qpos = np.concatenate([entity_pos, entity_quat]) + + physics.bind(free_joint).qpos = qpos diff --git a/gello/dm_control_tasks/manipulation/tasks/reach.py b/gello/dm_control_tasks/manipulation/tasks/reach.py new file mode 100644 index 0000000..e743e46 --- /dev/null +++ b/gello/dm_control_tasks/manipulation/tasks/reach.py @@ -0,0 +1,63 @@ +"""A task where a walker must learn to stand.""" + +import numpy as np +from dm_control.suite.utils import randomizers +from dm_control.utils import rewards + +from gello.dm_control_tasks.arms.manipulator import Manipulator +from gello.dm_control_tasks.manipulation.arenas.floors import FixedManipulationArena +from gello.dm_control_tasks.manipulation.tasks import base + +_TARGET_COLOR = (0.8, 0.2, 0.2, 0.6) + + +class Reach(base.ManipulationTask): + """Reach task for a manipulator.""" + + def __init__( + self, + arm: Manipulator, + arena: FixedManipulationArena, + physics_timestep=base._PHYSICS_TIMESTEP, + control_timestep=base._CONTROL_TIMESTEP, + distance_tolerance: float = 0.5, + ) -> None: + super().__init__(arm, arena, physics_timestep, control_timestep) + + self._distance_tolerance = distance_tolerance + + # Create target. + self._target = self.root_entity.mjcf_model.worldbody.add( + "site", + name="target", + type="sphere", + pos=(0, 0, 0), + size=(0.1,), + rgba=_TARGET_COLOR, + ) + + def initialize_episode(self, physics, random_state): + # Randomly set feasible target position + randomizers.randomize_limited_and_rotational_joints(physics, random_state) + physics.forward() + flange_position = physics.bind(self._arm.flange).xpos[:3] + print(flange_position) + + # set target position to flange position + physics.bind(self._target).pos = flange_position + + # Randomize initial position of the arm. + randomizers.randomize_limited_and_rotational_joints(physics, random_state) + physics.forward() + + def get_reward(self, physics): + # flange position + flange_pos = physics.bind(self._arm.flange).pos[:3] + distance = np.linalg.norm(physics.bind(self._target).pos[:3] - flange_pos) + return -rewards.tolerance( + distance, + bounds=(0, self._distance_tolerance), + margin=self._distance_tolerance, + value_at_margin=0, + sigmoid="linear", + ) diff --git a/gello/dm_control_tasks/mjcf_utils.py b/gello/dm_control_tasks/mjcf_utils.py new file mode 100644 index 0000000..1e36cec --- /dev/null +++ b/gello/dm_control_tasks/mjcf_utils.py @@ -0,0 +1,23 @@ +from pathlib import Path +from typing import List + +from dm_control import mjcf + +# Path to the root of the project. +_PROJECT_ROOT: Path = Path(__file__).parent.parent.parent + +# Path to the Menagerie submodule. +MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie" + + +def safe_find_all( + root: mjcf.RootElement, + feature_name: str, + immediate_children_only: bool = False, + exclude_attachments: bool = False, +) -> List[mjcf.Element]: + """Find all given elements or throw an error if none are found.""" + features = root.find_all(feature_name, immediate_children_only, exclude_attachments) + if not features: + raise ValueError(f"No {feature_name} found in the MJCF model.") + return features diff --git a/gello/dynamixel/__init__.py b/gello/dynamixel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gello/dynamixel/driver.py b/gello/dynamixel/driver.py new file mode 100644 index 0000000..360c64b --- /dev/null +++ b/gello/dynamixel/driver.py @@ -0,0 +1,274 @@ +import time +from threading import Event, Lock, Thread +from typing import Protocol, Sequence + +import numpy as np +from dynamixel_sdk.group_sync_read import GroupSyncRead +from dynamixel_sdk.group_sync_write import GroupSyncWrite +from dynamixel_sdk.packet_handler import PacketHandler +from dynamixel_sdk.port_handler import PortHandler +from dynamixel_sdk.robotis_def import ( + COMM_SUCCESS, + DXL_HIBYTE, + DXL_HIWORD, + DXL_LOBYTE, + DXL_LOWORD, +) + +# Constants +ADDR_TORQUE_ENABLE = 64 +ADDR_GOAL_POSITION = 116 +LEN_GOAL_POSITION = 4 +ADDR_PRESENT_POSITION = 132 +ADDR_PRESENT_POSITION = 140 +LEN_PRESENT_POSITION = 4 +TORQUE_ENABLE = 1 +TORQUE_DISABLE = 0 + + +class DynamixelDriverProtocol(Protocol): + def set_joints(self, joint_angles: Sequence[float]): + """Set the joint angles for the Dynamixel servos. + + Args: + joint_angles (Sequence[float]): A list of joint angles. + """ + ... + + def torque_enabled(self) -> bool: + """Check if torque is enabled for the Dynamixel servos. + + Returns: + bool: True if torque is enabled, False if it is disabled. + """ + ... + + def set_torque_mode(self, enable: bool): + """Set the torque mode for the Dynamixel servos. + + Args: + enable (bool): True to enable torque, False to disable. + """ + ... + + def get_joints(self) -> np.ndarray: + """Get the current joint angles in radians. + + Returns: + np.ndarray: An array of joint angles. + """ + ... + + def close(self): + """Close the driver.""" + + +class FakeDynamixelDriver(DynamixelDriverProtocol): + def __init__(self, ids: Sequence[int]): + self._ids = ids + self._joint_angles = np.zeros(len(ids), dtype=int) + self._torque_enabled = False + + def set_joints(self, joint_angles: Sequence[float]): + if len(joint_angles) != len(self._ids): + raise ValueError( + "The length of joint_angles must match the number of servos" + ) + if not self._torque_enabled: + raise RuntimeError("Torque must be enabled to set joint angles") + self._joint_angles = np.array(joint_angles) + + def torque_enabled(self) -> bool: + return self._torque_enabled + + def set_torque_mode(self, enable: bool): + self._torque_enabled = enable + + def get_joints(self) -> np.ndarray: + return self._joint_angles.copy() + + def close(self): + pass + + +class DynamixelDriver(DynamixelDriverProtocol): + def __init__( + self, ids: Sequence[int], port: str = "/dev/ttyUSB0", baudrate: int = 57600 + ): + """Initialize the DynamixelDriver class. + + Args: + ids (Sequence[int]): A list of IDs for the Dynamixel servos. + port (str): The USB port to connect to the arm. + baudrate (int): The baudrate for communication. + """ + self._ids = ids + self._joint_angles = None + self._lock = Lock() + + # Initialize the port handler, packet handler, and group sync read/write + self._portHandler = PortHandler(port) + self._packetHandler = PacketHandler(2.0) + self._groupSyncRead = GroupSyncRead( + self._portHandler, + self._packetHandler, + ADDR_PRESENT_POSITION, + LEN_PRESENT_POSITION, + ) + self._groupSyncWrite = GroupSyncWrite( + self._portHandler, + self._packetHandler, + ADDR_GOAL_POSITION, + LEN_GOAL_POSITION, + ) + + # Open the port and set the baudrate + if not self._portHandler.openPort(): + raise RuntimeError("Failed to open the port") + + if not self._portHandler.setBaudRate(baudrate): + raise RuntimeError(f"Failed to change the baudrate, {baudrate}") + + # Add parameters for each Dynamixel servo to the group sync read + for dxl_id in self._ids: + if not self._groupSyncRead.addParam(dxl_id): + raise RuntimeError( + f"Failed to add parameter for Dynamixel with ID {dxl_id}" + ) + + # Disable torque for each Dynamixel servo + self._torque_enabled = False + try: + self.set_torque_mode(self._torque_enabled) + except Exception as e: + print(f"port: {port}, {e}") + + self._stop_thread = Event() + self._start_reading_thread() + + def set_joints(self, joint_angles: Sequence[float]): + if len(joint_angles) != len(self._ids): + raise ValueError( + "The length of joint_angles must match the number of servos" + ) + if not self._torque_enabled: + raise RuntimeError("Torque must be enabled to set joint angles") + + for dxl_id, angle in zip(self._ids, joint_angles): + # Convert the angle to the appropriate value for the servo + position_value = int(angle * 2048 / np.pi) + + # Allocate goal position value into byte array + param_goal_position = [ + DXL_LOBYTE(DXL_LOWORD(position_value)), + DXL_HIBYTE(DXL_LOWORD(position_value)), + DXL_LOBYTE(DXL_HIWORD(position_value)), + DXL_HIBYTE(DXL_HIWORD(position_value)), + ] + + # Add goal position value to the Syncwrite parameter storage + dxl_addparam_result = self._groupSyncWrite.addParam( + dxl_id, param_goal_position + ) + if not dxl_addparam_result: + raise RuntimeError( + f"Failed to set joint angle for Dynamixel with ID {dxl_id}" + ) + + # Syncwrite goal position + dxl_comm_result = self._groupSyncWrite.txPacket() + if dxl_comm_result != COMM_SUCCESS: + raise RuntimeError("Failed to syncwrite goal position") + + # Clear syncwrite parameter storage + self._groupSyncWrite.clearParam() + + def torque_enabled(self) -> bool: + return self._torque_enabled + + def set_torque_mode(self, enable: bool): + torque_value = TORQUE_ENABLE if enable else TORQUE_DISABLE + with self._lock: + for dxl_id in self._ids: + dxl_comm_result, dxl_error = self._packetHandler.write1ByteTxRx( + self._portHandler, dxl_id, ADDR_TORQUE_ENABLE, torque_value + ) + if dxl_comm_result != COMM_SUCCESS or dxl_error != 0: + print(dxl_comm_result) + print(dxl_error) + raise RuntimeError( + f"Failed to set torque mode for Dynamixel with ID {dxl_id}" + ) + + self._torque_enabled = enable + + def _start_reading_thread(self): + self._reading_thread = Thread(target=self._read_joint_angles) + self._reading_thread.daemon = True + self._reading_thread.start() + + def _read_joint_angles(self): + # Continuously read joint angles and update the joint_angles array + while not self._stop_thread.is_set(): + time.sleep(0.001) + with self._lock: + _joint_angles = np.zeros(len(self._ids), dtype=int) + dxl_comm_result = self._groupSyncRead.txRxPacket() + if dxl_comm_result != COMM_SUCCESS: + print(f"warning, comm failed: {dxl_comm_result}") + continue + for i, dxl_id in enumerate(self._ids): + if self._groupSyncRead.isAvailable( + dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION + ): + angle = self._groupSyncRead.getData( + dxl_id, ADDR_PRESENT_POSITION, LEN_PRESENT_POSITION + ) + angle = np.int32(np.uint32(angle)) + _joint_angles[i] = angle + else: + raise RuntimeError( + f"Failed to get joint angles for Dynamixel with ID {dxl_id}" + ) + self._joint_angles = _joint_angles + # self._groupSyncRead.clearParam() # TODO what does this do? should i add it + + def get_joints(self) -> np.ndarray: + # Return a copy of the joint_angles array to avoid race conditions + while self._joint_angles is None: + time.sleep(0.1) + with self._lock: + _j = self._joint_angles.copy() + return _j / 2048.0 * np.pi + + def close(self): + self._stop_thread.set() + self._reading_thread.join() + self._portHandler.closePort() + + +def main(): + # Set the port, baudrate, and servo IDs + ids = [1] + + # Create a DynamixelDriver instance + try: + driver = DynamixelDriver(ids) + except FileNotFoundError: + driver = DynamixelDriver(ids, port="/dev/cu.usbserial-FT7WBMUB") + + driver.set_torque_mode(True) + driver.set_torque_mode(False) + + # Print the joint angles + try: + while True: + joint_angles = driver.get_joints() + print(f"Joint angles for IDs {ids}: {joint_angles}") + # print(f"Joint angles for IDs {ids[1]}: {joint_angles[1]}") + except KeyboardInterrupt: + driver.close() + + +if __name__ == "__main__": + main() diff --git a/gello/dynamixel/tests/test_driver.py b/gello/dynamixel/tests/test_driver.py new file mode 100644 index 0000000..218744b --- /dev/null +++ b/gello/dynamixel/tests/test_driver.py @@ -0,0 +1,35 @@ +import numpy as np +import pytest + +from gello.dynamixel.driver import FakeDynamixelDriver + + +@pytest.fixture +def fake_driver(): + return FakeDynamixelDriver(ids=[1, 2]) + + +def test_set_joints(fake_driver): + fake_driver.set_torque_mode(True) + fake_driver.set_joints([np.pi / 2, np.pi / 2]) + assert np.allclose(fake_driver.get_joints(), [np.pi / 2, np.pi / 2]) + + +def test_set_joints_wrong_length(fake_driver): + with pytest.raises(ValueError): + fake_driver.set_joints([np.pi / 2]) + + +def test_set_joints_torque_disabled(fake_driver): + with pytest.raises(RuntimeError): + fake_driver.set_joints([np.pi / 2, np.pi / 2]) + + +def test_torque_enabled(fake_driver): + assert not fake_driver.torque_enabled() + fake_driver.set_torque_mode(True) + assert fake_driver.torque_enabled() + + +def test_get_joints(fake_driver): + assert np.allclose(fake_driver.get_joints(), [0, 0]) diff --git a/gello/env.py b/gello/env.py new file mode 100644 index 0000000..f88d583 --- /dev/null +++ b/gello/env.py @@ -0,0 +1,88 @@ +import time +from typing import Any, Dict, Optional + +import numpy as np + +from gello.cameras.camera import CameraDriver +from gello.robots.robot import Robot + + +class Rate: + def __init__(self, rate: float): + self.last = time.time() + self.rate = rate + + def sleep(self) -> None: + while self.last + 1.0 / self.rate > time.time(): + time.sleep(0.0001) + self.last = time.time() + + +class RobotEnv: + def __init__( + self, + robot: Robot, + control_rate_hz: float = 100.0, + camera_dict: Optional[Dict[str, CameraDriver]] = None, + ) -> None: + self._robot = robot + self._rate = Rate(control_rate_hz) + self._camera_dict = {} if camera_dict is None else camera_dict + + def robot(self) -> Robot: + """Get the robot object. + + Returns: + robot: the robot object. + """ + return self._robot + + def __len__(self): + return 0 + + def step(self, joints: np.ndarray) -> Dict[str, Any]: + """Step the environment forward. + + Args: + joints: joint angles command to step the environment with. + + Returns: + obs: observation from the environment. + """ + assert len(joints) == ( + self._robot.num_dofs() + ), f"input:{len(joints)}, robot:{self._robot.num_dofs()}" + assert self._robot.num_dofs() == len(joints) + self._robot.command_joint_state(joints) + self._rate.sleep() + return self.get_obs() + + def get_obs(self) -> Dict[str, Any]: + """Get observation from the environment. + + Returns: + obs: observation from the environment. + """ + observations = {} + for name, camera in self._camera_dict.items(): + image, depth = camera.read() + observations[f"{name}_rgb"] = image + observations[f"{name}_depth"] = depth + + robot_obs = self._robot.get_observations() + assert "joint_positions" in robot_obs + assert "joint_velocities" in robot_obs + assert "ee_pos_quat" in robot_obs + observations["joint_positions"] = robot_obs["joint_positions"] + observations["joint_velocities"] = robot_obs["joint_velocities"] + observations["ee_pos_quat"] = robot_obs["ee_pos_quat"] + observations["gripper_position"] = robot_obs["gripper_position"] + return observations + + +def main() -> None: + pass + + +if __name__ == "__main__": + main() diff --git a/gello/robots/dynamixel.py b/gello/robots/dynamixel.py new file mode 100644 index 0000000..d8676f6 --- /dev/null +++ b/gello/robots/dynamixel.py @@ -0,0 +1,134 @@ +from typing import Dict, Optional, Sequence, Tuple + +import numpy as np + +from gello.robots.robot import Robot + + +class DynamixelRobot(Robot): + """A class representing a UR robot.""" + + def __init__( + self, + joint_ids: Sequence[int], + joint_offsets: Optional[Sequence[float]] = None, + joint_signs: Optional[Sequence[int]] = None, + real: bool = False, + port: str = "/dev/ttyUSB0", + baudrate: int = 57600, + gripper_config: Optional[Tuple[int, float, float]] = None, + start_joints: Optional[np.ndarray] = None, + ): + from gello.dynamixel.driver import ( + DynamixelDriver, + DynamixelDriverProtocol, + FakeDynamixelDriver, + ) + + print(f"attempting to connect to port: {port}") + self.gripper_open_close: Optional[Tuple[float, float]] + if gripper_config is not None: + assert joint_offsets is not None + assert joint_signs is not None + + # joint_ids.append(gripper_config[0]) + # joint_offsets.append(0.0) + # joint_signs.append(1) + joint_ids = tuple(joint_ids) + (gripper_config[0],) + joint_offsets = tuple(joint_offsets) + (0.0,) + joint_signs = tuple(joint_signs) + (1,) + self.gripper_open_close = ( + gripper_config[1] * np.pi / 180, + gripper_config[2] * np.pi / 180, + ) + else: + self.gripper_open_close = None + + self._joint_ids = joint_ids + self._driver: DynamixelDriverProtocol + + if joint_offsets is None: + self._joint_offsets = np.zeros(len(joint_ids)) + else: + self._joint_offsets = np.array(joint_offsets) + + if joint_signs is None: + self._joint_signs = np.ones(len(joint_ids)) + else: + self._joint_signs = np.array(joint_signs) + + assert len(self._joint_ids) == len(self._joint_offsets), ( + f"joint_ids: {len(self._joint_ids)}, " + f"joint_offsets: {len(self._joint_offsets)}" + ) + assert len(self._joint_ids) == len(self._joint_signs), ( + f"joint_ids: {len(self._joint_ids)}, " + f"joint_signs: {len(self._joint_signs)}" + ) + assert np.all( + np.abs(self._joint_signs) == 1 + ), f"joint_signs: {self._joint_signs}" + + if real: + self._driver = DynamixelDriver(joint_ids, port=port, baudrate=baudrate) + self._driver.set_torque_mode(False) + else: + self._driver = FakeDynamixelDriver(joint_ids) + self._torque_on = False + self._last_pos = None + self._alpha = 0.99 + + if start_joints is not None: + # loop through all joints and add +- 2pi to the joint offsets to get the closest to start joints + new_joint_offsets = [] + current_joints = self.get_joint_state() + assert current_joints.shape == start_joints.shape + if gripper_config is not None: + current_joints = current_joints[:-1] + start_joints = start_joints[:-1] + for c_joint, s_joint, joint_offset in zip( + current_joints, start_joints, self._joint_offsets + ): + new_joint_offsets.append( + np.pi * 2 * np.round((s_joint - c_joint) / (2 * np.pi)) + + joint_offset + ) + if gripper_config is not None: + new_joint_offsets.append(self._joint_offsets[-1]) + self._joint_offsets = np.array(new_joint_offsets) + + def num_dofs(self) -> int: + return len(self._joint_ids) + + def get_joint_state(self) -> np.ndarray: + pos = (self._driver.get_joints() - self._joint_offsets) * self._joint_signs + assert len(pos) == self.num_dofs() + + if self.gripper_open_close is not None: + # map pos to [0, 1] + g_pos = (pos[-1] - self.gripper_open_close[0]) / ( + self.gripper_open_close[1] - self.gripper_open_close[0] + ) + g_pos = min(max(0, g_pos), 1) + pos[-1] = g_pos + + if self._last_pos is None: + self._last_pos = pos + else: + # exponential smoothing + pos = self._last_pos * (1 - self._alpha) + pos * self._alpha + self._last_pos = pos + + return pos + + def command_joint_state(self, joint_state: np.ndarray) -> None: + self._driver.set_joints((joint_state + self._joint_offsets).tolist()) + + def set_torque_mode(self, mode: bool): + if mode == self._torque_on: + return + self._driver.set_torque_mode(mode) + self._torque_on = mode + + def get_observations(self) -> Dict[str, np.ndarray]: + return {"joint_state": self.get_joint_state()} diff --git a/gello/robots/panda.py b/gello/robots/panda.py new file mode 100644 index 0000000..704f16a --- /dev/null +++ b/gello/robots/panda.py @@ -0,0 +1,88 @@ +import time +from typing import Dict + +import numpy as np + +from gello.robots.robot import Robot + +MAX_OPEN = 0.09 + + +class PandaRobot(Robot): + """A class representing a UR robot.""" + + def __init__(self, robot_ip: str = "100.97.47.74"): + from polymetis import GripperInterface, RobotInterface + + self.robot = RobotInterface( + ip_address=robot_ip, + ) + self.gripper = GripperInterface( + ip_address="localhost", + ) + self.robot.go_home() + self.robot.start_joint_impedance() + self.gripper.goto(width=MAX_OPEN, speed=255, force=255) + time.sleep(1) + + def num_dofs(self) -> int: + """Get the number of joints of the robot. + + Returns: + int: The number of joints of the robot. + """ + return 8 + + def get_joint_state(self) -> np.ndarray: + """Get the current state of the leader robot. + + Returns: + T: The current state of the leader robot. + """ + robot_joints = self.robot.get_joint_positions() + gripper_pos = self.gripper.get_state() + pos = np.append(robot_joints, gripper_pos.width / MAX_OPEN) + return pos + + def command_joint_state(self, joint_state: np.ndarray) -> None: + """Command the leader robot to a given state. + + Args: + joint_state (np.ndarray): The state to command the leader robot to. + """ + import torch + + self.robot.update_desired_joint_positions(torch.tensor(joint_state[:-1])) + self.gripper.goto(width=(MAX_OPEN * (1 - joint_state[-1])), speed=1, force=1) + + def get_observations(self) -> Dict[str, np.ndarray]: + joints = self.get_joint_state() + pos_quat = np.zeros(7) + gripper_pos = np.array([joints[-1]]) + return { + "joint_positions": joints, + "joint_velocities": joints, + "ee_pos_quat": pos_quat, + "gripper_position": gripper_pos, + } + + +def main(): + robot = PandaRobot() + current_joints = robot.get_joint_state() + # move a small delta 0.1 rad + move_joints = current_joints + 0.05 + # make last joint (gripper) closed + move_joints[-1] = 0.5 + time.sleep(1) + m = 0.09 + robot.gripper.goto(1 * m, speed=255, force=255) + time.sleep(1) + robot.gripper.goto(1.05 * m, speed=255, force=255) + time.sleep(1) + robot.gripper.goto(1.1 * m, speed=255, force=255) + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/gello/robots/robot.py b/gello/robots/robot.py new file mode 100644 index 0000000..b4a6e96 --- /dev/null +++ b/gello/robots/robot.py @@ -0,0 +1,128 @@ +from abc import abstractmethod +from typing import Dict, Protocol + +import numpy as np + + +class Robot(Protocol): + """Robot protocol. + + A protocol for a robot that can be controlled. + """ + + @abstractmethod + def num_dofs(self) -> int: + """Get the number of joints of the robot. + + Returns: + int: The number of joints of the robot. + """ + raise NotImplementedError + + @abstractmethod + def get_joint_state(self) -> np.ndarray: + """Get the current state of the leader robot. + + Returns: + T: The current state of the leader robot. + """ + raise NotImplementedError + + @abstractmethod + def command_joint_state(self, joint_state: np.ndarray) -> None: + """Command the leader robot to a given state. + + Args: + joint_state (np.ndarray): The state to command the leader robot to. + """ + raise NotImplementedError + + @abstractmethod + def get_observations(self) -> Dict[str, np.ndarray]: + """Get the current observations of the robot. + + This is to extract all the information that is available from the robot, + such as joint positions, joint velocities, etc. This may also include + information from additional sensors, such as cameras, force sensors, etc. + + Returns: + Dict[str, np.ndarray]: A dictionary of observations. + """ + raise NotImplementedError + + +class PrintRobot(Robot): + """A robot that prints the commanded joint state.""" + + def __init__(self, num_dofs: int, dont_print: bool = False): + self._num_dofs = num_dofs + self._joint_state = np.zeros(num_dofs) + self._dont_print = dont_print + + def num_dofs(self) -> int: + return self._num_dofs + + def get_joint_state(self) -> np.ndarray: + return self._joint_state + + def command_joint_state(self, joint_state: np.ndarray) -> None: + assert len(joint_state) == (self._num_dofs), ( + f"Expected joint state of length {self._num_dofs}, " + f"got {len(joint_state)}." + ) + self._joint_state = joint_state + if not self._dont_print: + print(self._joint_state) + + def get_observations(self) -> Dict[str, np.ndarray]: + joint_state = self.get_joint_state() + pos_quat = np.zeros(7) + return { + "joint_positions": joint_state, + "joint_velocities": joint_state, + "ee_pos_quat": pos_quat, + "gripper_position": np.array(0), + } + + +class BimanualRobot(Robot): + def __init__(self, robot_l: Robot, robot_r: Robot): + self._robot_l = robot_l + self._robot_r = robot_r + + def num_dofs(self) -> int: + return self._robot_l.num_dofs() + self._robot_r.num_dofs() + + def get_joint_state(self) -> np.ndarray: + return np.concatenate( + (self._robot_l.get_joint_state(), self._robot_r.get_joint_state()) + ) + + def command_joint_state(self, joint_state: np.ndarray) -> None: + self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()]) + self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :]) + + def get_observations(self) -> Dict[str, np.ndarray]: + l_obs = self._robot_l.get_observations() + r_obs = self._robot_r.get_observations() + assert l_obs.keys() == r_obs.keys() + return_obs = {} + for k in l_obs.keys(): + try: + return_obs[k] = np.concatenate((l_obs[k], r_obs[k])) + except Exception as e: + print(e) + print(k) + print(l_obs[k]) + print(r_obs[k]) + raise RuntimeError() + + return return_obs + + +def main(): + pass + + +if __name__ == "__main__": + main() diff --git a/gello/robots/robotiq_gripper.py b/gello/robots/robotiq_gripper.py new file mode 100644 index 0000000..40192df --- /dev/null +++ b/gello/robots/robotiq_gripper.py @@ -0,0 +1,358 @@ +"""Module to control Robotiq's grippers - tested with HAND-E. + +Taken from https://github.com/githubuser0xFFFF/py_robotiq_gripper/blob/master/src/robotiq_gripper.py +""" + +import socket +import threading +import time +from enum import Enum +from typing import OrderedDict, Tuple, Union + + +class RobotiqGripper: + """Communicates with the gripper directly, via socket with string commands, leveraging string names for variables.""" + + # WRITE VARIABLES (CAN ALSO READ) + ACT = ( + "ACT" # act : activate (1 while activated, can be reset to clear fault status) + ) + GTO = ( + "GTO" # gto : go to (will perform go to with the actions set in pos, for, spe) + ) + ATR = "ATR" # atr : auto-release (emergency slow move) + ADR = ( + "ADR" # adr : auto-release direction (open(1) or close(0) during auto-release) + ) + FOR = "FOR" # for : force (0-255) + SPE = "SPE" # spe : speed (0-255) + POS = "POS" # pos : position (0-255), 0 = open + # READ VARIABLES + STA = "STA" # status (0 = is reset, 1 = activating, 3 = active) + PRE = "PRE" # position request (echo of last commanded position) + OBJ = "OBJ" # object detection (0 = moving, 1 = outer grip, 2 = inner grip, 3 = no object at rest) + FLT = "FLT" # fault (0=ok, see manual for errors if not zero) + + ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work + + class GripperStatus(Enum): + """Gripper status reported by the gripper. The integer values have to match what the gripper sends.""" + + RESET = 0 + ACTIVATING = 1 + # UNUSED = 2 # This value is currently not used by the gripper firmware + ACTIVE = 3 + + class ObjectStatus(Enum): + """Object status reported by the gripper. The integer values have to match what the gripper sends.""" + + MOVING = 0 + STOPPED_OUTER_OBJECT = 1 + STOPPED_INNER_OBJECT = 2 + AT_DEST = 3 + + def __init__(self): + """Constructor.""" + self.socket = None + self.command_lock = threading.Lock() + self._min_position = 0 + self._max_position = 255 + self._min_speed = 0 + self._max_speed = 255 + self._min_force = 0 + self._max_force = 255 + + def connect(self, hostname: str, port: int, socket_timeout: float = 10.0) -> None: + """Connects to a gripper at the given address. + + :param hostname: Hostname or ip. + :param port: Port. + :param socket_timeout: Timeout for blocking socket operations. + """ + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + assert self.socket is not None + self.socket.connect((hostname, port)) + self.socket.settimeout(socket_timeout) + + def disconnect(self) -> None: + """Closes the connection with the gripper.""" + assert self.socket is not None + self.socket.close() + + def _set_vars(self, var_dict: OrderedDict[str, Union[int, float]]): + """Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response. + + :param var_dict: Dictionary of variables to set (variable_name, value). + :return: True on successful reception of ack, false if no ack was received, indicating the set may not + have been effective. + """ + assert self.socket is not None + # construct unique command + cmd = "SET" + for variable, value in var_dict.items(): + cmd += f" {variable} {str(value)}" + cmd += "\n" # new line is required for the command to finish + # atomic commands send/rcv + with self.command_lock: + self.socket.sendall(cmd.encode(self.ENCODING)) + data = self.socket.recv(1024) + return self._is_ack(data) + + def _set_var(self, variable: str, value: Union[int, float]): + """Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response. + + :param variable: Variable to set. + :param value: Value to set for the variable. + :return: True on successful reception of ack, false if no ack was received, indicating the set may not + have been effective. + """ + return self._set_vars(OrderedDict([(variable, value)])) + + def _get_var(self, variable: str): + """Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the response is received or the socket times out. + + :param variable: Name of the variable to retrieve. + :return: Value of the variable as integer. + """ + assert self.socket is not None + # atomic commands send/rcv + with self.command_lock: + cmd = f"GET {variable}\n" + self.socket.sendall(cmd.encode(self.ENCODING)) + data = self.socket.recv(1024) + + # expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value + # note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here + var_name, value_str = data.decode(self.ENCODING).split() + if var_name != variable: + raise ValueError( + f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'" + ) + value = int(value_str) + return value + + @staticmethod + def _is_ack(data: str): + return data == b"ack" + + def _reset(self): + """Reset the gripper. + + The following code is executed in the corresponding script function + def rq_reset(gripper_socket="1"): + rq_set_var("ACT", 0, gripper_socket) + rq_set_var("ATR", 0, gripper_socket) + + while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0): + rq_set_var("ACT", 0, gripper_socket) + rq_set_var("ATR", 0, gripper_socket) + sync() + end + + sleep(0.5) + end + """ + self._set_var(self.ACT, 0) + self._set_var(self.ATR, 0) + while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0: + self._set_var(self.ACT, 0) + self._set_var(self.ATR, 0) + time.sleep(0.5) + + def activate(self, auto_calibrate: bool = True): + """Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags. + + :param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion. + + The following code is executed in the corresponding script function + def rq_activate(gripper_socket="1"): + if (not rq_is_gripper_activated(gripper_socket)): + rq_reset(gripper_socket) + + while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0): + rq_reset(gripper_socket) + sync() + end + + rq_set_var("ACT",1, gripper_socket) + end + end + + def rq_activate_and_wait(gripper_socket="1"): + if (not rq_is_gripper_activated(gripper_socket)): + rq_activate(gripper_socket) + sleep(1.0) + + while(not rq_get_var("ACT", 1, gripper_socket) == 1 or not rq_get_var("STA", 1, gripper_socket) == 3): + sleep(0.1) + end + + sleep(0.5) + end + end + """ + if not self.is_active(): + self._reset() + while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0: + time.sleep(0.01) + + self._set_var(self.ACT, 1) + time.sleep(1.0) + while not self._get_var(self.ACT) == 1 or not self._get_var(self.STA) == 3: + time.sleep(0.01) + + # auto-calibrate position range if desired + if auto_calibrate: + self.auto_calibrate() + + def is_active(self): + """Returns whether the gripper is active.""" + status = self._get_var(self.STA) + return ( + RobotiqGripper.GripperStatus(status) == RobotiqGripper.GripperStatus.ACTIVE + ) + + def get_min_position(self) -> int: + """Returns the minimum position the gripper can reach (open position).""" + return self._min_position + + def get_max_position(self) -> int: + """Returns the maximum position the gripper can reach (closed position).""" + return self._max_position + + def get_open_position(self) -> int: + """Returns what is considered the open position for gripper (minimum position value).""" + return self.get_min_position() + + def get_closed_position(self) -> int: + """Returns what is considered the closed position for gripper (maximum position value).""" + return self.get_max_position() + + def is_open(self): + """Returns whether the current position is considered as being fully open.""" + return self.get_current_position() <= self.get_open_position() + + def is_closed(self): + """Returns whether the current position is considered as being fully closed.""" + return self.get_current_position() >= self.get_closed_position() + + def get_current_position(self) -> int: + """Returns the current position as returned by the physical hardware.""" + return self._get_var(self.POS) + + def auto_calibrate(self, log: bool = True) -> None: + """Attempts to calibrate the open and closed positions, by slowly closing and opening the gripper. + + :param log: Whether to print the results to log. + """ + # first try to open in case we are holding an object + (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1) + if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST: + raise RuntimeError(f"Calibration failed opening to start: {str(status)}") + + # try to close as far as possible, and record the number + (position, status) = self.move_and_wait_for_pos( + self.get_closed_position(), 64, 1 + ) + if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST: + raise RuntimeError( + f"Calibration failed because of an object: {str(status)}" + ) + assert position <= self._max_position + self._max_position = position + + # try to open as far as possible, and record the number + (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1) + if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST: + raise RuntimeError( + f"Calibration failed because of an object: {str(status)}" + ) + assert position >= self._min_position + self._min_position = position + + if log: + print( + f"Gripper auto-calibrated to [{self.get_min_position()}, {self.get_max_position()}]" + ) + + def move(self, position: int, speed: int, force: int) -> Tuple[bool, int]: + """Sends commands to start moving towards the given position, with the specified speed and force. + + :param position: Position to move to [min_position, max_position] + :param speed: Speed to move at [min_speed, max_speed] + :param force: Force to use [min_force, max_force] + :return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with + the actual position that was requested, after being adjusted to the min/max calibrated range. + """ + + def clip_val(min_val, val, max_val): + return max(min_val, min(val, max_val)) + + clip_pos = clip_val(self._min_position, position, self._max_position) + clip_spe = clip_val(self._min_speed, speed, self._max_speed) + clip_for = clip_val(self._min_force, force, self._max_force) + + # moves to the given position with the given speed and force + var_dict = OrderedDict( + [ + (self.POS, clip_pos), + (self.SPE, clip_spe), + (self.FOR, clip_for), + (self.GTO, 1), + ] + ) + succ = self._set_vars(var_dict) + time.sleep(0.008) # need to wait (dont know why) + return succ, clip_pos + + def move_and_wait_for_pos( + self, position: int, speed: int, force: int + ) -> Tuple[int, ObjectStatus]: # noqa + """Sends commands to start moving towards the given position, with the specified speed and force, and then waits for the move to complete. + + :param position: Position to move to [min_position, max_position] + :param speed: Speed to move at [min_speed, max_speed] + :param force: Force to use [min_force, max_force] + :return: A tuple with an integer representing the last position returned by the gripper after it notified + that the move had completed, a status indicating how the move ended (see ObjectStatus enum for details). Note + that it is possible that the position was not reached, if an object was detected during motion. + """ + set_ok, cmd_pos = self.move(position, speed, force) + if not set_ok: + raise RuntimeError("Failed to set variables for move.") + + # wait until the gripper acknowledges that it will try to go to the requested position + while self._get_var(self.PRE) != cmd_pos: + time.sleep(0.001) + + # wait until not moving + cur_obj = self._get_var(self.OBJ) + while ( + RobotiqGripper.ObjectStatus(cur_obj) == RobotiqGripper.ObjectStatus.MOVING + ): + cur_obj = self._get_var(self.OBJ) + + # report the actual position and the object status + final_pos = self._get_var(self.POS) + final_obj = cur_obj + return final_pos, RobotiqGripper.ObjectStatus(final_obj) + + +def main(): + # test open and closing the gripper + gripper = RobotiqGripper() + gripper.connect(hostname="192.168.1.10", port=63352) + # gripper.activate() + print(gripper.get_current_position()) + gripper.move(20, 255, 1) + time.sleep(0.2) + print(gripper.get_current_position()) + gripper.move(65, 255, 1) + time.sleep(0.2) + print(gripper.get_current_position()) + gripper.move(20, 255, 1) + gripper.disconnect() + + +if __name__ == "__main__": + main() diff --git a/gello/robots/sim_robot.py b/gello/robots/sim_robot.py new file mode 100644 index 0000000..9ae5c26 --- /dev/null +++ b/gello/robots/sim_robot.py @@ -0,0 +1,256 @@ +import pickle +import threading +import time +from typing import Any, Dict, Optional + +import mujoco +import mujoco.viewer +import numpy as np +import zmq +from dm_control import mjcf + +from gello.robots.robot import Robot + +assert mujoco.viewer is mujoco.viewer + + +def attach_hand_to_arm( + arm_mjcf: mjcf.RootElement, + hand_mjcf: mjcf.RootElement, +) -> None: + """Attaches a hand to an arm. + + The arm must have a site named "attachment_site". + + Taken from https://github.com/deepmind/mujoco_menagerie/blob/main/FAQ.md#how-do-i-attach-a-hand-to-an-arm + + Args: + arm_mjcf: The mjcf.RootElement of the arm. + hand_mjcf: The mjcf.RootElement of the hand. + + Raises: + ValueError: If the arm does not have a site named "attachment_site". + """ + physics = mjcf.Physics.from_mjcf_model(hand_mjcf) + + attachment_site = arm_mjcf.find("site", "attachment_site") + if attachment_site is None: + raise ValueError("No attachment site found in the arm model.") + + # Expand the ctrl and qpos keyframes to account for the new hand DoFs. + arm_key = arm_mjcf.find("key", "home") + if arm_key is not None: + hand_key = hand_mjcf.find("key", "home") + if hand_key is None: + arm_key.ctrl = np.concatenate([arm_key.ctrl, np.zeros(physics.model.nu)]) + arm_key.qpos = np.concatenate([arm_key.qpos, np.zeros(physics.model.nq)]) + else: + arm_key.ctrl = np.concatenate([arm_key.ctrl, hand_key.ctrl]) + arm_key.qpos = np.concatenate([arm_key.qpos, hand_key.qpos]) + + attachment_site.attach(hand_mjcf) + + +def build_scene(robot_xml_path: str, gripper_xml_path: Optional[str] = None): + # assert robot_xml_path.endswith(".xml") + + arena = mjcf.RootElement() + arm_simulate = mjcf.from_path(robot_xml_path) + # arm_copy = mjcf.from_path(xml_path) + + if gripper_xml_path is not None: + # attach gripper to the robot at "attachment_site" + gripper_simulate = mjcf.from_path(gripper_xml_path) + attach_hand_to_arm(arm_simulate, gripper_simulate) + + arena.worldbody.attach(arm_simulate) + # arena.worldbody.attach(arm_copy) + + return arena + + +class ZMQServerThread(threading.Thread): + def __init__(self, server): + super().__init__() + self._server = server + + def run(self): + self._server.serve() + + def terminate(self): + self._server.stop() + + +class ZMQRobotServer: + """A class representing a ZMQ server for a robot.""" + + def __init__(self, robot: Robot, host: str = "127.0.0.1", port: int = 5556): + self._robot = robot + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REP) + addr = f"tcp://{host}:{port}" + self._socket.bind(addr) + self._stop_event = threading.Event() + + def serve(self) -> None: + """Serve the robot state and commands over ZMQ.""" + self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms + while not self._stop_event.is_set(): + try: + message = self._socket.recv() + request = pickle.loads(message) + + # Call the appropriate method based on the request + method = request.get("method") + args = request.get("args", {}) + result: Any + if method == "num_dofs": + result = self._robot.num_dofs() + elif method == "get_joint_state": + result = self._robot.get_joint_state() + elif method == "command_joint_state": + result = self._robot.command_joint_state(**args) + elif method == "get_observations": + result = self._robot.get_observations() + else: + result = {"error": "Invalid method"} + print(result) + raise NotImplementedError( + f"Invalid method: {method}, {args, result}" + ) + + self._socket.send(pickle.dumps(result)) + except zmq.error.Again: + print("Timeout in ZMQLeaderServer serve") + # Timeout occurred, check if the stop event is set + + def stop(self) -> None: + self._stop_event.set() + self._socket.close() + self._context.term() + + +class MujocoRobotServer: + def __init__( + self, + xml_path: str, + gripper_xml_path: Optional[str] = None, + host: str = "127.0.0.1", + port: int = 5556, + print_joints: bool = False, + ): + self._has_gripper = gripper_xml_path is not None + arena = build_scene(xml_path, gripper_xml_path) + + assets: Dict[str, str] = {} + for asset in arena.asset.all_children(): + if asset.tag == "mesh": + f = asset.file + assets[f.get_vfs_filename()] = asset.file.contents + + xml_string = arena.to_xml_string() + # save xml_string to file + with open("arena.xml", "w") as f: + f.write(xml_string) + + self._model = mujoco.MjModel.from_xml_string(xml_string, assets) + self._data = mujoco.MjData(self._model) + + self._num_joints = self._model.nu + + self._joint_state = np.zeros(self._num_joints) + self._joint_cmd = self._joint_state + + self._zmq_server = ZMQRobotServer(robot=self, host=host, port=port) + self._zmq_server_thread = ZMQServerThread(self._zmq_server) + + self._print_joints = print_joints + + def num_dofs(self) -> int: + return self._num_joints + + def get_joint_state(self) -> np.ndarray: + return self._joint_state + + def command_joint_state(self, joint_state: np.ndarray) -> None: + assert len(joint_state) == self._num_joints, ( + f"Expected joint state of length {self._num_joints}, " + f"got {len(joint_state)}." + ) + if self._has_gripper: + _joint_state = joint_state.copy() + _joint_state[-1] = _joint_state[-1] * 255 + self._joint_cmd = _joint_state + else: + self._joint_cmd = joint_state.copy() + + def freedrive_enabled(self) -> bool: + return True + + def set_freedrive_mode(self, enable: bool): + pass + + def get_observations(self) -> Dict[str, np.ndarray]: + joint_positions = self._data.qpos.copy()[: self._num_joints] + joint_velocities = self._data.qvel.copy()[: self._num_joints] + ee_site = "attachment_site" + try: + ee_pos = self._data.site_xpos.copy()[ + mujoco.mj_name2id(self._model, 6, ee_site) + ] + ee_mat = self._data.site_xmat.copy()[ + mujoco.mj_name2id(self._model, 6, ee_site) + ] + ee_quat = np.zeros(4) + mujoco.mju_mat2Quat(ee_quat, ee_mat) + except Exception: + ee_pos = np.zeros(3) + ee_quat = np.zeros(4) + ee_quat[0] = 1 + gripper_pos = self._data.qpos.copy()[self._num_joints - 1] + return { + "joint_positions": joint_positions, + "joint_velocities": joint_velocities, + "ee_pos_quat": np.concatenate([ee_pos, ee_quat]), + "gripper_position": gripper_pos, + } + + def serve(self) -> None: + # start the zmq server + self._zmq_server_thread.start() + with mujoco.viewer.launch_passive(self._model, self._data) as viewer: + while viewer.is_running(): + step_start = time.time() + + # mj_step can be replaced with code that also evaluates + # a policy and applies a control signal before stepping the physics. + self._data.ctrl[:] = self._joint_cmd + # self._data.qpos[:] = self._joint_cmd + mujoco.mj_step(self._model, self._data) + self._joint_state = self._data.qpos.copy()[: self._num_joints] + + if self._print_joints: + print(self._joint_state) + + # Example modification of a viewer option: toggle contact points every two seconds. + with viewer.lock(): + # TODO remove? + viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int( + self._data.time % 2 + ) + + # Pick up changes to the physics state, apply perturbations, update options from GUI. + viewer.sync() + + # Rudimentary time keeping, will drift relative to wall clock. + time_until_next_step = self._model.opt.timestep - ( + time.time() - step_start + ) + if time_until_next_step > 0: + time.sleep(time_until_next_step) + + def stop(self) -> None: + self._zmq_server_thread.join() + + def __del__(self) -> None: + self.stop() diff --git a/gello/robots/ur.py b/gello/robots/ur.py new file mode 100644 index 0000000..b79cfad --- /dev/null +++ b/gello/robots/ur.py @@ -0,0 +1,133 @@ +from typing import Dict + +import numpy as np + +from gello.robots.robot import Robot + + +class URRobot(Robot): + """A class representing a UR robot.""" + + def __init__(self, robot_ip: str = "192.168.1.10", no_gripper: bool = False): + import rtde_control + import rtde_receive + + [print("in ur robot") for _ in range(4)] + try: + self.robot = rtde_control.RTDEControlInterface(robot_ip) + except Exception as e: + print(e) + print(robot_ip) + + self.r_inter = rtde_receive.RTDEReceiveInterface(robot_ip) + if not no_gripper: + from gello.robots.robotiq_gripper import RobotiqGripper + + self.gripper = RobotiqGripper() + self.gripper.connect(hostname=robot_ip, port=63352) + print("gripper connected") + # gripper.activate() + + [print("connect") for _ in range(4)] + + self._free_drive = False + self.robot.endFreedriveMode() + self._use_gripper = not no_gripper + + def num_dofs(self) -> int: + """Get the number of joints of the robot. + + Returns: + int: The number of joints of the robot. + """ + if self._use_gripper: + return 7 + return 6 + + def _get_gripper_pos(self) -> float: + import time + + time.sleep(0.01) + gripper_pos = self.gripper.get_current_position() + assert 0 <= gripper_pos <= 255, "Gripper position must be between 0 and 255" + return gripper_pos / 255 + + def get_joint_state(self) -> np.ndarray: + """Get the current state of the leader robot. + + Returns: + T: The current state of the leader robot. + """ + robot_joints = self.r_inter.getActualQ() + if self._use_gripper: + gripper_pos = self._get_gripper_pos() + pos = np.append(robot_joints, gripper_pos) + else: + pos = robot_joints + return pos + + def command_joint_state(self, joint_state: np.ndarray) -> None: + """Command the leader robot to a given state. + + Args: + joint_state (np.ndarray): The state to command the leader robot to. + """ + velocity = 0.5 + acceleration = 0.5 + dt = 1.0 / 500 # 2ms + lookahead_time = 0.2 + gain = 100 + + robot_joints = joint_state[:6] + t_start = self.robot.initPeriod() + self.robot.servoJ( + robot_joints, velocity, acceleration, dt, lookahead_time, gain + ) + if self._use_gripper: + gripper_pos = joint_state[-1] * 255 + self.gripper.move(gripper_pos, 255, 10) + self.robot.waitPeriod(t_start) + + def freedrive_enabled(self) -> bool: + """Check if the robot is in freedrive mode. + + Returns: + bool: True if the robot is in freedrive mode, False otherwise. + """ + return self._free_drive + + def set_freedrive_mode(self, enable: bool) -> None: + """Set the freedrive mode of the robot. + + Args: + enable (bool): True to enable freedrive mode, False to disable it. + """ + if enable and not self._free_drive: + self._free_drive = True + self.robot.freedriveMode() + elif not enable and self._free_drive: + self._free_drive = False + self.robot.endFreedriveMode() + + def get_observations(self) -> Dict[str, np.ndarray]: + joints = self.get_joint_state() + pos_quat = np.zeros(7) + gripper_pos = np.array([joints[-1]]) + return { + "joint_positions": joints, + "joint_velocities": joints, + "ee_pos_quat": pos_quat, + "gripper_position": gripper_pos, + } + + +def main(): + robot_ip = "192.168.1.11" + ur = URRobot(robot_ip, no_gripper=True) + print(ur) + ur.set_freedrive_mode(True) + print(ur.get_observations()) + + +if __name__ == "__main__": + main() diff --git a/gello/robots/xarm_robot.py b/gello/robots/xarm_robot.py new file mode 100644 index 0000000..366367b --- /dev/null +++ b/gello/robots/xarm_robot.py @@ -0,0 +1,358 @@ +import dataclasses +import threading +import time +from typing import Dict, Optional + +import numpy as np +from pyquaternion import Quaternion + +from gello.robots.robot import Robot + + +def _aa_from_quat(quat: np.ndarray) -> np.ndarray: + """Convert a quaternion to an axis-angle representation. + + Args: + quat (np.ndarray): The quaternion to convert. + + Returns: + np.ndarray: The axis-angle representation of the quaternion. + """ + assert quat.shape == (4,), "Input quaternion must be a 4D vector." + norm = np.linalg.norm(quat) + assert norm != 0, "Input quaternion must not be a zero vector." + quat = quat / norm # Normalize the quaternion + + Q = Quaternion(w=quat[3], x=quat[0], y=quat[1], z=quat[2]) + angle = Q.angle + axis = Q.axis + aa = axis * angle + return aa + + +def _quat_from_aa(aa: np.ndarray) -> np.ndarray: + """Convert an axis-angle representation to a quaternion. + + Args: + aa (np.ndarray): The axis-angle representation to convert. + + Returns: + np.ndarray: The quaternion representation of the axis-angle. + """ + assert aa.shape == (3,), "Input axis-angle must be a 3D vector." + norm = np.linalg.norm(aa) + assert norm != 0, "Input axis-angle must not be a zero vector." + axis = aa / norm # Normalize the axis-angle + + Q = Quaternion(axis=axis, angle=norm) + quat = np.array([Q.x, Q.y, Q.z, Q.w]) + return quat + + +@dataclasses.dataclass(frozen=True) +class RobotState: + x: float + y: float + z: float + gripper: float + j1: float + j2: float + j3: float + j4: float + j5: float + j6: float + j7: float + aa: np.ndarray + + @staticmethod + def from_robot( + cartesian: np.ndarray, + joints: np.ndarray, + gripper: float, + aa: np.ndarray, + ) -> "RobotState": + return RobotState( + cartesian[0], + cartesian[1], + cartesian[2], + gripper, + joints[0], + joints[1], + joints[2], + joints[3], + joints[4], + joints[5], + joints[6], + aa, + ) + + def cartesian_pos(self) -> np.ndarray: + return np.array([self.x, self.y, self.z]) + + def quat(self) -> np.ndarray: + return _quat_from_aa(self.aa) + + def joints(self) -> np.ndarray: + return np.array([self.j1, self.j2, self.j3, self.j4, self.j5, self.j6, self.j7]) + + def gripper_pos(self) -> float: + return self.gripper + + +class Rate: + def __init__(self, *, duration): + self.duration = duration + self.last = time.time() + + def sleep(self, duration=None) -> None: + duration = self.duration if duration is None else duration + assert duration >= 0 + now = time.time() + passed = now - self.last + remaining = duration - passed + assert passed >= 0 + if remaining > 0.0001: + time.sleep(remaining) + self.last = time.time() + + +class XArmRobot(Robot): + GRIPPER_OPEN = 800 + GRIPPER_CLOSE = 0 + # MAX_DELTA = 0.2 + DEFAULT_MAX_DELTA = 0.05 + + def num_dofs(self) -> int: + return 8 + + def get_joint_state(self) -> np.ndarray: + state = self.get_state() + gripper = state.gripper_pos() + all_dofs = np.concatenate([state.joints(), np.array([gripper])]) + return all_dofs + + def command_joint_state(self, joint_state: np.ndarray) -> None: + if len(joint_state) == 7: + self.set_command(joint_state, None) + elif len(joint_state) == 8: + self.set_command(joint_state[:7], joint_state[7]) + else: + raise ValueError( + f"Invalid joint state: {joint_state}, len={len(joint_state)}" + ) + + def stop(self): + self.running = False + if self.robot is not None: + self.robot.disconnect() + + if self.command_thread is not None: + self.command_thread.join() + + def __init__( + self, + ip: str = "192.168.1.226", + real: bool = True, + control_frequency: float = 50.0, + max_delta: float = DEFAULT_MAX_DELTA, + ): + print(ip) + self.real = real + self.max_delta = max_delta + if real: + from xarm.wrapper import XArmAPI + + self.robot = XArmAPI(ip, is_radian=True) + else: + self.robot = None + + self._control_frequency = control_frequency + self._clear_error_states() + self._set_gripper_position(self.GRIPPER_OPEN) + + self.last_state_lock = threading.Lock() + self.target_command_lock = threading.Lock() + self.last_state = self._update_last_state() + self.target_command = { + "joints": self.last_state.joints(), + "gripper": 0, + } + self.running = True + self.command_thread = None + if real: + self.command_thread = threading.Thread(target=self._robot_thread) + self.command_thread.start() + + def get_state(self) -> RobotState: + with self.last_state_lock: + return self.last_state + + def set_command(self, joints: np.ndarray, gripper: Optional[float] = None) -> None: + with self.target_command_lock: + self.target_command = { + "joints": joints, + "gripper": gripper, + } + + def _clear_error_states(self): + if self.robot is None: + return + self.robot.clean_error() + self.robot.clean_warn() + self.robot.motion_enable(True) + time.sleep(1) + self.robot.set_mode(1) + time.sleep(1) + self.robot.set_collision_sensitivity(0) + time.sleep(1) + self.robot.set_state(state=0) + time.sleep(1) + self.robot.set_gripper_enable(True) + time.sleep(1) + self.robot.set_gripper_mode(0) + time.sleep(1) + self.robot.set_gripper_speed(3000) + time.sleep(1) + + def _get_gripper_pos(self) -> float: + if self.robot is None: + return 0.0 + code, gripper_pos = self.robot.get_gripper_position() + while code != 0 or gripper_pos is None: + print(f"Error code {code} in get_gripper_position(). {gripper_pos}") + time.sleep(0.001) + code, gripper_pos = self.robot.get_gripper_position() + if code == 22: + self._clear_error_states() + + normalized_gripper_pos = (gripper_pos - self.GRIPPER_OPEN) / ( + self.GRIPPER_CLOSE - self.GRIPPER_OPEN + ) + return normalized_gripper_pos + + def _set_gripper_position(self, pos: int) -> None: + if self.robot is None: + return + self.robot.set_gripper_position(pos, wait=False) + # while self.robot.get_is_moving(): + # time.sleep(0.01) + + def _robot_thread(self): + rate = Rate( + duration=1 / self._control_frequency + ) # command and update rate for robot + step_times = [] + count = 0 + + while self.running: + s_t = time.time() + # update last state + self.last_state = self._update_last_state() + with self.target_command_lock: + joint_delta = np.array( + self.target_command["joints"] - self.last_state.joints() + ) + gripper_command = self.target_command["gripper"] + + norm = np.linalg.norm(joint_delta) + + # threshold delta to be at most 0.01 in norm space + if norm > self.max_delta: + delta = joint_delta / norm * self.max_delta + else: + delta = joint_delta + + # command position + self._set_position( + self.last_state.joints() + delta, + ) + + if gripper_command is not None: + set_point = gripper_command + self._set_gripper_position( + self.GRIPPER_OPEN + + set_point * (self.GRIPPER_CLOSE - self.GRIPPER_OPEN) + ) + self.last_state = self._update_last_state() + + rate.sleep() + step_times.append(time.time() - s_t) + count += 1 + if count % 1000 == 0: + # Mean, Std, Min, Max, only show 3 decimal places and string pad with 10 spaces + frequency = 1 / np.mean(step_times) + # print(f"Step time - mean: {np.mean(step_times):10.3f}, std: {np.std(step_times):10.3f}, min: {np.min(step_times):10.3f}, max: {np.max(step_times):10.3f}") + print( + f"Low Level Frequency - mean: {frequency:10.3f}, std: {np.std(frequency):10.3f}, min: {np.min(frequency):10.3f}, max: {np.max(frequency):10.3f}" + ) + step_times = [] + + def _update_last_state(self) -> RobotState: + with self.last_state_lock: + if self.robot is None: + return RobotState(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, np.zeros(3)) + + gripper_pos = self._get_gripper_pos() + + code, servo_angle = self.robot.get_servo_angle(is_radian=True) + while code != 0: + print(f"Error code {code} in get_servo_angle().") + self._clear_error_states() + code, servo_angle = self.robot.get_servo_angle(is_radian=True) + + code, cart_pos = self.robot.get_position_aa(is_radian=True) + while code != 0: + print(f"Error code {code} in get_position().") + self._clear_error_states() + code, cart_pos = self.robot.get_position_aa(is_radian=True) + + cart_pos = np.array(cart_pos) + aa = cart_pos[3:] + cart_pos[:3] /= 1000 + + return RobotState.from_robot( + cart_pos, + servo_angle, + gripper_pos, + aa, + ) + + def _set_position( + self, + joints: np.ndarray, + ) -> None: + if self.robot is None: + return + # threhold xyz to be in min max + ret = self.robot.set_servo_angle_j(joints, wait=False, is_radian=True) + if ret in [1, 9]: + self._clear_error_states() + + def get_observations(self) -> Dict[str, np.ndarray]: + state = self.get_state() + pos_quat = np.concatenate([state.cartesian_pos(), state.quat()]) + joints = self.get_joint_state() + return { + "joint_positions": joints, # rotational joint + gripper state + "joint_velocities": joints, + "ee_pos_quat": pos_quat, + "gripper_position": np.array(state.gripper_pos()), + } + + +def main(): + ip = "192.168.1.226" + robot = XArmRobot(ip) + import time + + time.sleep(1) + print(robot.get_state()) + + time.sleep(1) + print(robot.get_state()) + print("end") + robot.stop() + + +if __name__ == "__main__": + main() diff --git a/gello/zmq_core/camera_node.py b/gello/zmq_core/camera_node.py new file mode 100644 index 0000000..932212c --- /dev/null +++ b/gello/zmq_core/camera_node.py @@ -0,0 +1,69 @@ +import pickle +import threading +from typing import Optional, Tuple + +import numpy as np +import zmq + +from gello.cameras.camera import CameraDriver + +DEFAULT_CAMERA_PORT = 5000 + + +class ZMQClientCamera(CameraDriver): + """A class representing a ZMQ client for a leader robot.""" + + def __init__(self, port: int = DEFAULT_CAMERA_PORT, host: str = "127.0.0.1"): + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REQ) + self._socket.connect(f"tcp://{host}:{port}") + + def read( + self, + img_size: Optional[Tuple[int, int]] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Get the current state of the leader robot. + + Returns: + T: The current state of the leader robot. + """ + # pack the image_size and send it to the server + send_message = pickle.dumps(img_size) + self._socket.send(send_message) + state_dict = pickle.loads(self._socket.recv()) + return state_dict + + +class ZMQServerCamera: + def __init__( + self, + camera: CameraDriver, + port: int = DEFAULT_CAMERA_PORT, + host: str = "127.0.0.1", + ): + self._camera = camera + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REP) + addr = f"tcp://{host}:{port}" + debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}" + print(debug_message) + self._timout_message = f"Timeout in Camera Server, Camera: {camera}" + self._socket.bind(addr) + self._stop_event = threading.Event() + + def serve(self) -> None: + """Serve the leader robot state over ZMQ.""" + self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms + while not self._stop_event.is_set(): + try: + message = self._socket.recv() + img_size = pickle.loads(message) + camera_read = self._camera.read(img_size) + self._socket.send(pickle.dumps(camera_read)) + except zmq.Again: + print(self._timout_message) + # Timeout occurred, check if the stop event is set + + def stop(self) -> None: + """Signal the server to stop serving.""" + self._stop_event.set() diff --git a/gello/zmq_core/robot_node.py b/gello/zmq_core/robot_node.py new file mode 100644 index 0000000..bc4a636 --- /dev/null +++ b/gello/zmq_core/robot_node.py @@ -0,0 +1,125 @@ +import pickle +import threading +from typing import Any, Dict + +import numpy as np +import zmq + +from gello.robots.robot import Robot + +DEFAULT_ROBOT_PORT = 6000 + + +class ZMQServerRobot: + def __init__( + self, + robot: Robot, + port: int = DEFAULT_ROBOT_PORT, + host: str = "127.0.0.1", + ): + self._robot = robot + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REP) + addr = f"tcp://{host}:{port}" + debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}" + print(debug_message) + self._timout_message = f"Timeout in Robot Server, Robot: {robot}" + self._socket.bind(addr) + self._stop_event = threading.Event() + + def serve(self) -> None: + """Serve the leader robot state over ZMQ.""" + self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms + while not self._stop_event.is_set(): + try: + # Wait for next request from client + message = self._socket.recv() + request = pickle.loads(message) + + # Call the appropriate method based on the request + method = request.get("method") + args = request.get("args", {}) + result: Any + if method == "num_dofs": + result = self._robot.num_dofs() + elif method == "get_joint_state": + result = self._robot.get_joint_state() + elif method == "command_joint_state": + result = self._robot.command_joint_state(**args) + elif method == "get_observations": + result = self._robot.get_observations() + else: + result = {"error": "Invalid method"} + print(result) + raise NotImplementedError( + f"Invalid method: {method}, {args, result}" + ) + + self._socket.send(pickle.dumps(result)) + except zmq.Again: + print(self._timout_message) + # Timeout occurred, check if the stop event is set + + def stop(self) -> None: + """Signal the server to stop serving.""" + self._stop_event.set() + + +class ZMQClientRobot(Robot): + """A class representing a ZMQ client for a leader robot.""" + + def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"): + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REQ) + self._socket.connect(f"tcp://{host}:{port}") + + def num_dofs(self) -> int: + """Get the number of joints in the robot. + + Returns: + int: The number of joints in the robot. + """ + request = {"method": "num_dofs"} + send_message = pickle.dumps(request) + self._socket.send(send_message) + result = pickle.loads(self._socket.recv()) + return result + + def get_joint_state(self) -> np.ndarray: + """Get the current state of the leader robot. + + Returns: + T: The current state of the leader robot. + """ + request = {"method": "get_joint_state"} + send_message = pickle.dumps(request) + self._socket.send(send_message) + result = pickle.loads(self._socket.recv()) + return result + + def command_joint_state(self, joint_state: np.ndarray) -> None: + """Command the leader robot to the given state. + + Args: + joint_state (T): The state to command the leader robot to. + """ + request = { + "method": "command_joint_state", + "args": {"joint_state": joint_state}, + } + send_message = pickle.dumps(request) + self._socket.send(send_message) + result = pickle.loads(self._socket.recv()) + return result + + def get_observations(self) -> Dict[str, np.ndarray]: + """Get the current observations of the leader robot. + + Returns: + Dict[str, np.ndarray]: The current observations of the leader robot. + """ + request = {"method": "get_observations"} + send_message = pickle.dumps(request) + self._socket.send(send_message) + result = pickle.loads(self._socket.recv()) + return result diff --git a/imgs/gello_matching_joints.jpg b/imgs/gello_matching_joints.jpg new file mode 100644 index 0000000..ed92f87 Binary files /dev/null and b/imgs/gello_matching_joints.jpg differ diff --git a/imgs/robot_known_configuration.jpg b/imgs/robot_known_configuration.jpg new file mode 100644 index 0000000..5f8fc96 Binary files /dev/null and b/imgs/robot_known_configuration.jpg differ diff --git a/imgs/title.png b/imgs/title.png new file mode 100644 index 0000000..ee82708 Binary files /dev/null and b/imgs/title.png differ diff --git a/kill_nodes.sh b/kill_nodes.sh new file mode 100755 index 0000000..0108a90 --- /dev/null +++ b/kill_nodes.sh @@ -0,0 +1 @@ +pkill -9 -f launch_nodes.py diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..0e89563 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +ignore_missing_imports = True +namespace_packages = True +no_implicit_optional = True +exclude = third_party/ + +[mypy-dynamixel_sdk.*] +ignore_missing_imports = True diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..d4697f8 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,12 @@ +{ + "include": [ + "gello" + ], + "exclude": [ + "**/node_modules", + "src/typestubs" + ], + "reportMissingImports": false, + "reportMissingTypeStubs": false, + "reportGeneralTypeIssues": false +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5ac85ad --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +# keep in alphabetical order +dm_control +Pillow +pygame +pyspacemouse +PyQt6 +pyquaternion +pyrealsense2 +pure-python-adb +quaternion +tyro +ur-rtde +zmq +xarm +xarm-python-sdk +numpy-quaternion +termcolor diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..7d8fe63 --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,13 @@ +# keep in alphabetical order +# for development +black +flake8 +flake8-docstrings +isort +ipdb +jupyterlab +mypy +neovim +pyright +pytest +python-lsp-server[all] diff --git a/scripts/arm_blocks_play.py b/scripts/arm_blocks_play.py new file mode 100644 index 0000000..851e85c --- /dev/null +++ b/scripts/arm_blocks_play.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass + +import numpy as np +import tyro +from dm_control import composer, viewer + +from gello.agents.gello_agent import DynamixelRobotConfig +from gello.dm_control_tasks.arms.ur5e import UR5e +from gello.dm_control_tasks.manipulation.arenas.floors import Floor +from gello.dm_control_tasks.manipulation.tasks.block_play import BlockPlay + + +@dataclass +class Args: + use_gello: bool = False + + +config = DynamixelRobotConfig( + joint_ids=(1, 2, 3, 4, 5, 6), + joint_offsets=( + -np.pi / 2, + 1 * np.pi / 2 + np.pi, + np.pi / 2 + 0 * np.pi, + 0 * np.pi + np.pi / 2, + np.pi - 2 * np.pi / 2, + -1 * np.pi / 2 + 2 * np.pi, + ), + joint_signs=(1, 1, -1, 1, 1, 1), + gripper_config=(7, 20, -22), +) + + +def main(args: Args) -> None: + reset_joints_left = np.deg2rad([90, -90, -90, -90, 90, 0, 0]) + robot = UR5e() + task = BlockPlay(robot, Floor(), reset_joints=reset_joints_left[:-1]) + # task = BlockPlay(robot, Floor()) + env = composer.Environment(task=task) + + action_space = env.action_spec() + if args.use_gello: + gello = config.make_robot( + port="/dev/cu.usbserial-FT7WBEIA", start_joints=reset_joints_left + ) + + def policy(timestep) -> np.ndarray: + if args.use_gello: + joint_command = gello.get_joint_state() + joint_command = np.array(joint_command).copy() + + joint_command[-1] = joint_command[-1] * 255 + return joint_command + return np.random.uniform(action_space.minimum, action_space.maximum) + + viewer.launch(env, policy=policy) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/scripts/gello_get_offset.py b/scripts/gello_get_offset.py new file mode 100644 index 0000000..4e5aa8c --- /dev/null +++ b/scripts/gello_get_offset.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Tuple + +import numpy as np +import tyro + +from gello.dynamixel.driver import DynamixelDriver + +MENAGERIE_ROOT: Path = Path(__file__).parent / "third_party" / "mujoco_menagerie" + + +@dataclass +class Args: + port: str = "/dev/ttyUSB0" + """The port that GELLO is connected to.""" + + start_joints: Tuple[float, ...] = (0, 0, 0, 0, 0, 0) + """The joint angles that the GELLO is placed in at (in radians).""" + + joint_signs: Tuple[float, ...] = (1, 1, -1, 1, 1, 1) + """The joint angles that the GELLO is placed in at (in radians).""" + + gripper: bool = True + """Whether or not the gripper is attached.""" + + def __post_init__(self): + assert len(self.joint_signs) == len(self.start_joints) + for idx, j in enumerate(self.joint_signs): + assert ( + j == -1 or j == 1 + ), f"Joint idx: {idx} should be -1 or 1, but got {j}." + + @property + def num_robot_joints(self) -> int: + return len(self.start_joints) + + @property + def num_joints(self) -> int: + extra_joints = 1 if self.gripper else 0 + return self.num_robot_joints + extra_joints + + +def get_config(args: Args) -> None: + joint_ids = list(range(1, args.num_joints + 1)) + driver = DynamixelDriver(joint_ids, port=args.port, baudrate=57600) + + # assume that the joint state shouold be args.start_joints + # find the offset, which is a multiple of np.pi/2 that minimizes the error between the current joint state and args.start_joints + # this is done by brute force, we seach in a range of +/- 8pi + + def get_error(offset: float, index: int, joint_state: np.ndarray) -> float: + joint_sign_i = args.joint_signs[index] + joint_i = joint_sign_i * (joint_state[index] - offset) + start_i = args.start_joints[index] + return np.abs(joint_i - start_i) + + for _ in range(10): + driver.get_joints() # warmup + + for _ in range(1): + best_offsets = [] + curr_joints = driver.get_joints() + for i in range(args.num_robot_joints): + best_offset = 0 + best_error = 1e6 + for offset in np.linspace( + -8 * np.pi, 8 * np.pi, 8 * 4 + 1 + ): # intervals of pi/2 + error = get_error(offset, i, curr_joints) + if error < best_error: + best_error = error + best_offset = offset + best_offsets.append(best_offset) + print() + print("best offsets : ", [f"{x:.3f}" for x in best_offsets]) + print( + "best offsets function of pi: [" + + ", ".join([f"{int(np.round(x/(np.pi/2)))}*np.pi/2" for x in best_offsets]) + + " ]", + ) + if args.gripper: + print( + "gripper open (degrees) ", + np.rad2deg(driver.get_joints()[-1]) - 0.2, + ) + print( + "gripper close (degrees) ", + np.rad2deg(driver.get_joints()[-1]) - 42, + ) + + +def main(args: Args) -> None: + get_config(args) + + +if __name__ == "__main__": + main(tyro.cli(Args)) diff --git a/scripts/launch.py b/scripts/launch.py new file mode 100644 index 0000000..5ba980d --- /dev/null +++ b/scripts/launch.py @@ -0,0 +1,39 @@ +import os +import subprocess + +current_file_path = os.path.abspath(__file__) + + +def run_docker_container(): + user = os.getenv("USER") + container_name = f"gello_{user}" + gello_path = os.path.abspath(os.path.join(current_file_path, "../../")) + volume_mapping = f"{gello_path}:/gello" + + cmd = [ + "docker", + "run", + "--runtime=nvidia", + "--rm", + "--name", + container_name, + "--privileged", + "--volume", + volume_mapping, + "--volume", + "/home/gello:/homefolder", + "--net=host", + "--volume", + "/dev/serial/by-id/:/dev/serial/by-id/", + "-it", + "gello:latest", + "bash", + "-c", + "pip install -e third_party/DynamixelSDK/python && exec bash", + ] + + subprocess.run(cmd) + + +if __name__ == "__main__": + run_docker_container() diff --git a/scripts/visualize_example.py b/scripts/visualize_example.py new file mode 100644 index 0000000..ba2b160 --- /dev/null +++ b/scripts/visualize_example.py @@ -0,0 +1,41 @@ +import time +from pathlib import Path + +import mujoco +import mujoco.viewer + + +def main(): + _PROJECT_ROOT: Path = Path(__file__).parent.parent + _MENAGERIE_ROOT: Path = _PROJECT_ROOT / "third_party" / "mujoco_menagerie" + xml = _MENAGERIE_ROOT / "franka_emika_panda" / "panda.xml" + + # xml = _MENAGERIE_ROOT / "universal_robots_ur5e" / "ur5e.xml" + + m = mujoco.MjModel.from_xml_path(xml.as_posix()) + d = mujoco.MjData(m) + + with mujoco.viewer.launch_passive(m, d) as viewer: + # Close the viewer automatically after 30 wall-seconds. + while viewer.is_running(): + step_start = time.time() + + # mj_step can be replaced with code that also evaluates + # a policy and applies a control signal before stepping the physics. + mujoco.mj_step(m, d) + + # Example modification of a viewer option: toggle contact points every two seconds. + with viewer.lock(): + viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = int(d.time % 2) + + # Pick up changes to the physics state, apply perturbations, update options from GUI. + viewer.sync() + + # Rudimentary time keeping, will drift relative to wall clock. + time_until_next_step = m.opt.timestep - (time.time() - step_start) + if time_until_next_step > 0: + time.sleep(time_until_next_step) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7278068 --- /dev/null +++ b/setup.py @@ -0,0 +1,25 @@ +import setuptools + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="gello", + version="0.0.1", + author="Philipp Wu", + author_email="philippwu@berkeley.edu", + description="software for GELLO", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/wuphilipp/gello_software", + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + ], + python_requires=">=3.8", + license="MIT", + install_requires=[ + "numpy", + ], +) diff --git a/third_party/DynamixelSDK b/third_party/DynamixelSDK new file mode 160000 index 0000000..3450c70 --- /dev/null +++ b/third_party/DynamixelSDK @@ -0,0 +1 @@ +Subproject commit 3450c7078917b262d9b36042c15444047aae226e diff --git a/third_party/mujoco_menagerie b/third_party/mujoco_menagerie new file mode 160000 index 0000000..c118f3d --- /dev/null +++ b/third_party/mujoco_menagerie @@ -0,0 +1 @@ +Subproject commit c118f3d5d5ec9fb27f832ac78b3c4971234f0f4f