From ec1faf0faf727d511c7d599e436f54a006dd6bab Mon Sep 17 00:00:00 2001 From: Krishnan Date: Wed, 17 May 2023 19:17:36 -0400 Subject: [PATCH 1/2] added doubleqcriticmlp and options for critic networks in shac.yaml --- scripts/cfg/alg/shac.yaml | 1 + scripts/cfg/alg/shac2.yaml | 17 ++- scripts/cfg/config.yaml | 23 +++- scripts/cfg/env/hopper.yaml | 18 +++ src/shac/algorithms/shac.py | 139 ++++++++++++++++------- src/shac/algorithms/shac2.py | 206 +++++++++++++++++++++++----------- src/shac/envs/hopper.py | 3 + src/shac/models/critic.py | 18 ++- src/shac/utils/hydra_utils.py | 1 + 9 files changed, 311 insertions(+), 115 deletions(-) diff --git a/scripts/cfg/alg/shac.yaml b/scripts/cfg/alg/shac.yaml index dc5f899f..215ad2f3 100644 --- a/scripts/cfg/alg/shac.yaml +++ b/scripts/cfg/alg/shac.yaml @@ -34,6 +34,7 @@ params: num_actors: ${env.config.num_envs} # ${resolve_default:64,${..env.config.num_envs}} save_interval: 400 # ${resolve_default:400,${..env.save_interval}} contact_truncation: False + min_steps_before_truncation: 4 player: determenistic: True diff --git a/scripts/cfg/alg/shac2.yaml b/scripts/cfg/alg/shac2.yaml index 436d8403..65c08971 100644 --- a/scripts/cfg/alg/shac2.yaml +++ b/scripts/cfg/alg/shac2.yaml @@ -9,8 +9,23 @@ params: units: ${env.shac2.actor_mlp.units} activation: elu + # critic_name: double_q_network + # critic: + # _target_: rl_games.algos_torch.network_builder.DoubleQCritic + # units: ${env.shac2.critic_mlp.units} + # activation: elu + # norm_func_name: layer_norm + # d2rl: True + # critic_name: q_network + # critic: + # _target_: shac.models.critic.QCriticMLP + # cfg_network: + # critic_mlp: + # units: ${env.shac2.critic_mlp.units} + # activation: elu + critic_name: value_network critic: - _target_: shac.models.critic.QCriticMLP + _target_: shac.models.critic.CriticMLP cfg_network: critic_mlp: units: ${env.shac2.critic_mlp.units} diff --git a/scripts/cfg/config.yaml b/scripts/cfg/config.yaml index 7995e7a9..145490f3 100644 --- a/scripts/cfg/config.yaml +++ b/scripts/cfg/config.yaml @@ -2,7 +2,6 @@ defaults: - _self_ - env: cartpole - alg: shac - # - override hydra/launcher: submitit_local hydra: @@ -25,7 +24,7 @@ general: no_time_stamp: False render: False device: cuda:0 - run_wandb: True + run_wandb: False seed: 42 train: True checkpoint: @@ -80,7 +79,25 @@ env: - 0.7 - 0.95 - shac2: ${.shac} + shac2: + lambda: 0.95 + actor_mlp: + units: + - 64 + - 64 + critic_mlp: + units: + - 64 + - 64 + target_critic_alpha: 0.4 + actor_lr: 1e-3 + critic_lr: 1e-3 + max_epochs: 2000 + save_interval: 400 + steps_num: 32 + betas: + - 0.7 + - 0.95 wandb: project: shac diff --git a/scripts/cfg/env/hopper.yaml b/scripts/cfg/env/hopper.yaml index 9ce06d95..b24f4910 100644 --- a/scripts/cfg/env/hopper.yaml +++ b/scripts/cfg/env/hopper.yaml @@ -32,6 +32,24 @@ shac: - 64 save_interval: 400 +shac2: + actor_lr: 2e-3 + critic_lr: 4e-3 + max_epochs: 500 + betas: + - 0.7 + - 0.95 + actor_mlp: + units: + - 128 + - 64 + - 32 + critic_mlp: + units: + - 64 + - 64 + save_interval: 400 + ppo: max_epochs: 2000 minibatch_size: 8192 diff --git a/src/shac/algorithms/shac.py b/src/shac/algorithms/shac.py index 3f8e447a..c5c5367e 100644 --- a/src/shac/algorithms/shac.py +++ b/src/shac/algorithms/shac.py @@ -22,6 +22,7 @@ import shac.models.critic as critic_models from shac.utils.common import * import shac.utils.torch_utils as tu +from rl_games.algos_torch.network_builder import DoubleQCritic from shac.utils.running_mean_std import RunningMeanStd from shac.utils.dataset import CriticDataset, QCriticDataset from shac.utils.time_report import TimeReport @@ -98,6 +99,7 @@ def __init__(self, cfg): self.truncate_grad = cfg["params"]["config"]["truncate_grads"] self.grad_norm = cfg["params"]["config"]["grad_norm"] self.contact_truncation = cfg["params"]["config"].get("contact_truncation", False) + self.min_steps_from_truncation = cfg["params"]["config"].get("min_steps_before_truncation", 4) if cfg["params"]["general"]["train"]: self.log_dir = cfg["params"]["general"]["logdir"] @@ -134,9 +136,11 @@ def __init__(self, cfg): self.actor = actor_fn(self.num_obs, self.num_actions, cfg["params"]["network"], device=self.device) critic_fn = getattr(critic_models, self.critic_name) if self.critic_name == "CriticMLP": - self.critic = critic_fn(self.num_obs, cfg["params"]["network"], device=self.device) - else: - self.critic = critic_fn(self.num_obs, self.num_actions, cfg["params"]["network"], device=self.device) + input_size = self.num_obs + elif self.critic_name == "QCriticMLP" or self.critic_name == "DoubleQCriticMLP": + input_size = self.num_obs + self.num_actions + self.critic = critic_fn(input_size, cfg["params"]["network"], device=self.device) + self.all_params = list(self.actor.parameters()) + list(self.critic.parameters()) self.target_critic = copy.deepcopy(self.critic) @@ -161,12 +165,18 @@ def __init__(self, cfg): dtype=torch.float32, device=self.device, ) - if self.critic_name == "QCriticMLP": + if self.critic_name.endswith("QCriticMLP"): self.act_buf = torch.zeros( (self.steps_num, self.num_envs, self.num_actions), dtype=torch.float32, device=self.device, ) + # self.next_obs_buf = torch.zeros( + # (self.steps_num, self.num_envs, self.num_obs), + # dtype=torch.float32, + # device=self.device, + # ) + self.rew_buf = torch.zeros((self.steps_num, self.num_envs), dtype=torch.float32, device=self.device) self.done_mask = torch.zeros((self.steps_num, self.num_envs), dtype=torch.float32, device=self.device) self.next_values = torch.zeros((self.steps_num, self.num_envs), dtype=torch.float32, device=self.device) @@ -225,6 +235,18 @@ def __init__(self, cfg): # timer self.time_report = TimeReport() + def compute_values(self, obs): + """Compute values for the given observations with target critic.""" + if self.critic_name == "CriticMLP": + values = self.target_critic(obs).squeeze(-1) + elif self.critic_name == "QCriticMLP": + action = torch.tanh(self.actor(obs, deterministic=True)) + values = self.target_critic(obs, action).squeeze(-1) + elif self.critic_name == "DoubleQCriticMLP": + action = torch.tanh(self.actor(obs, deterministic=True)) + values = torch.minimum(*self.target_critic(obs, action)).squeeze(-1) + return values + def compute_actor_loss(self, deterministic=False): rew_acc = torch.zeros((self.steps_num + 1, self.num_envs), dtype=torch.float32, device=self.device) gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device) @@ -240,26 +262,31 @@ def compute_actor_loss(self, deterministic=False): ret_var = self.ret_rms.var.clone() # initialize trajectory to cut off gradients between episodes. - obs = self.env.initialize_trajectory() + next_obs = self.env.initialize_trajectory() if self.obs_rms is not None: # update obs rms with torch.no_grad(): - self.obs_rms.update(obs) + self.obs_rms.update(next_obs) # normalize the current obs - obs = obs_rms.normalize(obs) + next_obs = obs_rms.normalize(next_obs) # collect data for critic training for i in range(self.steps_num): + obs = next_obs with torch.no_grad(): self.obs_buf[i] = obs.clone() actions = self.actor(obs, deterministic=deterministic) - if self.critic_name == "QCriticMLP": + if self.critic_name.endswith("QCriticMLP"): with torch.no_grad(): self.act_buf[i] = actions.clone() - obs, rew, done, extra_info = self.env.step(torch.tanh(actions)) + next_obs, rew, done, extra_info = self.env.step(torch.tanh(actions)) + + # if self.critic_name.endswith("QCriticMLP"): + # with torch.no_grad(): + # self.next_obs_buf[i] = obs.clone() with torch.no_grad(): raw_rew = rew.clone() @@ -270,9 +297,9 @@ def compute_actor_loss(self, deterministic=False): if self.obs_rms is not None: # update obs rms with torch.no_grad(): - self.obs_rms.update(obs) + self.obs_rms.update(next_obs) # normalize the current obs - obs = obs_rms.normalize(obs) + next_obs = obs_rms.normalize(next_obs) if self.ret_rms is not None: # update ret rms @@ -285,17 +312,16 @@ def compute_actor_loss(self, deterministic=False): self.episode_length += 1 trunc = extra_info.get("contact_changed", torch.zeros_like(done)) - if not self.contact_truncation: - is_done = done.clone() + trunc_on_contact = extra_info.get("num_contact_changed", torch.zeros_like(done)) >= 1 + if self.contact_truncation: + is_done = done.clone() | trunc_on_contact + trunc_env_ids = trunc_on_contact.nonzero(as_tuple=False).squeeze(-1) else: - is_done = done.clone() # | trunc + is_done = done.clone() + trunc_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1) done_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1) - if self.critic_name == "CriticMLP": - next_values[i + 1] = self.target_critic(obs).squeeze(-1) - else: - next_action = torch.tanh(self.actor(obs, deterministic=True)) - next_values[i + 1] = self.target_critic(obs, next_action).squeeze(-1) + next_values[i + 1] = self.compute_values(next_obs) for id in done_env_ids: if ( @@ -304,21 +330,34 @@ def compute_actor_loss(self, deterministic=False): or (torch.abs(extra_info["obs_before_reset"][id]) > 1e6).sum() > 0 ): # ugly fix for nan values next_values[i + 1, id] = 0.0 - elif self.episode_length[id] < self.max_episode_length and ( - not self.contact_truncation or not trunc[id] - ): + elif self.episode_length[id] < self.max_episode_length: + # set values to 0 due to early termination next_values[i + 1, id] = 0.0 - else: # otherwise, use terminal value critic to estimate the long-term performance + else: + # otherwise, use terminal value critic to estimate the long-term performance if self.obs_rms is not None: real_obs = obs_rms.normalize(extra_info["obs_before_reset"][id]) else: real_obs = extra_info["obs_before_reset"][id] - - if self.critic_name == "CriticMLP": - next_values[i + 1, id] = self.target_critic(real_obs).squeeze(-1) + # if truncating on contact, compute model-free value + # and at least min_steps have passed since last truncation/short horizon + last_zero = self.rew_buf[:i, id].eq(0).nonzero(as_tuple=False).squeeze(-1) + if len(last_zero) == 0: + last_zero = i else: - real_act = torch.tanh(self.actor(real_obs, deterministic=True)) - next_values[i + 1, id] = self.target_critic(real_obs, real_act).squeeze(-1) + last_zero = last_zero[-1] + + if ( + self.contact_truncation + and id in trunc_env_ids + and i - last_zero >= self.min_steps_from_truncation + ): + real_obs = real_obs.detach() + + next_values[i + 1, id] = self.compute_values(real_obs) + + is_done = done.clone() + done_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1) if (next_values[i + 1] > 1e6).sum() > 0 or (next_values[i + 1] < -1e6).sum() > 0: print("next value error") @@ -338,9 +377,8 @@ def compute_actor_loss(self, deterministic=False): gamma = gamma * self.gamma # clear up gamma and rew_acc for done envs - gamma[done_env_ids] = 1.0 - rew_acc[i + 1, done_env_ids] = 0.0 - # done_env_ids = done.nonzero(as_tuple=False).squeeze(-1) + gamma[trunc_env_ids] = 1.0 + rew_acc[i + 1, trunc_env_ids] = 0.0 # collect data for critic training with torch.no_grad(): @@ -471,12 +509,16 @@ def compute_target_values(self): raise NotImplementedError def compute_critic_loss(self, batch_sample): - if self.critic_name == "CriticMLP": - predicted_values = self.critic(batch_sample["obs"]).squeeze(-1) - else: - predicted_values = self.critic(batch_sample["obs"], batch_sample["act"]).squeeze(-1) target_values = batch_sample["target_values"] - critic_loss = ((predicted_values - target_values) ** 2).mean() + if self.critic_name == "DoubleQCriticMLP": + q1, q2 = self.critic(batch_sample["obs"], batch_sample["act"]) + critic_loss = ((q1 - target_values) ** 2).mean() + ((q2 - target_values) ** 2).mean() + else: + if self.critic_name == "CriticMLP": + predicted_values = self.critic(batch_sample["obs"]).squeeze(-1) + elif self.critic_name == "QCriticMLP": + predicted_values = self.critic(batch_sample["obs"], batch_sample["act"]).squeeze(-1) + critic_loss = ((predicted_values - target_values) ** 2).mean() return critic_loss @@ -576,7 +618,11 @@ def actor_closure(): dataset = CriticDataset(self.batch_size, self.obs_buf, self.target_values, drop_last=False) else: dataset = QCriticDataset( - self.batch_size, self.obs_buf, self.act_buf, self.target_values, drop_last=False + self.batch_size, + self.obs_buf, + self.act_buf, + self.target_values, + drop_last=False, ) self.time_report.end_timer("prepare critic dataset") @@ -733,16 +779,29 @@ def play(self, cfg): def save(self, filename=None): if filename is None: filename = "best_policy" + # double Q critic must be saved as state dict to be pickleable + if self.critic_name == "DoubleQCriticMLP": + print("saving critic state dict") + critic = self.critic.state_dict() + target_critic = self.target_critic.state_dict() + else: + critic = self.critic + target_critic = self.target_critic torch.save( - [self.actor, self.critic, self.target_critic, self.obs_rms, self.ret_rms], + [self.actor, critic, target_critic, self.obs_rms, self.ret_rms], os.path.join(self.log_dir, "{}.pt".format(filename)), ) def load(self, path): checkpoint = torch.load(path) self.actor = checkpoint[0].to(self.device) - self.critic = checkpoint[1].to(self.device) - self.target_critic = checkpoint[2].to(self.device) + # double Q critic must be loaded as state dict to be pickleable + if self.critic_name == "DoubleQCriticMLP": + self.critic.load_state_dict(checkpoint[1]) + self.target_critic.load_state_dict(checkpoint[2]) + else: + self.critic = checkpoint[1].to(self.device) + self.target_critic = checkpoint[2].to(self.device) self.obs_rms = checkpoint[3].to(self.device) self.ret_rms = checkpoint[4].to(self.device) if checkpoint[4] is not None else checkpoint[4] diff --git a/src/shac/algorithms/shac2.py b/src/shac/algorithms/shac2.py index dfa7d47e..9ff34d90 100644 --- a/src/shac/algorithms/shac2.py +++ b/src/shac/algorithms/shac2.py @@ -18,7 +18,7 @@ from rl_games.algos_torch import torch_ext from shac.utils.average_meter import AverageMeter from shac.utils.common import * -from shac.utils.dataset import QCriticDataset +from shac.utils.dataset import QCriticDataset, CriticDataset from shac.utils.running_mean_std import RunningMeanStd from shac.utils.time_report import TimeReport import shac.utils.torch_utils as tu @@ -137,13 +137,31 @@ def __init__(self, cfg: DictConfig): action_dim=self.num_actions, device=self.device, ) - self.critic = instantiate( - cfg.alg.params.network.critic, - obs_dim=self.num_obs, - action_dim=self.num_actions, - device=self.device, - ) + + self.critic_name = cfg.alg.params.network.critic_name + if self.critic_name == "q_network": + self.critic = instantiate( + cfg.alg.params.network.critic, + obs_dim=self.num_obs, + action_dim=self.num_actions, + device=self.device, + ) + elif self.critic_name == "value_network": + self.critic = instantiate( + cfg.alg.params.network.critic, + obs_dim=self.num_obs, + device=self.device, + ) + elif self.alg.params.critic_name == "double_q_network": + self.critic = instantiate( + cfg.alg.params.network.critic, + input_size=self.num_obs + self.num_actions, + dense_func=torch.nn.Linear, + output_dim=1, + ).to(self.device) + self.all_params = list(self.actor.parameters()) + list(self.critic.parameters()) + __import__("ipdb").set_trace() self.target_critic = copy.deepcopy(self.critic) # initialize optimizer @@ -305,47 +323,77 @@ def compute_actor_loss(self, deterministic=False): self.episode_length += 1 trunc = extra_info.get("contact_changed", torch.zeros_like(done)) + is_done = done.clone() + done_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1) + if self.contact_truncation: - is_done = done.clone() | trunc + is_done = is_done | trunc + trunc_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1) else: - is_done = done.clone() - done_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1) + trunc_env_ids = done_env_ids next_actions = torch.tanh(self.actor(obs, deterministic=True)) - next_values[i + 1] = self.target_critic(obs, next_actions).squeeze(-1) - next_values_model_free[i + 1] = self.target_critic(obs.detach(), next_actions).squeeze(-1) - # next_values_model_free[i + 1] = torch.minimum( - # self.critic(obs.detach(), next_actions).squeeze(-1), - # self.target_critic(obs.detach(), next_actions).squeeze(-1), - # ) + if self.critic_name == "q_network": + val = self.target_critic(obs, next_actions).squeeze(-1) + val_model_free = self.target_critic(obs.detach(), next_actions).squeeze(-1) + elif self.critic_name == "double_q_network": + val = torch.minimum(*self.target_critic(obs, next_actions)).squeeze(-1) + val_model_free = torch.minimum(*self.target_critic(obs.detach(), next_actions)).squeeze(-1) + elif self.critic_name == "value_network": + val = self.target_critic(obs).squeeze(-1) + val_model_free = self.target_critic(obs.detach()).squeeze(-1) + next_values[i + 1] = val + next_values_model_free[i + 1] = val_model_free # zero next_values for done envs with inf, nan, or >1e6 values in obs_before_reset # or early termination - if done_env_ids.shape[0] > 0: - zero_next_values = torch.where( - torch.isnan(extra_info["obs_before_reset"][done_env_ids]).any(dim=-1) - | torch.isinf(extra_info["obs_before_reset"][done_env_ids]).any(dim=-1) - | (torch.abs(extra_info["obs_before_reset"][done_env_ids]) > 1e6).any(dim=-1) - | (self.episode_length[done_env_ids] < self.max_episode_length), - torch.ones_like(done_env_ids, dtype=bool), - torch.zeros_like(done_env_ids, dtype=bool), - ) - zero_next_values, assign_next_values = done_env_ids[zero_next_values], done_env_ids[~zero_next_values] - if zero_next_values.shape[0] > 0: - next_values[i + 1, zero_next_values] = 0.0 - next_values_model_free[i + 1, zero_next_values] = 0.0 - # use terminal value critic to estimate the long-term performance - if assign_next_values.shape[0] > 0: + # if done_env_ids.shape[0] > 0: + # zero_next_values = torch.where( + # torch.isnan(extra_info["obs_before_reset"][done_env_ids]).any(dim=-1) + # | torch.isinf(extra_info["obs_before_reset"][done_env_ids]).any(dim=-1) + # | (torch.abs(extra_info["obs_before_reset"][done_env_ids]) > 1e6).any(dim=-1) + # | (self.episode_length[done_env_ids] < self.max_episode_length), + # torch.ones_like(done_env_ids, dtype=bool), + # torch.zeros_like(done_env_ids, dtype=bool), + # ) + # zero_next_values, assign_next_values = done_env_ids[zero_next_values], done_env_ids[~zero_next_values] + # if zero_next_values.shape[0] > 0: + # next_values[i + 1, zero_next_values] = 0.0 + # next_values_model_free[i + 1, zero_next_values] = 0.0 + # # use terminal value critic to estimate the long-term performance + # if assign_next_values.shape[0] > 0: + # if self.obs_rms is not None: + # real_obs = obs_rms.normalize(extra_info["obs_before_reset"][assign_next_values]) + # else: + # real_obs = extra_info["obs_before_reset"][assign_next_values] + # real_act = torch.tanh(self.actor(real_obs, deterministic=True)) + # next_values[i + 1, assign_next_values] = torch.minimum(*self.critic(real_obs, real_act)).squeeze(-1) + # next_values_model_free[i + 1, assign_next_values] = torch.minimum( + # *self.critic(real_obs.detach(), real_act) + # ).squeeze(-1) + + for id in done_env_ids: + if ( + torch.isnan(extra_info["obs_before_reset"][id]).sum() > 0 + or torch.isinf(extra_info["obs_before_reset"][id]).sum() > 0 + or (torch.abs(extra_info["obs_before_reset"][id]) > 1e6).sum() > 0 + ): # ugly fix for nan values + next_values[i + 1, id] = 0.0 + elif self.episode_length[id] < self.max_episode_length: + next_values[i + 1, id] = 0.0 + else: # otherwise, use terminal value critic to estimate the long-term performance if self.obs_rms is not None: - real_obs = obs_rms.normalize(extra_info["obs_before_reset"][assign_next_values]) + real_obs = obs_rms.normalize(extra_info["obs_before_reset"][id]) else: - real_obs = extra_info["obs_before_reset"][assign_next_values] - real_act = torch.tanh(self.actor(real_obs, deterministic=True)) - next_values[i + 1, assign_next_values] = self.critic(real_obs, real_act).squeeze(-1) - next_values_model_free[i + 1, assign_next_values] = self.critic( - real_obs.detach(), real_act - ).squeeze(-1) + real_obs = extra_info["obs_before_reset"][id] + + if self.critic_name == "CriticMLP": + next_values[i + 1, id] = self.target_critic(real_obs).squeeze(-1) + else: + real_act = torch.tanh(self.actor(real_obs, deterministic=True)) + next_values[i + 1, id] = self.target_critic(real_obs, real_act).squeeze(-1) + if (next_values[i + 1] > 1e6).sum() > 0 or (next_values[i + 1] < -1e6).sum() > 0: print("next value error") if self.multi_gpu: @@ -354,29 +402,43 @@ def compute_actor_loss(self, deterministic=False): rew_acc[i + 1, :] = rew_acc[i, :] + gamma * rew + # if i < self.steps_num - 1: + # if len(done_env_ids) > 0 and (self.rollout_lens[done_env_ids] >= self.min_steps).any(): + # done_steps = self.rollout_lens[done_env_ids] >= self.min_steps + # + # next_value = torch.where( + # done_steps, + # next_values_model_free[i + 1, done_env_ids], + # next_values[i + 1, done_env_ids], + # ) + # else: + # next_value = next_values[i + 1, done_env_ids] + # actor_loss += (-rew_acc[i + 1, done_env_ids] - self.gamma * gamma[done_env_ids] * next_value).sum() + # # actor_model_free_loss += ( + # # -rew_acc[i + 1, done_env_ids].detach() + # # - self.gamma * gamma[done_env_ids] * next_values_model_free[i + 1, done_env_ids] + # # ).sum() + # else: + # # terminate all envs at the end of optimization iteration + # actor_loss += (-rew_acc[i + 1, :] - self.gamma * gamma * next_values[i + 1, :]).sum() + # # actor_model_free_loss += ( + # # -rew_acc[i + 1, :].detach() - self.gamma * gamma * next_values_model_free[i + 1, :] + # # ).sum() + if i < self.steps_num - 1: - if self.rollout_len[done_env_ids] == self.min_steps: - next_value = next_values_model_free[i + 1, done_env_ids] - else: - next_value = next_values[i + 1, done_env_ids] - actor_loss += (-rew_acc[i + 1, done_env_ids] - self.gamma * gamma[done_env_ids] * next_value).sum() - # actor_model_free_loss += ( - # -rew_acc[i + 1, done_env_ids].detach() - # - self.gamma * gamma[done_env_ids] * next_values_model_free[i + 1, done_env_ids] - # ).sum() + actor_loss += ( + -rew_acc[i + 1, done_env_ids] - self.gamma * gamma[done_env_ids] * next_values[i + 1, done_env_ids] + ).sum() else: # terminate all envs at the end of optimization iteration actor_loss += (-rew_acc[i + 1, :] - self.gamma * gamma * next_values[i + 1, :]).sum() - # actor_model_free_loss += ( - # -rew_acc[i + 1, :].detach() - self.gamma * gamma * next_values_model_free[i + 1, :] - # ).sum() # compute gamma for next step gamma = gamma * self.gamma # clear up gamma and rew_acc for done envs - gamma[done_env_ids] = 1.0 - rew_acc[i + 1, done_env_ids] = 0.0 + gamma[trunc_env_ids] = 1.0 + rew_acc[i + 1, trunc_env_ids] = 0.0 # collect data for critic training with torch.no_grad(): @@ -436,7 +498,7 @@ def compute_actor_loss(self, deterministic=False): idx = np.append(idx, len(rew_acc_) - 1) rollout_lens = np.append(rollout_lens, np.diff(idx)) # rollout_lens is a 1D array of lenghts of each rollout - self.rollout_len = np.mean(rollout_lens) + self.rollout_lens = tu.to_torch(rollout_lens, dtype=float) return actor_loss @torch.no_grad() @@ -514,9 +576,16 @@ def compute_target_values(self): raise NotImplementedError def compute_critic_loss(self, batch_sample): - predicted_values = self.critic(batch_sample["obs"], batch_sample["act"]).squeeze(-1) target_values = batch_sample["target_values"] - critic_loss = ((predicted_values - target_values) ** 2).mean() + if self.critic_name == "double_q_network": + q1, q2 = self.critic(batch_sample["obs"], batch_sample["act"]) + critic_loss = ((q1 - target_values) ** 2).mean() + ((q2 - target_values) ** 2).mean() + elif self.critic_name == "q_network": + q = self.critic(batch_sample["obs"], batch_sample["act"]) + critic_loss = ((q - target_values) ** 2).mean() + elif self.critic_name == "value_network": + v = self.critic(batch_sample["obs"]) + critic_loss = ((v - target_values) ** 2).mean() return critic_loss @@ -572,7 +641,7 @@ def train(self): self.episode_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self.episode_discounted_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self.episode_length = torch.zeros(self.num_envs, dtype=int, device=self.device) - self.rollout_len = torch.zeros(self.num_envs, dtype=int, device=self.device) + self.rollout_lens = torch.zeros(self.num_envs, dtype=float, device=self.device) self.episode_gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device) def actor_closure(): @@ -671,13 +740,16 @@ def actor_closure(): self.time_report.start_timer("prepare critic dataset") with torch.no_grad(): self.compute_target_values() - dataset = QCriticDataset( - self.batch_size, - self.obs_buf, - self.act_buf, - self.target_values, - drop_last=False, - ) + if self.critic_name == "value_network": + dataset = CriticDataset(self.batch_size, self.obs_buf, self.target_values, drop_last=False) + else: + dataset = QCriticDataset( + self.batch_size, + self.obs_buf, + self.act_buf, + self.target_values, + drop_last=False, + ) # compute KL divergence of the current policy self.ep_kls.append( torch_ext.policy_kl( @@ -795,9 +867,11 @@ def actor_closure(): self.writer.add_scalar("ac_std/iter", ac_stddev, self.iter_count) self.writer.add_scalar("ac_std/step", ac_stddev, self.step_count) self.writer.add_scalar("ac_std/time", ac_stddev, time_elapse) - self.writer.add_scalar("rollout_len/step", self.rollout_len, self.step_count) - self.writer.add_scalar("rollout_len/iter", self.rollout_len, self.iter_count) + rollout_len_mean = self.rollout_lens.mean().cpu().item() + self.writer.add_scalar("rollout_len/step", rollout_len_mean, self.step_count) + self.writer.add_scalar("rollout_len/iter", rollout_len_mean, self.iter_count) else: + rollout_len_mean = self.rollout_lens.mean().cpu().item() mean_policy_loss = np.inf mean_policy_discounted_loss = np.inf mean_episode_length = 0 @@ -809,7 +883,7 @@ def actor_closure(): self.iter_count, mean_policy_loss, mean_policy_discounted_loss, - self.rollout_len, + rollout_len_mean, mean_episode_length, self.steps_num * self.num_envs * self.rank_size / (time_end_epoch - time_start_epoch), self.value_loss, diff --git a/src/shac/envs/hopper.py b/src/shac/envs/hopper.py index 07b0f40f..41cca570 100644 --- a/src/shac/envs/hopper.py +++ b/src/shac/envs/hopper.py @@ -205,7 +205,9 @@ def step(self, actions): self.MM_caching_frequency, ) contact_changed = next_state.contact_changed.clone() != self.state.contact_changed.clone() + num_contact_changed = next_state.contact_changed.clone() - self.state.contact_changed.clone() contact_changed = contact_changed.view(self.num_envs, -1).any(dim=1) + num_contact_changed = num_contact_changed.view(self.num_envs, -1).sum(dim=1) self.state = next_state self.sim_time += self.sim_dt @@ -227,6 +229,7 @@ def step(self, actions): } self.extras["contact_changed"] = contact_changed + self.extras["num_contact_changed"] = num_contact_changed if len(env_ids) > 0: self.reset(env_ids) diff --git a/src/shac/models/critic.py b/src/shac/models/critic.py index 231cfc0d..95872225 100644 --- a/src/shac/models/critic.py +++ b/src/shac/models/critic.py @@ -10,6 +10,17 @@ import numpy as np from shac.models import model_utils +from rl_games.algos_torch.network_builder import DoubleQCritic + + +def DoubleQCriticMLP(input_size, cfg_network, device="cuda:0"): + units = cfg_network["critic_mlp"]["units"] + activation = cfg_network["critic_mlp"]["activation"] + critic = DoubleQCritic(1, input_size=input_size, units=units, activation=activation, dense_func=torch.nn.Linear).to( + device + ) + print(critic) + return critic class CriticMLP(nn.Module): @@ -40,12 +51,12 @@ def forward(self, observations): class QCriticMLP(nn.Module): - def __init__(self, obs_dim, action_dim, cfg_network, device="cuda:0"): + def __init__(self, input_size, cfg_network, device="cuda:0"): super(QCriticMLP, self).__init__() self.device = device - self.layer_dims = [obs_dim + action_dim] + cfg_network["critic_mlp"]["units"] + [1] + self.layer_dims = [input_size] + cfg_network["critic_mlp"]["units"] + [1] init_ = lambda m: model_utils.init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)) @@ -58,9 +69,6 @@ def __init__(self, obs_dim, action_dim, cfg_network, device="cuda:0"): self.q_function = nn.Sequential(*modules).to(device) - self.obs_dim = obs_dim - self.action_dim = action_dim - print(self.q_function) def forward(self, observations, actions): diff --git a/src/shac/utils/hydra_utils.py b/src/shac/utils/hydra_utils.py index 504457ff..54be697f 100644 --- a/src/shac/utils/hydra_utils.py +++ b/src/shac/utils/hydra_utils.py @@ -1,4 +1,5 @@ import numpy as np +import torch from omegaconf import OmegaConf, DictConfig from typing import Dict From 91ea192888fa375869c5ea2e2f14ab42a0a6e34c Mon Sep 17 00:00:00 2001 From: krishnan Date: Wed, 7 Jun 2023 10:56:17 -0700 Subject: [PATCH 2/2] updated shac2 and shac to conform to warp envs syntax --- src/shac/algorithms/shac.py | 5 ++--- src/shac/algorithms/shac2.py | 6 +++--- src/shac/envs/__init__.py | 5 +++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/shac/algorithms/shac.py b/src/shac/algorithms/shac.py index c5c5367e..764ba782 100644 --- a/src/shac/algorithms/shac.py +++ b/src/shac/algorithms/shac.py @@ -46,7 +46,6 @@ def __init__(self, cfg): seed=cfg["params"]["general"]["seed"], episode_length=cfg["params"]["diff_env"].get("episode_length", 250), stochastic_init=stochastic_init, - no_grad=False, ) config.update(cfg["params"].get("diff_env", {})) @@ -216,7 +215,7 @@ def __init__(self, cfg): self.episode_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self.episode_discounted_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self.episode_gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device) - self.episode_length = torch.zeros(self.num_envs, dtype=int) + self.episode_length = torch.zeros(self.num_envs, dtype=int, device=self.device) self.horizon_length = torch.zeros(self.num_envs, dtype=int) self.best_policy_loss = np.inf self.actor_loss = np.inf @@ -557,7 +556,7 @@ def train(self): self.initialize_env() self.episode_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) self.episode_discounted_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) - self.episode_length = torch.zeros(self.num_envs, dtype=int) + self.episode_length = torch.zeros(self.num_envs, dtype=int, device=self.device) self.episode_gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device) def actor_closure(): diff --git a/src/shac/algorithms/shac2.py b/src/shac/algorithms/shac2.py index 9ff34d90..2c55dd57 100644 --- a/src/shac/algorithms/shac2.py +++ b/src/shac/algorithms/shac2.py @@ -36,7 +36,7 @@ class SHAC: def __init__(self, cfg: DictConfig): seeding(cfg.general.seed) - self.env = instantiate(cfg.env.config) + self.env = instantiate(cfg.task.env, _convert_='partial') print("num_envs = ", self.env.num_envs) print("num_actions = ", self.env.num_actions) @@ -71,8 +71,8 @@ def __init__(self, cfg: DictConfig): self.steps_num = cfg.alg.params.config.steps_num self.max_epochs = cfg.alg.params.config.max_epochs - self.actor_lr = float(cfg.env.shac2.actor_lr) - self.critic_lr = float(cfg.env.shac2.critic_lr) + self.actor_lr = float(cfg.task.shac2.actor_lr) + self.critic_lr = float(cfg.task.shac2.critic_lr) self.lr_schedule = cfg.alg.params.config.lr_schedule self.is_adaptive_lr = self.lr_schedule == "adaptive" diff --git a/src/shac/envs/__init__.py b/src/shac/envs/__init__.py index 70883bd1..eed51c20 100644 --- a/src/shac/envs/__init__.py +++ b/src/shac/envs/__init__.py @@ -17,6 +17,11 @@ from warp.envs.hopper import HopperEnv as HopperWarpEnv from warp.envs.ant import AntEnv as AntWarpEnv +from warp.envs.obj_env import ObjectTask +from warp.envs.hand_env import HandObjectTask +from warp.envs.repose_task import ReposeTask +from warp.envs.articulate_task import ArticulateTask + # dmanip envs try: from dmanip.envs import WarpEnv, ClawWarpEnv, AllegroWarpEnv