396 lines
12 KiB
Python
396 lines
12 KiB
Python
import dataclasses
|
|
import threading
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
import numpy as np
|
|
from pyquaternion import Quaternion
|
|
|
|
from gello.robots.robot import Robot
|
|
|
|
|
|
def _aa_from_quat(quat: np.ndarray) -> np.ndarray:
|
|
"""Convert a quaternion to an axis-angle representation.
|
|
|
|
Args:
|
|
quat (np.ndarray): The quaternion to convert.
|
|
|
|
Returns:
|
|
np.ndarray: The axis-angle representation of the quaternion.
|
|
"""
|
|
assert quat.shape == (4,), "Input quaternion must be a 4D vector."
|
|
norm = np.linalg.norm(quat)
|
|
assert norm != 0, "Input quaternion must not be a zero vector."
|
|
quat = quat / norm # Normalize the quaternion
|
|
|
|
Q = Quaternion(w=quat[3], x=quat[0], y=quat[1], z=quat[2])
|
|
angle = Q.angle
|
|
axis = Q.axis
|
|
aa = axis * angle
|
|
return aa
|
|
|
|
|
|
def _quat_from_aa(aa: np.ndarray) -> np.ndarray:
|
|
"""Convert an axis-angle representation to a quaternion.
|
|
|
|
Args:
|
|
aa (np.ndarray): The axis-angle representation to convert.
|
|
|
|
Returns:
|
|
np.ndarray: The quaternion representation of the axis-angle.
|
|
"""
|
|
assert aa.shape == (3,), "Input axis-angle must be a 3D vector."
|
|
norm = np.linalg.norm(aa)
|
|
assert norm != 0, "Input axis-angle must not be a zero vector."
|
|
axis = aa / norm # Normalize the axis-angle
|
|
|
|
Q = Quaternion(axis=axis, angle=norm)
|
|
quat = np.array([Q.x, Q.y, Q.z, Q.w])
|
|
return quat
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class RobotState:
|
|
x: float
|
|
y: float
|
|
z: float
|
|
gripper: float
|
|
j1: float
|
|
j2: float
|
|
j3: float
|
|
j4: float
|
|
j5: float
|
|
j6: float
|
|
j7: float
|
|
aa: np.ndarray
|
|
|
|
@staticmethod
|
|
def from_robot(
|
|
cartesian: np.ndarray,
|
|
joints: np.ndarray,
|
|
gripper: float,
|
|
aa: np.ndarray,
|
|
) -> "RobotState":
|
|
return RobotState(
|
|
cartesian[0],
|
|
cartesian[1],
|
|
cartesian[2],
|
|
gripper,
|
|
joints[0],
|
|
joints[1],
|
|
joints[2],
|
|
joints[3],
|
|
joints[4],
|
|
joints[5],
|
|
joints[6],
|
|
aa,
|
|
)
|
|
|
|
def cartesian_pos(self) -> np.ndarray:
|
|
return np.array([self.x, self.y, self.z])
|
|
|
|
def quat(self) -> np.ndarray:
|
|
return _quat_from_aa(self.aa)
|
|
|
|
def joints(self) -> np.ndarray:
|
|
return np.array([self.j1, self.j2, self.j3, self.j4, self.j5, self.j6, self.j7])
|
|
|
|
def gripper_pos(self) -> float:
|
|
return self.gripper
|
|
|
|
|
|
class Rate:
|
|
def __init__(self, *, duration):
|
|
self.duration = duration
|
|
self.last = time.time()
|
|
|
|
def sleep(self, duration=None) -> None:
|
|
duration = self.duration if duration is None else duration
|
|
assert duration >= 0
|
|
now = time.time()
|
|
passed = now - self.last
|
|
remaining = duration - passed
|
|
assert passed >= 0
|
|
if remaining > 0.0001:
|
|
time.sleep(remaining)
|
|
self.last = time.time()
|
|
|
|
|
|
class XArmRobot(Robot):
|
|
GRIPPER_OPEN = 800
|
|
GRIPPER_CLOSE = 100
|
|
# MAX_DELTA = 0.2
|
|
DEFAULT_MAX_DELTA = 0.05
|
|
|
|
def num_dofs(self) -> int:
|
|
return 8
|
|
|
|
def get_joint_state(self) -> np.ndarray:
|
|
state = self.get_state()
|
|
gripper = state.gripper_pos()
|
|
all_dofs = np.concatenate([state.joints(), np.array([gripper])])
|
|
return all_dofs
|
|
|
|
def command_joint_state(self, joint_state: np.ndarray) -> None:
|
|
if len(joint_state) == 7:
|
|
self.set_command(joint_state, None)
|
|
elif len(joint_state) == 8:
|
|
self.set_command(joint_state[:7], joint_state[7])
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid joint state: {joint_state}, len={len(joint_state)}"
|
|
)
|
|
|
|
def stop(self):
|
|
self.running = False
|
|
if self.robot is not None:
|
|
self.robot.disconnect()
|
|
|
|
if self.command_thread is not None:
|
|
self.command_thread.join()
|
|
|
|
def __init__(
|
|
self,
|
|
ip: str = "192.168.1.226",
|
|
real: bool = True,
|
|
control_frequency: float = 50.0,
|
|
max_delta: float = DEFAULT_MAX_DELTA,
|
|
use_robotiq: bool = True,
|
|
):
|
|
print(ip)
|
|
self.real = real
|
|
self.use_robotiq = use_robotiq
|
|
self.max_delta = max_delta
|
|
if real:
|
|
from xarm.wrapper import XArmAPI
|
|
|
|
self.robot = XArmAPI(ip, is_radian=True)
|
|
else:
|
|
self.robot = None
|
|
|
|
if self.use_robotiq:
|
|
import pyRobotiqGripper
|
|
gripper = pyRobotiqGripper.RobotiqGripper()
|
|
self.gripper = gripper
|
|
#gripper.activate()
|
|
|
|
self._control_frequency = control_frequency
|
|
self._clear_error_states()
|
|
self._set_gripper_position(self.GRIPPER_OPEN)
|
|
|
|
self.last_state_lock = threading.Lock()
|
|
self.target_command_lock = threading.Lock()
|
|
self.last_state = self._update_last_state()
|
|
self.target_command = {
|
|
"joints": self.last_state.joints(),
|
|
"gripper": 0,
|
|
}
|
|
self.running = True
|
|
self.command_thread = None
|
|
if real:
|
|
self.command_thread = threading.Thread(target=self._robot_thread)
|
|
self.command_thread.start()
|
|
|
|
|
|
def get_state(self) -> RobotState:
|
|
with self.last_state_lock:
|
|
return self.last_state
|
|
|
|
def set_command(self, joints: np.ndarray, gripper: Optional[float] = None) -> None:
|
|
with self.target_command_lock:
|
|
self.target_command = {
|
|
"joints": joints,
|
|
"gripper": gripper,
|
|
}
|
|
|
|
def _clear_error_states(self):
|
|
if self.robot is None:
|
|
return
|
|
self.robot.clean_error()
|
|
self.robot.clean_warn()
|
|
self.robot.motion_enable(True)
|
|
time.sleep(1)
|
|
self.robot.set_mode(1)
|
|
time.sleep(1)
|
|
self.robot.set_collision_sensitivity(0)
|
|
time.sleep(1)
|
|
self.robot.set_state(state=0)
|
|
time.sleep(1)
|
|
if self.use_robotiq:
|
|
pass
|
|
else:
|
|
self.robot.set_gripper_enable(True)
|
|
time.sleep(1)
|
|
self.robot.set_gripper_mode(0)
|
|
time.sleep(1)
|
|
self.robot.set_gripper_speed(3000)
|
|
time.sleep(1)
|
|
|
|
def _get_gripper_pos(self) -> float:
|
|
if self.robot is None:
|
|
return 0.0
|
|
if self.use_robotiq:
|
|
return self.gripper.getPosition() / 255
|
|
else:
|
|
code, gripper_pos = self.robot.get_gripper_position()
|
|
while code != 0 or gripper_pos is None:
|
|
print(f"Error code {code} in get_gripper_position(). {gripper_pos}")
|
|
time.sleep(0.001)
|
|
code, gripper_pos = self.robot.get_gripper_position()
|
|
if code == 22:
|
|
self._clear_error_states()
|
|
|
|
normalized_gripper_pos = (gripper_pos - self.GRIPPER_OPEN) / (
|
|
self.GRIPPER_CLOSE - self.GRIPPER_OPEN
|
|
)
|
|
return normalized_gripper_pos
|
|
|
|
def _set_gripper_position(self, pos: int) -> None:
|
|
if self.robot is None:
|
|
return
|
|
if self.use_robotiq:
|
|
pos = 255 - (pos / 800 * 255)
|
|
try:
|
|
self.gripper.goTo(pos, wait=False)
|
|
except Exception as e:
|
|
print(e)
|
|
print(pos)
|
|
raise e
|
|
else:
|
|
self.robot.set_gripper_position(pos, wait=False)
|
|
# while self.robot.get_is_moving():
|
|
# time.sleep(0.01)
|
|
|
|
def _robot_thread(self):
|
|
rate = Rate(
|
|
duration=1 / self._control_frequency
|
|
) # command and update rate for robot
|
|
step_times = []
|
|
count = 0
|
|
|
|
while self.running:
|
|
s_t = time.time()
|
|
# update last state
|
|
self.last_state = self._update_last_state()
|
|
with self.target_command_lock:
|
|
joint_delta = np.array(
|
|
self.target_command["joints"] - self.last_state.joints()
|
|
)
|
|
gripper_command = self.target_command["gripper"]
|
|
|
|
norm = np.linalg.norm(joint_delta)
|
|
|
|
# threshold delta to be at most 0.01 in norm space
|
|
if norm > self.max_delta:
|
|
delta = joint_delta / norm * self.max_delta
|
|
else:
|
|
delta = joint_delta
|
|
|
|
# command position
|
|
self._set_position(
|
|
self.last_state.joints() + delta,
|
|
)
|
|
|
|
if gripper_command is not None:
|
|
set_point = gripper_command
|
|
self._set_gripper_position(
|
|
self.GRIPPER_OPEN
|
|
+ set_point * (self.GRIPPER_CLOSE - self.GRIPPER_OPEN)
|
|
)
|
|
self.last_state = self._update_last_state()
|
|
|
|
rate.sleep()
|
|
step_times.append(time.time() - s_t)
|
|
count += 1
|
|
if count % 1000 == 0:
|
|
# Mean, Std, Min, Max, only show 3 decimal places and string pad with 10 spaces
|
|
frequency = 1 / np.mean(step_times)
|
|
# print(f"Step time - mean: {np.mean(step_times):10.3f}, std: {np.std(step_times):10.3f}, min: {np.min(step_times):10.3f}, max: {np.max(step_times):10.3f}")
|
|
print(
|
|
f"Low Level Frequency - mean: {frequency:10.3f}, std: {np.std(frequency):10.3f}, min: {np.min(frequency):10.3f}, max: {np.max(frequency):10.3f}"
|
|
)
|
|
step_times = []
|
|
|
|
def _update_last_state(self) -> RobotState:
|
|
with self.last_state_lock:
|
|
if self.robot is None:
|
|
return RobotState(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, np.zeros(3))
|
|
|
|
gripper_pos = self._get_gripper_pos()
|
|
|
|
code, servo_angle = self.robot.get_servo_angle(is_radian=True)
|
|
while code != 0:
|
|
print(f"Error code {code} in get_servo_angle().")
|
|
self._clear_error_states()
|
|
code, servo_angle = self.robot.get_servo_angle(is_radian=True)
|
|
|
|
code, cart_pos = self.robot.get_position_aa(is_radian=True)
|
|
while code != 0:
|
|
print(f"Error code {code} in get_position().")
|
|
self._clear_error_states()
|
|
code, cart_pos = self.robot.get_position_aa(is_radian=True)
|
|
|
|
cart_pos = np.array(cart_pos)
|
|
aa = cart_pos[3:]
|
|
cart_pos[:3] /= 1000
|
|
|
|
return RobotState.from_robot(
|
|
cart_pos,
|
|
servo_angle,
|
|
gripper_pos,
|
|
aa,
|
|
)
|
|
|
|
def _set_position(
|
|
self,
|
|
joints: np.ndarray,
|
|
) -> None:
|
|
if self.robot is None:
|
|
return
|
|
# threhold xyz to be in min max
|
|
ret = self.robot.set_servo_angle_j(joints, wait=False, is_radian=True)
|
|
if ret in [1, 9]:
|
|
self._clear_error_states()
|
|
|
|
def get_observations(self) -> Dict[str, np.ndarray]:
|
|
state = self.get_state()
|
|
pos_quat = np.concatenate([state.cartesian_pos(), state.quat()])
|
|
joints = self.get_joint_state()
|
|
return {
|
|
"joint_positions": joints, # rotational joint + gripper state
|
|
"joint_velocities": joints,
|
|
"ee_pos_quat": pos_quat,
|
|
"gripper_position": np.array(state.gripper_pos()),
|
|
}
|
|
|
|
|
|
def main():
|
|
ip = "192.168.1.226"
|
|
robot = XArmRobot(ip)
|
|
import time
|
|
|
|
time.sleep(1)
|
|
print(robot.get_state())
|
|
print(robot.get_state())
|
|
print(robot.get_state())
|
|
print(robot.get_state())
|
|
print(robot.get_state())
|
|
|
|
time.sleep(1)
|
|
print(robot.get_state())
|
|
print("end")
|
|
robot.command_joint_state(np.zeros(7))
|
|
robot.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ip = "192.168.1.226"
|
|
# from xarm.wrapper import XArmAPI
|
|
# robot = XArmAPI(ip, is_radian=True)
|
|
# robot.set_mode(1)
|
|
# robot.set_mode(1)
|
|
# robot.set_mode(1)
|
|
# print(robot.get_servo_angle())
|
|
# robot.set_servo_angle_j(np.ones(7)*0.1 + robot.get_servo_angle()[1])
|
|
# print(robot.get_state())
|
|
main()
|