running
This commit is contained in:
parent
e05749660a
commit
bf62f03e06
4 changed files with 91 additions and 51 deletions
|
@ -117,7 +117,7 @@ def main(args):
|
|||
)
|
||||
if args.start_joints is None:
|
||||
reset_joints = np.deg2rad(
|
||||
[0, 0, 0, 90, 0, 90, 0, 0]
|
||||
[-90, 0, 270, 90, 0, 90, 0, 0]
|
||||
) # Change this to your own reset joints
|
||||
else:
|
||||
reset_joints = args.start_joints
|
||||
|
@ -154,7 +154,7 @@ def main(args):
|
|||
abs_deltas = np.abs(start_pos - joints)
|
||||
id_max_joint_delta = np.argmax(abs_deltas)
|
||||
|
||||
max_joint_delta = 0.8
|
||||
max_joint_delta = 0.9
|
||||
if abs_deltas[id_max_joint_delta] > max_joint_delta:
|
||||
id_mask = abs_deltas > max_joint_delta
|
||||
print()
|
||||
|
@ -175,30 +175,30 @@ def main(args):
|
|||
joints
|
||||
), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
|
||||
|
||||
max_delta = 0.05
|
||||
for _ in range(25):
|
||||
obs = env.get_obs()
|
||||
command_joints = agent.act(obs)
|
||||
current_joints = obs["joint_positions"]
|
||||
delta = command_joints - current_joints
|
||||
max_joint_delta = np.abs(delta).max()
|
||||
if max_joint_delta > max_delta:
|
||||
delta = delta / max_joint_delta * max_delta
|
||||
env.step(current_joints + delta)
|
||||
# max_delta = 0.05
|
||||
# for _ in range(25):
|
||||
# obs = env.get_obs()
|
||||
# command_joints = agent.act(obs)
|
||||
# current_joints = obs["joint_positions"]
|
||||
# delta = command_joints - current_joints
|
||||
# max_joint_delta = np.abs(delta).max()
|
||||
# if max_joint_delta > max_delta:
|
||||
# delta = delta / max_joint_delta * max_delta
|
||||
# env.step(current_joints + delta)
|
||||
|
||||
obs = env.get_obs()
|
||||
joints = obs["joint_positions"]
|
||||
action = agent.act(obs)
|
||||
if (action - joints > 0.5).any():
|
||||
print("Action is too big")
|
||||
|
||||
# print which joints are too big
|
||||
joint_index = np.where(action - joints > 0.8)
|
||||
for j in joint_index:
|
||||
print(
|
||||
f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
||||
)
|
||||
exit()
|
||||
# obs = env.get_obs()
|
||||
# joints = obs["joint_positions"]
|
||||
# action = agent.act(obs)
|
||||
# if (action - joints > 0.5).any():
|
||||
# print("Action is too big")
|
||||
#
|
||||
# # print which joints are too big
|
||||
# joint_index = np.where(action - joints > 0.8)
|
||||
# for j in joint_index:
|
||||
# print(
|
||||
# f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
|
||||
# )
|
||||
# exit()
|
||||
|
||||
if args.use_save_interface:
|
||||
from gello.data_utils.keyboard_interface import KBReset
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue