This commit is contained in:
Your Name 2024-05-25 23:34:18 -07:00
parent e05749660a
commit bf62f03e06
4 changed files with 91 additions and 51 deletions

View file

@ -117,7 +117,7 @@ def main(args):
)
if args.start_joints is None:
reset_joints = np.deg2rad(
[0, 0, 0, 90, 0, 90, 0, 0]
[-90, 0, 270, 90, 0, 90, 0, 0]
) # Change this to your own reset joints
else:
reset_joints = args.start_joints
@ -154,7 +154,7 @@ def main(args):
abs_deltas = np.abs(start_pos - joints)
id_max_joint_delta = np.argmax(abs_deltas)
max_joint_delta = 0.8
max_joint_delta = 0.9
if abs_deltas[id_max_joint_delta] > max_joint_delta:
id_mask = abs_deltas > max_joint_delta
print()
@ -175,30 +175,30 @@ def main(args):
joints
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
max_delta = 0.05
for _ in range(25):
obs = env.get_obs()
command_joints = agent.act(obs)
current_joints = obs["joint_positions"]
delta = command_joints - current_joints
max_joint_delta = np.abs(delta).max()
if max_joint_delta > max_delta:
delta = delta / max_joint_delta * max_delta
env.step(current_joints + delta)
# max_delta = 0.05
# for _ in range(25):
# obs = env.get_obs()
# command_joints = agent.act(obs)
# current_joints = obs["joint_positions"]
# delta = command_joints - current_joints
# max_joint_delta = np.abs(delta).max()
# if max_joint_delta > max_delta:
# delta = delta / max_joint_delta * max_delta
# env.step(current_joints + delta)
obs = env.get_obs()
joints = obs["joint_positions"]
action = agent.act(obs)
if (action - joints > 0.5).any():
print("Action is too big")
# print which joints are too big
joint_index = np.where(action - joints > 0.8)
for j in joint_index:
print(
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
)
exit()
# obs = env.get_obs()
# joints = obs["joint_positions"]
# action = agent.act(obs)
# if (action - joints > 0.5).any():
# print("Action is too big")
#
# # print which joints are too big
# joint_index = np.where(action - joints > 0.8)
# for j in joint_index:
# print(
# f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
# )
# exit()
if args.use_save_interface:
from gello.data_utils.keyboard_interface import KBReset