Skip to content

Commit

Permalink
Fix single agent mobilefranka
Browse files Browse the repository at this point in the history
  • Loading branch information
ranzuh committed Apr 25, 2023
1 parent b4b05f3 commit 90d938f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 27 deletions.
2 changes: 1 addition & 1 deletion omniisaacgymenvs/cfg/task/MobileFranka.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ sim:
gpu_temp_buffer_capacity: 16777216
gpu_max_num_partitions: 8

franka:
mobile_franka:
# -1 to use default values
override_usd_defaults: False
fixed_base: False
Expand Down
4 changes: 2 additions & 2 deletions omniisaacgymenvs/cfg/train/MobileFrankaPPO.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ params:
val: 0
fixed_sigma: True
mlp:
units: [256, 128, 64]
units: [512, 256, 128] #[256, 128, 64]
activation: elu
d2rl: False

Expand Down Expand Up @@ -62,7 +62,7 @@ params:
truncate_grads: True
e_clip: 0.2
horizon_length: 16
minibatch_size: 128 #128 #1024
minibatch_size: 4096 #128 #1024
mini_epochs: 8
critic_coef: 4
clip_value: True
Expand Down
100 changes: 76 additions & 24 deletions omniisaacgymenvs/tasks/mobile_franka.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from omni.isaac.core.utils.torch.transformations import *
#from omni.isaac.core.utils.rotations import euler_angles_to_quat, quat_to_euler_angles
from omni.isaac.core.utils.torch.rotations import get_euler_xyz
from omni.isaac.core.prims import GeometryPrimView


from omni.isaac.cloner import Cloner

Expand Down Expand Up @@ -69,8 +71,16 @@ def __init__(
control_frequency = 120.0 / self._task_cfg["env"]["controlFrequencyInv"] # 30
self.dt = 1/control_frequency

self._num_observations = 30 #23
self._num_actions = 12
self._num_observations = 32 - 3 - 2 #23
self._num_actions = 11
self._num_agents = 1

self.initial_target_pos = np.array([2.0, 0.0, 0.5])

# set the ranges for the target randomization
self.x_lim = [-3, 3]
self.y_lim = [-3, 3]
self.z_lim = [0.2, 1.2]

RLTask.__init__(self, name, env)
return
Expand All @@ -79,6 +89,13 @@ def set_up_scene(self, scene) -> None:

self.get_franka()
#self.get_cabinet()
target_cube = VisualCuboid(
prim_path=self.default_zero_env_path + "/target_cube",
#position=[3.0, 0.0, 0.5],
translation=self.initial_target_pos,
scale=np.array([0.1, 0.1, 0.1]),
color=np.array([1, 0, 0]),
)

super().set_up_scene(scene, replicate_physics=False)

Expand All @@ -90,15 +107,10 @@ def set_up_scene(self, scene) -> None:
scene.add(self._mobilefrankas._lfingers)
scene.add(self._mobilefrankas._rfingers)
scene.add(self._mobilefrankas._base)
target_cube = VisualCuboid(
prim_path=self.default_zero_env_path + "/target_cube",
#position=[3.0, 0.0, 0.5],
translation=[3.0, 1.0, 0.5],
scale=np.array([0.1, 0.1, 0.1]),
color=np.array([1, 0, 0]),
)

scene.add(target_cube)
self._targets = GeometryPrimView(prim_paths_expr="/World/envs/.*/target_cube", name="target_view")
scene.add(self._targets)
#scene.add(target_cube)
#scene.add(self._cabinets)
#scene.add(self._cabinets._drawers)

Expand All @@ -107,7 +119,7 @@ def set_up_scene(self, scene) -> None:

def get_franka(self):
mobile_franka = MobileFranka(prim_path=self.default_zero_env_path + "/mobile_franka", name="mobile_franka")
self._sim_config.apply_articulation_settings("franka", get_prim_at_path(mobile_franka.prim_path), self._sim_config.parse_actor_config("franka"))
self._sim_config.apply_articulation_settings("mobile_franka", get_prim_at_path(mobile_franka.prim_path), self._sim_config.parse_actor_config("mobile_franka"))

def init_data(self) -> None:
def get_env_local_pose(env_pos, xformable, device):
Expand Down Expand Up @@ -184,6 +196,7 @@ def get_observations(self) -> dict:
# yaw is in range 0-2pi do I want it to be -pi to pi
roll, pitch, base_yaw = get_euler_xyz(base_rot)
base_yaw = base_yaw.unsqueeze(1)
#print("base_rot, base_yaw", base_rot, base_yaw)
# for rot in base_rot:
# base_rot_z.append(quat_to_euler_angles(rot)[2])
# base_rot_z = torch.tensor(base_rot_z).unsqueeze(1).to(self._device)
Expand Down Expand Up @@ -230,20 +243,26 @@ def get_observations(self) -> dict:

self.to_target = self.target_positions - self.franka_lfinger_pos

self.obs_buf = torch.hstack((
obs = torch.hstack((
base_pos_xy,
base_yaw,
arm_dof_pos_scaled,
base_vel_xy,
base_angvel_z,
#base_vel_xy,
#base_angvel_z,
franka_dof_vel[:, 3:] * self.dof_vel_scale,
self.franka_lfinger_pos,
self.target_positions
)).to(dtype=torch.float32)

#print("obs", obs)
#input()

#print(obs)
#print(obs.shape)
self.obs_buf = obs

#input()

#print(self.obs_buf)
#print(self.obs_buf.shape)
#input()

#print("rotation", rot)
Expand Down Expand Up @@ -293,8 +312,29 @@ def pre_physics_step(self, actions) -> None:
if len(reset_env_ids) > 0:
self.reset_idx(reset_env_ids)

self.actions = actions.clone().to(self._device)
targets = self.franka_dof_targets + self.franka_dof_speed_scales * self.dt * self.actions * self.action_scale
raw_actions = actions.clone().to(self._device)

base_actions = raw_actions[:, :2]
arm_actions = raw_actions[:, 2:]


combined_actions = torch.hstack((
base_actions[:,0].unsqueeze(1),
torch.zeros((base_actions.shape[0], 1), device=self._device),
base_actions[:,1].unsqueeze(1),
arm_actions
))

# print("actions", actions.shape)
# print("base_actions", base_actions.shape)
# print("arm_actions", arm_actions.shape)
# print("combined_actions", combined_actions.shape)
# print(combined_actions)
# input()

self.actions = combined_actions

targets = self.franka_dof_targets + self.franka_dof_speed_scales * self.dt * combined_actions * self.action_scale # * 0.1
self.franka_dof_targets[:] = torch.clamp(targets, self.franka_dof_lower_limits, self.franka_dof_upper_limits)
env_ids_int32 = torch.arange(self._mobilefrankas.count, dtype=torch.int32, device=self._device)

Expand All @@ -303,14 +343,14 @@ def pre_physics_step(self, actions) -> None:
#self.actions[:, 2] = 0.5 # angular z

# TODO make the scaling values part of configs
action_x = self.actions[:, 0] * 1.0
action_x = combined_actions[:, 0] * 1.0 # * 0.5
action_y = torch.zeros(self._mobilefrankas.count, device=self._device)
action_yaw = self.actions[:, 2] * 0.75
action_yaw = combined_actions[:, 2] * 0.75 # * 0.5

vel_targets = self._calculate_velocity_targets(action_x, action_y, action_yaw)

# set the position targets for base joints to the current position
#self.franka_dof_targets[:, :3] = self.franka_dof_pos[:, :3]
self.franka_dof_targets[:, :3] = self.franka_dof_pos[:, :3]
#print("self.franka_dof_targets", self.franka_dof_targets)
artic_vel_targets = torch.zeros_like(self.franka_dof_targets)
artic_vel_targets[:, :3] = vel_targets
Expand Down Expand Up @@ -341,17 +381,29 @@ def reset_idx(self, env_ids):
self.franka_dof_lower_limits,
self.franka_dof_upper_limits,
)
# randomize the yaw from 0 to 360 in degrees
pos[:, 2] = torch.rand((len(env_ids),), device=self._device) * 359.0
dof_pos = torch.zeros((num_indices, self._mobilefrankas.num_dof), device=self._device)
dof_vel = torch.zeros((num_indices, self._mobilefrankas.num_dof), device=self._device)
dof_pos[:, :] = pos
self.franka_dof_targets[env_ids, :] = pos
self.franka_dof_pos[env_ids, :] = pos

#print(self.franka_dof_targets[env_ids])
self._mobilefrankas.set_joint_position_targets(self.franka_dof_targets[env_ids], indices=indices)
self._mobilefrankas.set_joint_positions(dof_pos, indices=indices)
self._mobilefrankas.set_joint_velocities(dof_vel, indices=indices)

self.target_positions[:] = torch.tensor([3.0, 1.0, 0.5])
#self.target_positions[:] = torch.tensor(self.initial_target_pos, device=self._device)
rands = torch.rand((num_indices, 3), device=self._device)

# modify rands to be in the range of the limits
rands[:, 0] = rands[:, 0] * (self.x_lim[1] - self.x_lim[0]) + self.x_lim[0]
rands[:, 1] = rands[:, 1] * (self.y_lim[1] - self.y_lim[0]) + self.y_lim[0]
rands[:, 2] = rands[:, 2] * (self.z_lim[1] - self.z_lim[0]) + self.z_lim[0]

self.target_positions[env_ids] = rands
self._targets.set_world_poses(self._env_pos + self.target_positions)

# bookkeeping
self.reset_buf[env_ids] = 0
Expand Down Expand Up @@ -401,7 +453,7 @@ def calculate_metrics(self) -> None:
#print("penalty_joint_limit", penalty_joint_limit)

reward = torch.zeros_like(self.rew_buf)
reward = reward - self.action_penalty_scale * action_penalty - 0.2 * distance_to_target - 0.02 * penalty_joint_limit
reward = reward - self.action_penalty_scale * action_penalty - 0.2 * distance_to_target - 0.03 * penalty_joint_limit
#print("action penalty", action_penalty, "scaled", self.action_penalty_scale * action_penalty)
#print("distance", distance_to_target, "scaled", 0.01 * distance_to_target)
#print("reward", reward)
Expand All @@ -411,7 +463,7 @@ def _joint_limit_penalty(self, values):
# neutral position of joints
neutral = torch.tensor([0,0,0,-1.5,0,2.0,0], device=self._device)
# weights for each joint how much to penalize them incase they differ a lot from neutral
weights = torch.tensor([1.0, 1, 1.5, 1, 1, 2.0, 1], device=self._device)
weights = torch.tensor([1.5, 1, 1.5, 1, 1, 2.0, 1], device=self._device)
return torch.sum(torch.abs(values-neutral) * weights, axis=1)

def is_done(self) -> None:
Expand Down

0 comments on commit 90d938f

Please sign in to comment.