diff --git a/mimickit/engines/isaac_lab_engine.py b/mimickit/engines/isaac_lab_engine.py index c87e1a9..8649d88 100644 --- a/mimickit/engines/isaac_lab_engine.py +++ b/mimickit/engines/isaac_lab_engine.py @@ -77,6 +77,7 @@ def __init__(self, config, num_envs, device, visualize): self._env_spacing = config["env_spacing"] self._obj_cfgs = [] self._obj_control_modes = [] + self._has_body_forces = [] if ("control_mode" in config): self._control_mode = engine.ControlMode[config["control_mode"]] @@ -483,6 +484,7 @@ def set_body_forces(self, env_id, obj_id, body_id, forces): env_ids=env_id, body_ids=sim_body_id, is_global=True) + self._has_body_forces[obj_id] = True return def get_obj_torque_limits(self, env_id, obj_id): @@ -722,13 +724,37 @@ def _post_sim_step(self): return def _clear_forces(self): - for obj in self._objs: - if (obj.has_external_wrench): - forces = torch.zeros([1, 3], dtype=torch.float, device=self._device) - torques = torch.zeros([1, 3], dtype=torch.float, device=self._device) - obj.set_external_force_and_torque(forces=forces, torques=torques, - positions=None, env_ids=None, - body_ids=None, is_global=True) + num_envs = self.get_num_envs() + for obj_id, obj in enumerate(self._objs): + has_external_wrench = getattr(obj, "has_external_wrench", None) + if (has_external_wrench is None): + data = getattr(obj, "data", None) + if (data is not None): + has_external_wrench = getattr(data, "has_external_wrench", None) + + if isinstance(has_external_wrench, torch.Tensor): + has_external_wrench = bool(has_external_wrench.any().item()) + elif isinstance(has_external_wrench, np.ndarray): + has_external_wrench = bool(has_external_wrench.any()) + + if (has_external_wrench is None): + has_external_wrench = self._has_body_forces[obj_id] + + if (not has_external_wrench): + continue + + obj_type = self.get_obj_type(obj_id) + if (obj_type == engine.ObjType.articulated): + num_bodies = int(self._body_order_sim2common[obj_id].numel()) + forces = torch.zeros([num_envs, num_bodies, 3], dtype=torch.float, device=self._device) + else: + forces = torch.zeros([num_envs, 3], dtype=torch.float, device=self._device) + + torques = torch.zeros_like(forces) + obj.set_external_force_and_torque(forces=forces, torques=torques, + positions=None, env_ids=None, + body_ids=None, is_global=True) + self._has_body_forces[obj_id] = False return def _validate_envs(self): @@ -1075,6 +1101,7 @@ def _build_sim_tensors(self): num_envs = self.get_num_envs() num_objs = self.get_objs_per_env() self._objs_need_reset = torch.zeros([num_envs, num_objs], device=self._device, dtype=torch.bool) + self._has_body_forces = [False for _ in range(num_objs)] return def _filter_env_collisions(self): @@ -1147,4 +1174,4 @@ def _on_keyboard_event(self, event): if (event.input in self._keyboard_callbacks): callback = self._keyboard_callbacks[event.input] callback() - return \ No newline at end of file + return