Skip to content

Commit

Permalink
Fix bounded area
Browse files Browse the repository at this point in the history
  • Loading branch information
jackvial committed Oct 27, 2024
1 parent 565a46c commit 796ef59
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions lerobot/common/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,41 @@
# These are for the boundaries of the workspace. If the robot goes out of bounds, the episode is terminated
# and there is a negative reward. You might need to tweak these for your setup. Use `teleop_with_goals.py` to
# check rewards.
GRIPPER_TIP_Z_BOUNDS = (0.008, 0.065)
GRIPPER_TIP_X_BOUNDS = (-0.16, 0.16)
GRIPPER_TIP_Y_BOUNDS = (-0.25, -0.06)
GRIPPER_TIP_BOUNDS = np.row_stack([GRIPPER_TIP_X_BOUNDS, GRIPPER_TIP_Y_BOUNDS, GRIPPER_TIP_Z_BOUNDS])

# GRIPPER_TIP_Z_BOUNDS = (0.008, 0.065)
# GRIPPER_TIP_X_BOUNDS = (-0.16, 0.16)
# GRIPPER_TIP_Y_BOUNDS = (-0.25, -0.06)

GRIPPER_TIP_Z_BOUNDS = (-0.03, 0.065) # Extend bound under table so arm on table doesn't count as OOB

# Rotate the bounded area 90 degrees clockwise (looking at arm from behind)
GRIPPER_TIP_X_BOUNDS = (0.06, 0.25)
GRIPPER_TIP_Y_BOUNDS = (-0.16, 0.16)
GRIPPER_TIP_BOUNDS = np.row_stack([GRIPPER_TIP_X_BOUNDS, GRIPPER_TIP_Y_BOUNDS, GRIPPER_TIP_Z_BOUNDS])

def is_in_bounds(gripper_tip_pos, buffer: float | np.ndarray = 0):
print("gripper_tip_pos.shape", gripper_tip_pos.shape)
"""Check if gripper tip position is within the workspace bounds.
Args:
gripper_tip_pos: Position of gripper tip in robot base frame [x,y,z]
buffer: Additional buffer space from boundaries (can be scalar or per-axis)
Returns:
bool: True if position is within bounds, False otherwise
"""
if not isinstance(buffer, np.ndarray):
buffer = np.zeros_like(GRIPPER_TIP_BOUNDS) + buffer

for i, bounds in enumerate(GRIPPER_TIP_BOUNDS):
assert (bounds[1] - bounds[0]) > buffer[i].sum()
lower_bound_check = gripper_tip_pos[i] < bounds[0] + buffer[i][0]
# print(f"lower_bound_check: {lower_bound_check}")
upper_bound_check = gripper_tip_pos[i] > bounds[1] - buffer[i][1]
# print(f"upper_bound_check: {upper_bound_check}")
if gripper_tip_pos[i] < bounds[0] + buffer[i][0] or gripper_tip_pos[i] > bounds[1] - buffer[i][1]:
if (gripper_tip_pos[i] < bounds[0] + buffer[i][0] or
gripper_tip_pos[i] > bounds[1] - buffer[i][1]):
return False
return True




def calc_smoothness_reward(
action: np.ndarray,
prior_action: np.ndarray | None = None,
Expand Down

0 comments on commit 796ef59

Please sign in to comment.