Skip to content

Commit

Permalink
Merge pull request NVlabs#1 from krishpop/warp-envs
Browse files Browse the repository at this point in the history
Warp envs syntax and added doubleqcritic
  • Loading branch information
krishpop authored Jun 7, 2023
2 parents 9a382a2 + 91ea192 commit 960c0d3
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 121 deletions.
1 change: 1 addition & 0 deletions scripts/cfg/alg/shac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ params:
num_actors: ${env.config.num_envs} # ${resolve_default:64,${..env.config.num_envs}}
save_interval: 400 # ${resolve_default:400,${..env.save_interval}}
contact_truncation: False
min_steps_before_truncation: 4

player:
determenistic: True
Expand Down
17 changes: 16 additions & 1 deletion scripts/cfg/alg/shac2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,23 @@ params:
units: ${env.shac2.actor_mlp.units}
activation: elu

# critic_name: double_q_network
# critic:
# _target_: rl_games.algos_torch.network_builder.DoubleQCritic
# units: ${env.shac2.critic_mlp.units}
# activation: elu
# norm_func_name: layer_norm
# d2rl: True
# critic_name: q_network
# critic:
# _target_: shac.models.critic.QCriticMLP
# cfg_network:
# critic_mlp:
# units: ${env.shac2.critic_mlp.units}
# activation: elu
critic_name: value_network
critic:
_target_: shac.models.critic.QCriticMLP
_target_: shac.models.critic.CriticMLP
cfg_network:
critic_mlp:
units: ${env.shac2.critic_mlp.units}
Expand Down
23 changes: 20 additions & 3 deletions scripts/cfg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ defaults:
- _self_
- env: cartpole
- alg: shac
# - override hydra/launcher: submitit_local


hydra:
Expand All @@ -25,7 +24,7 @@ general:
no_time_stamp: False
render: False
device: cuda:0
run_wandb: True
run_wandb: False
seed: 42
train: True
checkpoint:
Expand Down Expand Up @@ -80,7 +79,25 @@ env:
- 0.7
- 0.95

shac2: ${.shac}
shac2:
lambda: 0.95
actor_mlp:
units:
- 64
- 64
critic_mlp:
units:
- 64
- 64
target_critic_alpha: 0.4
actor_lr: 1e-3
critic_lr: 1e-3
max_epochs: 2000
save_interval: 400
steps_num: 32
betas:
- 0.7
- 0.95

wandb:
project: shac
Expand Down
18 changes: 18 additions & 0 deletions scripts/cfg/env/hopper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ shac:
- 64
save_interval: 400

shac2:
actor_lr: 2e-3
critic_lr: 4e-3
max_epochs: 500
betas:
- 0.7
- 0.95
actor_mlp:
units:
- 128
- 64
- 32
critic_mlp:
units:
- 64
- 64
save_interval: 400

ppo:
max_epochs: 2000
minibatch_size: 8192
Expand Down
144 changes: 101 additions & 43 deletions src/shac/algorithms/shac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import shac.models.critic as critic_models
from shac.utils.common import *
import shac.utils.torch_utils as tu
from rl_games.algos_torch.network_builder import DoubleQCritic
from shac.utils.running_mean_std import RunningMeanStd
from shac.utils.dataset import CriticDataset, QCriticDataset
from shac.utils.time_report import TimeReport
Expand All @@ -45,7 +46,6 @@ def __init__(self, cfg):
seed=cfg["params"]["general"]["seed"],
episode_length=cfg["params"]["diff_env"].get("episode_length", 250),
stochastic_init=stochastic_init,
no_grad=False,
)

config.update(cfg["params"].get("diff_env", {}))
Expand Down Expand Up @@ -98,6 +98,7 @@ def __init__(self, cfg):
self.truncate_grad = cfg["params"]["config"]["truncate_grads"]
self.grad_norm = cfg["params"]["config"]["grad_norm"]
self.contact_truncation = cfg["params"]["config"].get("contact_truncation", False)
self.min_steps_from_truncation = cfg["params"]["config"].get("min_steps_before_truncation", 4)

if cfg["params"]["general"]["train"]:
self.log_dir = cfg["params"]["general"]["logdir"]
Expand Down Expand Up @@ -134,9 +135,11 @@ def __init__(self, cfg):
self.actor = actor_fn(self.num_obs, self.num_actions, cfg["params"]["network"], device=self.device)
critic_fn = getattr(critic_models, self.critic_name)
if self.critic_name == "CriticMLP":
self.critic = critic_fn(self.num_obs, cfg["params"]["network"], device=self.device)
else:
self.critic = critic_fn(self.num_obs, self.num_actions, cfg["params"]["network"], device=self.device)
input_size = self.num_obs
elif self.critic_name == "QCriticMLP" or self.critic_name == "DoubleQCriticMLP":
input_size = self.num_obs + self.num_actions
self.critic = critic_fn(input_size, cfg["params"]["network"], device=self.device)

self.all_params = list(self.actor.parameters()) + list(self.critic.parameters())
self.target_critic = copy.deepcopy(self.critic)

Expand All @@ -161,12 +164,18 @@ def __init__(self, cfg):
dtype=torch.float32,
device=self.device,
)
if self.critic_name == "QCriticMLP":
if self.critic_name.endswith("QCriticMLP"):
self.act_buf = torch.zeros(
(self.steps_num, self.num_envs, self.num_actions),
dtype=torch.float32,
device=self.device,
)
# self.next_obs_buf = torch.zeros(
# (self.steps_num, self.num_envs, self.num_obs),
# dtype=torch.float32,
# device=self.device,
# )

self.rew_buf = torch.zeros((self.steps_num, self.num_envs), dtype=torch.float32, device=self.device)
self.done_mask = torch.zeros((self.steps_num, self.num_envs), dtype=torch.float32, device=self.device)
self.next_values = torch.zeros((self.steps_num, self.num_envs), dtype=torch.float32, device=self.device)
Expand Down Expand Up @@ -206,7 +215,7 @@ def __init__(self, cfg):
self.episode_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
self.episode_discounted_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
self.episode_gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device)
self.episode_length = torch.zeros(self.num_envs, dtype=int)
self.episode_length = torch.zeros(self.num_envs, dtype=int, device=self.device)
self.horizon_length = torch.zeros(self.num_envs, dtype=int)
self.best_policy_loss = np.inf
self.actor_loss = np.inf
Expand All @@ -225,6 +234,18 @@ def __init__(self, cfg):
# timer
self.time_report = TimeReport()

def compute_values(self, obs):
"""Compute values for the given observations with target critic."""
if self.critic_name == "CriticMLP":
values = self.target_critic(obs).squeeze(-1)
elif self.critic_name == "QCriticMLP":
action = torch.tanh(self.actor(obs, deterministic=True))
values = self.target_critic(obs, action).squeeze(-1)
elif self.critic_name == "DoubleQCriticMLP":
action = torch.tanh(self.actor(obs, deterministic=True))
values = torch.minimum(*self.target_critic(obs, action)).squeeze(-1)
return values

def compute_actor_loss(self, deterministic=False):
rew_acc = torch.zeros((self.steps_num + 1, self.num_envs), dtype=torch.float32, device=self.device)
gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device)
Expand All @@ -240,26 +261,31 @@ def compute_actor_loss(self, deterministic=False):
ret_var = self.ret_rms.var.clone()

# initialize trajectory to cut off gradients between episodes.
obs = self.env.initialize_trajectory()
next_obs = self.env.initialize_trajectory()
if self.obs_rms is not None:
# update obs rms
with torch.no_grad():
self.obs_rms.update(obs)
self.obs_rms.update(next_obs)
# normalize the current obs
obs = obs_rms.normalize(obs)
next_obs = obs_rms.normalize(next_obs)

# collect data for critic training
for i in range(self.steps_num):
obs = next_obs
with torch.no_grad():
self.obs_buf[i] = obs.clone()

actions = self.actor(obs, deterministic=deterministic)

if self.critic_name == "QCriticMLP":
if self.critic_name.endswith("QCriticMLP"):
with torch.no_grad():
self.act_buf[i] = actions.clone()

obs, rew, done, extra_info = self.env.step(torch.tanh(actions))
next_obs, rew, done, extra_info = self.env.step(torch.tanh(actions))

# if self.critic_name.endswith("QCriticMLP"):
# with torch.no_grad():
# self.next_obs_buf[i] = obs.clone()

with torch.no_grad():
raw_rew = rew.clone()
Expand All @@ -270,9 +296,9 @@ def compute_actor_loss(self, deterministic=False):
if self.obs_rms is not None:
# update obs rms
with torch.no_grad():
self.obs_rms.update(obs)
self.obs_rms.update(next_obs)
# normalize the current obs
obs = obs_rms.normalize(obs)
next_obs = obs_rms.normalize(next_obs)

if self.ret_rms is not None:
# update ret rms
Expand All @@ -285,17 +311,16 @@ def compute_actor_loss(self, deterministic=False):
self.episode_length += 1

trunc = extra_info.get("contact_changed", torch.zeros_like(done))
if not self.contact_truncation:
is_done = done.clone()
trunc_on_contact = extra_info.get("num_contact_changed", torch.zeros_like(done)) >= 1
if self.contact_truncation:
is_done = done.clone() | trunc_on_contact
trunc_env_ids = trunc_on_contact.nonzero(as_tuple=False).squeeze(-1)
else:
is_done = done.clone() # | trunc
is_done = done.clone()
trunc_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1)
done_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1)

if self.critic_name == "CriticMLP":
next_values[i + 1] = self.target_critic(obs).squeeze(-1)
else:
next_action = torch.tanh(self.actor(obs, deterministic=True))
next_values[i + 1] = self.target_critic(obs, next_action).squeeze(-1)
next_values[i + 1] = self.compute_values(next_obs)

for id in done_env_ids:
if (
Expand All @@ -304,21 +329,34 @@ def compute_actor_loss(self, deterministic=False):
or (torch.abs(extra_info["obs_before_reset"][id]) > 1e6).sum() > 0
): # ugly fix for nan values
next_values[i + 1, id] = 0.0
elif self.episode_length[id] < self.max_episode_length and (
not self.contact_truncation or not trunc[id]
):
elif self.episode_length[id] < self.max_episode_length:
# set values to 0 due to early termination
next_values[i + 1, id] = 0.0
else: # otherwise, use terminal value critic to estimate the long-term performance
else:
# otherwise, use terminal value critic to estimate the long-term performance
if self.obs_rms is not None:
real_obs = obs_rms.normalize(extra_info["obs_before_reset"][id])
else:
real_obs = extra_info["obs_before_reset"][id]

if self.critic_name == "CriticMLP":
next_values[i + 1, id] = self.target_critic(real_obs).squeeze(-1)
# if truncating on contact, compute model-free value
# and at least min_steps have passed since last truncation/short horizon
last_zero = self.rew_buf[:i, id].eq(0).nonzero(as_tuple=False).squeeze(-1)
if len(last_zero) == 0:
last_zero = i
else:
real_act = torch.tanh(self.actor(real_obs, deterministic=True))
next_values[i + 1, id] = self.target_critic(real_obs, real_act).squeeze(-1)
last_zero = last_zero[-1]

if (
self.contact_truncation
and id in trunc_env_ids
and i - last_zero >= self.min_steps_from_truncation
):
real_obs = real_obs.detach()

next_values[i + 1, id] = self.compute_values(real_obs)

is_done = done.clone()
done_env_ids = is_done.nonzero(as_tuple=False).squeeze(-1)

if (next_values[i + 1] > 1e6).sum() > 0 or (next_values[i + 1] < -1e6).sum() > 0:
print("next value error")
Expand All @@ -338,9 +376,8 @@ def compute_actor_loss(self, deterministic=False):
gamma = gamma * self.gamma

# clear up gamma and rew_acc for done envs
gamma[done_env_ids] = 1.0
rew_acc[i + 1, done_env_ids] = 0.0
# done_env_ids = done.nonzero(as_tuple=False).squeeze(-1)
gamma[trunc_env_ids] = 1.0
rew_acc[i + 1, trunc_env_ids] = 0.0

# collect data for critic training
with torch.no_grad():
Expand Down Expand Up @@ -471,12 +508,16 @@ def compute_target_values(self):
raise NotImplementedError

def compute_critic_loss(self, batch_sample):
if self.critic_name == "CriticMLP":
predicted_values = self.critic(batch_sample["obs"]).squeeze(-1)
else:
predicted_values = self.critic(batch_sample["obs"], batch_sample["act"]).squeeze(-1)
target_values = batch_sample["target_values"]
critic_loss = ((predicted_values - target_values) ** 2).mean()
if self.critic_name == "DoubleQCriticMLP":
q1, q2 = self.critic(batch_sample["obs"], batch_sample["act"])
critic_loss = ((q1 - target_values) ** 2).mean() + ((q2 - target_values) ** 2).mean()
else:
if self.critic_name == "CriticMLP":
predicted_values = self.critic(batch_sample["obs"]).squeeze(-1)
elif self.critic_name == "QCriticMLP":
predicted_values = self.critic(batch_sample["obs"], batch_sample["act"]).squeeze(-1)
critic_loss = ((predicted_values - target_values) ** 2).mean()

return critic_loss

Expand Down Expand Up @@ -515,7 +556,7 @@ def train(self):
self.initialize_env()
self.episode_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
self.episode_discounted_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
self.episode_length = torch.zeros(self.num_envs, dtype=int)
self.episode_length = torch.zeros(self.num_envs, dtype=int, device=self.device)
self.episode_gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device)

def actor_closure():
Expand Down Expand Up @@ -576,7 +617,11 @@ def actor_closure():
dataset = CriticDataset(self.batch_size, self.obs_buf, self.target_values, drop_last=False)
else:
dataset = QCriticDataset(
self.batch_size, self.obs_buf, self.act_buf, self.target_values, drop_last=False
self.batch_size,
self.obs_buf,
self.act_buf,
self.target_values,
drop_last=False,
)
self.time_report.end_timer("prepare critic dataset")

Expand Down Expand Up @@ -733,16 +778,29 @@ def play(self, cfg):
def save(self, filename=None):
if filename is None:
filename = "best_policy"
# double Q critic must be saved as state dict to be pickleable
if self.critic_name == "DoubleQCriticMLP":
print("saving critic state dict")
critic = self.critic.state_dict()
target_critic = self.target_critic.state_dict()
else:
critic = self.critic
target_critic = self.target_critic
torch.save(
[self.actor, self.critic, self.target_critic, self.obs_rms, self.ret_rms],
[self.actor, critic, target_critic, self.obs_rms, self.ret_rms],
os.path.join(self.log_dir, "{}.pt".format(filename)),
)

def load(self, path):
checkpoint = torch.load(path)
self.actor = checkpoint[0].to(self.device)
self.critic = checkpoint[1].to(self.device)
self.target_critic = checkpoint[2].to(self.device)
# double Q critic must be loaded as state dict to be pickleable
if self.critic_name == "DoubleQCriticMLP":
self.critic.load_state_dict(checkpoint[1])
self.target_critic.load_state_dict(checkpoint[2])
else:
self.critic = checkpoint[1].to(self.device)
self.target_critic = checkpoint[2].to(self.device)
self.obs_rms = checkpoint[3].to(self.device)
self.ret_rms = checkpoint[4].to(self.device) if checkpoint[4] is not None else checkpoint[4]

Expand Down
Loading

0 comments on commit 960c0d3

Please sign in to comment.