diff --git a/README.md b/README.md index 7a4a217..e2ebeb7 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,11 @@ For example (DDPG): ![AC](AC/A2CAgent_600.gif) -- improve `AWR`, `DDPG` with Gumbel Distribution Regression from [`XQL`](https://div99.github.io/XQL): +- [x] [TD3](https://arxiv.org/pdf/1802.09477.pdf) + + ![TD3](TD3/TD3Agent_100.gif) + +- improve `AWR`, `DDPG` `TD3` with Gumbel Distribution Regression from [`XQL`](https://div99.github.io/XQL): - XAWR ![XAWR](XAWR/XAWRAgent_100.gif) @@ -111,6 +115,10 @@ For example (DDPG): ![XDDPG](XDDPG/XDDPGAgent_200.gif) + - XTD3 + + ![XTD3](XTD3/XTD3Agent_100.gif) + ## Reference - TrainMonitor and Generategif modified from [coax](https://github.com/coax-dev/coax) diff --git a/TD3/TD3Agent_100.gif b/TD3/TD3Agent_100.gif new file mode 100644 index 0000000..18df994 Binary files /dev/null and b/TD3/TD3Agent_100.gif differ diff --git a/TD3/__init__.py b/TD3/__init__.py new file mode 100644 index 0000000..8875f60 --- /dev/null +++ b/TD3/__init__.py @@ -0,0 +1 @@ +# pylint: disable=all diff --git a/TD3/main.py b/TD3/main.py new file mode 100644 index 0000000..c9fe3e2 --- /dev/null +++ b/TD3/main.py @@ -0,0 +1,136 @@ +"""main executable file for TD3""" +import os +import logging +from itertools import repeat +import gymnasium as gym +import torch +import numpy as np +from util import generate_gif +from util.wrappers import TrainMonitor +from util.buffer import Experience +from collections import deque +# pylint: disable=invalid-name +from TD3.td3 import TD3Agent as TD3_torch + +Agent = TD3_torch +logging.basicConfig(level=logging.INFO) + +torch.manual_seed(0) +np.random.seed(0) + +EPSILON_DECAY_STEPS = 100 + + +def main( + n_episodes=2000, + max_t=200, + eps_start=1.0, + eps_end=0.01, + eps_decay=0.996, + score_term_rules=lambda s: False, + time_interval="25ms" +): + # pylint: disable=line-too-long + """Deep Q-Learning + + Params + ====== + n_episodes (int): maximum number of training epsiodes + max_t (int): maximum number of timesteps per episode + eps_start (float): starting value of epsilon, for epsilon-greedy action selection + eps_end (float): minimum value of epsilon + eps_decay (float): mutiplicative factor (per episode) for decreasing epsilon + + """ + scores = [] # list containing score from each episode + scores_window = deque(maxlen=100) # last 100 scores + eps = eps_start + + env = gym.make( + "Pendulum-v1", + render_mode="rgb_array", + ) + # env = gym.make( + # "LunarLander-v2", + # render_mode="rgb_array", + # continuous=True, + # ) + # env = gym.make("MountainCarContinuous-v0", render_mode="rgb_array") + env = TrainMonitor(env, tensorboard_dir="./logs", tensorboard_write_all=True) + + gamma = 0.99 + batch_size = 64 + learn_iteration = 16 + update_tau = 0.5 + + lr_actor = 0.0001 + lr_critic = 0.001 + + mu = 0.0 + theta = 0.15 + max_sigma = 0.3 + min_sigma = 0.3 + decay_period = 100000 + value_noise_clip = 0.5 + value_noise_sigma = 0.5 + + agent = Agent( + state_dims=env.observation_space, + action_space=env.action_space, + lr_actor=lr_actor, + lr_critic=lr_critic, + gamma=gamma, + batch_size=batch_size, + forget_experience=False, + update_tau=update_tau, + mu=mu, + theta=theta, + max_sigma=max_sigma, + min_sigma=min_sigma, + decay_period=decay_period, + value_noise_clip=value_noise_clip, + value_noise_sigma=value_noise_sigma + ) + dump_gif_dir = f"images/{agent.__class__.__name__}/{agent.__class__.__name__}_{{}}.gif" + for i_episode in range(1, n_episodes + 1): + state, _ = env.reset() + score = 0 + for t, _ in enumerate(repeat(0, max_t)): + action = agent.take_action(state=state, explore=True, step=t * i_episode) + next_state, reward, done, _, _ = env.step(action) + agent.remember(Experience(state, action, reward, next_state, done)) + agent.learn(learn_iteration) + + state = next_state + score += reward + + if done or score_term_rules(score): + break + + scores_window.append(score) ## save the most recent score + scores.append(score) ## sae the most recent score + eps = max(eps * eps_decay, eps_end) ## decrease the epsilon + print(" " * os.get_terminal_size().columns, end="\r") + print( + f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}", + end="\r" + ) + + if i_episode and i_episode % 100 == 0: + print(" " * os.get_terminal_size().columns, end="\r") + print( + f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}" + ) + generate_gif( + env, + filepath=dump_gif_dir.format(i_episode), + policy=lambda s: agent.take_action(s, explore=False), + duration=float(time_interval.split("ms")[0]), + max_episode_steps=max_t + ) + + return scores + + +if __name__ == "__main__": + main() diff --git a/TD3/td3.py b/TD3/td3.py new file mode 100644 index 0000000..a50d556 --- /dev/null +++ b/TD3/td3.py @@ -0,0 +1,385 @@ +"""TD3 implementation with pytorch.""" +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from util.buffer import ReplayBuffer +from util.agent import Agent +from util.buffer import Experience + + +class OUNoise(object): + # pylint: disable=line-too-long + """ + Taken from https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py + """ + + def __init__( + self, + action_space, + mu=0.0, + theta=0.15, + max_sigma=0.3, + min_sigma=0.3, + decay_period=100000 + ): + self.mu = mu + self.theta = theta + self.sigma = max_sigma + self.max_sigma = max_sigma + self.min_sigma = min_sigma + self.decay_period = decay_period + self.action_dim = action_space.shape[0] + self.low = action_space.low + self.high = action_space.high + self.reset() + + def reset(self): + self.state = np.ones(self.action_dim) * self.mu + + def evolve_state(self): + x = self.state + dx = self.theta * (self.mu - + x) + self.sigma * np.random.randn(self.action_dim) + self.state = x + dx + return self.state + + def get_action(self, action, t=0): + ou_state = self.evolve_state() + self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma + ) * min(1.0, t / self.decay_period) + return np.clip(action + ou_state, self.low, self.high) + + +class Actor(nn.Module): + """ Actor (Policy) Model.""" + + def __init__( + self, + state_dim, + action_space, + seed=0, + fc1_unit=64, + fc2_unit=64, + max_action=1, + init_weight_gain=np.sqrt(2), + init_policy_weight_gain=1, + init_bias=0 + ): + """ + Initialize parameters and build model. + Params + ======= + state_size (int): Dimension of each state + action_size (int): Dimension of each action + seed (int): Random seed + fc1_unit (int): Number of nodes in first hidden layer + fc2_unit (int): Number of nodes in second hidden layer + """ + super().__init__() ## calls __init__ method of nn.Module class + self.seed = torch.manual_seed(seed) + self.fc1 = nn.Linear(state_dim, fc1_unit) + self.fc1_ln = nn.LayerNorm(fc1_unit) + self.fc2 = nn.Linear(fc1_unit, fc2_unit) + self.fc2_ln = nn.LayerNorm(fc2_unit) + self.fc_policy = nn.Linear(fc2_unit, action_space) + self.max_action = max_action + + nn.init.orthogonal_(self.fc1.weight, gain=init_weight_gain) + nn.init.orthogonal_(self.fc2.weight, gain=init_weight_gain) + nn.init.uniform_( + self.fc_policy.weight, -init_policy_weight_gain, init_policy_weight_gain + ) + + nn.init.constant_(self.fc1.bias, init_bias) + nn.init.constant_(self.fc2.bias, init_bias) + nn.init.constant_(self.fc_policy.bias, init_bias) + + def forward(self, x): + """ + Build a network that maps state -> action values. + """ + x = F.relu(self.fc1_ln(self.fc1(x))) + x = F.relu(self.fc2_ln(self.fc2(x))) + pi = self.max_action * torch.tanh(self.fc_policy(x)) + return pi + + +class Critic(nn.Module): + """ Critic (Policy) Model.""" + + def __init__( + self, + state_dim, + action_space=1, + seed=0, + fc1_unit=64, + fc2_unit=64, + init_weight_gain=np.sqrt(2), + init_bias=0 + ): + """ + Initialize parameters and build model. + Params + ======= + state_size (int): Dimension of each state + action_size (int): Dimension of each action + seed (int): Random seed + fc1_unit (int): Number of nodes in first hidden layer + fc2_unit (int): Number of nodes in second hidden layer + """ + super().__init__() ## calls __init__ method of nn.Module class + self.seed = torch.manual_seed(seed) + self.fc1 = nn.Linear(state_dim + action_space, fc1_unit) + self.fc1_ln = nn.LayerNorm(fc1_unit) + self.fc2 = nn.Linear(fc1_unit, fc2_unit) + self.fc2_ln = nn.LayerNorm(fc2_unit) + self.fc3 = nn.Linear(fc2_unit, 1) + + nn.init.orthogonal_(self.fc1.weight, gain=init_weight_gain) + nn.init.orthogonal_(self.fc2.weight, gain=init_weight_gain) + + nn.init.constant_(self.fc1.bias, init_bias) + nn.init.constant_(self.fc2.bias, init_bias) + + def forward(self, x, y): + """ + Build a network that maps state -> action values. + """ + x = torch.concat([x, y], dim=1) + x = F.relu(self.fc1_ln(self.fc1(x))) + x = F.relu(self.fc2_ln(self.fc2(x))) + return self.fc3(x) + + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class TD3Agent(Agent): + """Interacts with and learns form environment.""" + + def __init__( + self, + state_dims, + action_space, + gamma=0.99, + lr_actor=0.001, + lr_critic=0.001, + batch_size=64, + epsilon=0.01, + mem_size=None, + forget_experience=True, + update_tau=0.5, + n_steps=0, + gae_lambda=None, + beta=0, + seed=0, + mu=0.0, + theta=0.15, + max_sigma=0.3, + min_sigma=0.3, + decay_period=100000, + value_noise_clip=0.5, + value_noise_sigma=0.5 + ): + + self.state_dims = state_dims.shape[0] + self.action_space_env = action_space + self.action_space = action_space.shape[0] + self.gamma = gamma + self.batch_size = batch_size + self.epsilon = epsilon + self.seed = np.random.seed(seed) + self.n_steps = n_steps + self.gae_lambda = gae_lambda + self.lr_actor = lr_actor + self.lr_critic = lr_critic + self.beta = beta + self.noise = OUNoise( + action_space, + mu=mu, + theta=theta, + max_sigma=max_sigma, + min_sigma=min_sigma, + decay_period=decay_period, + ) + self.update_tau = update_tau + self.value_noise_clip = value_noise_clip + self.value_noise_sigma = value_noise_sigma + + # Theta 1 network + self.actor = Actor( + self.state_dims, + self.action_space, + max_action=self.action_space_env.high[0] + ).to(device) + self.actor_target = Actor( + self.state_dims, + self.action_space, + max_action=self.action_space_env.high[0] + ).to(device) + self.actor_target.load_state_dict(self.actor.state_dict()) + + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=self.lr_actor + ) + + # Theta 1 Critic network + self.critic = Critic(self.state_dims, self.action_space).to(device) + self.critic_target = Critic(self.state_dims, self.action_space).to(device) + self.critic_target.load_state_dict(self.critic.state_dict()) + + # Theta 2 Critic network + self.critic_1 = Critic(self.state_dims, self.action_space).to(device) + self.critic_target_1 = Critic(self.state_dims, self.action_space).to(device) + self.critic_target_1.load_state_dict(self.critic_1.state_dict()) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=self.lr_critic + ) + + # Replay memory + self.memory = ReplayBuffer(max_size=mem_size) + + self.forget_experience = forget_experience + + self.val_loss = nn.MSELoss() + self.val_1_loss = nn.MSELoss() + self.policy_loss = nn.MSELoss() + + def learn(self, iteration): + if len(self.memory) > self.batch_size: + for _ in range(iteration): + experience = self.memory.sample_from(num_samples=self.batch_size) + self._learn(experience) + + def action(self, state, mode="train"): + if mode == "train": + self.actor.train() + else: + self.actor.eval() + + with torch.no_grad(): + action = self.actor.forward(state) + return action.cpu().data.numpy() + + def take_action(self, state, explore=False, step=0): + """Returns action for given state as per current policy + Params + ======= + state (array_like): current state + epsilon (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + action_values = self.action(state=state, mode="eval").squeeze(0) + if explore: + action_values = self.noise.get_action(action_values, step) + + # Clip the output according to the action space of the env + action_values = np.clip( + action_values, self.action_space_env.low[0], + self.action_space_env.high[0] + ) + return action_values + + def remember(self, scenario: Experience): + self.memory.enqueue(scenario) + + def _learn(self, experiences): + # pylint: disable=line-too-long + """Update value parameters using given batch of experience tuples. + Params + ======= + experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + + states = torch.from_numpy(np.vstack([e.state for e in experiences]) + ).float().to(device) + actions = torch.from_numpy(np.vstack([e.action for e in experiences]) + ).long().to(device) + rewards = torch.from_numpy(np.vstack([e.reward for e in experiences]) + ).float().to(device) + next_states = torch.from_numpy( + np.vstack([e.next_state for e in experiences]) + ).float().to(device) + terminate = torch.from_numpy(np.vstack([e.done for e in experiences]) + ).float().to(device) + + self.critic.train() + self.critic_target.eval() + self.critic_1.train() + self.critic_target_1.eval() + self.actor_target.eval() + + # noise ~ N(0, sigma) + noise = torch.clamp( + torch.normal(mean=0.0, std=self.value_noise_sigma, size=actions.size()), + -self.value_noise_clip, self.value_noise_clip + ).to(device) + + # Compute the target Q value + target_q = self.critic_target.forward( + next_states, + self.actor_target.forward(next_states) + noise + ) + target_q_1 = self.critic_target_1.forward( + next_states, + self.actor_target.forward(next_states) + noise + ) + + min_target_q_value = torch.min( + torch.cat((target_q, target_q_1), dim=1), dim=1 + ).values.unsqueeze(dim=1) + + target_q = rewards + ((1 - terminate) * self.gamma * + min_target_q_value).detach() + + # Get current Q estimate + current_q = self.critic.forward(states, actions) + + # Get current Q estimate + current_q_1 = self.critic_1.forward(states, actions) + + # Compute critic loss + critic_loss = self.val_loss(current_q, target_q + ) + self.val_1_loss(current_q_1, target_q) + + # Optimize the critic + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + # Here, `Delayed policy updates` is needed + # This implementation will assume that `policy_freq = 1`. + # Compute actor loss + actor_loss = -self.critic.forward(states, self.actor.forward(states)).mean() + + # Optimize the actor + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.update_actor_target_network() + self.update_critic_target_network() + + def soft_update(self, local_model, target_model): + """ + Soft update model parameters. + θ_target = τ * θ_local + (1 - τ) * θ_target + Token from + https://github.com/udacity/deep-reinforcement-learning/blob/master/dqn/exercise/dqn_agent.py + """ + for target_param, local_param in zip( + target_model.parameters(), local_model.parameters() + ): + target_param.data.copy_( + self.update_tau * local_param.data + + (1.0 - self.update_tau) * target_param.data + ) + + def update_actor_target_network(self): + self.soft_update(self.actor, self.actor_target) + + def update_critic_target_network(self): + self.soft_update(self.critic, self.critic_target) diff --git a/TODO.md b/TODO.md index 114834a..663714b 100644 --- a/TODO.md +++ b/TODO.md @@ -10,5 +10,6 @@ - [x] [AWR](https://openreview.net/attachment?id=H1gdF34FvS&name=original_pdf) - [x] [AC](https://proceedings.neurips.cc/paper/1999/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf) (A2C & [A3C](https://arxiv.org/pdf/1602.01783.pdf)) - [x] [XQL](https://div99.github.io/XQL) +- [x] [TD3](https://arxiv.org/pdf/1802.09477.pdf) - [ ] [IMPALA](https://arxiv.org/pdf/1802.01561.pdf) - [ ] [ERLA](https://arxiv.org/pdf/2101.03958.pdf) diff --git a/XTD3/XTD3Agent_100.gif b/XTD3/XTD3Agent_100.gif new file mode 100644 index 0000000..5bf9796 Binary files /dev/null and b/XTD3/XTD3Agent_100.gif differ diff --git a/XTD3/__init__.py b/XTD3/__init__.py new file mode 100644 index 0000000..8875f60 --- /dev/null +++ b/XTD3/__init__.py @@ -0,0 +1 @@ +# pylint: disable=all diff --git a/XTD3/main.py b/XTD3/main.py new file mode 100644 index 0000000..5a2156e --- /dev/null +++ b/XTD3/main.py @@ -0,0 +1,141 @@ +"""main executable file for XTD3""" +import os +import logging +from itertools import repeat +import gymnasium as gym +import torch +import numpy as np +from util import generate_gif +from util.wrappers import TrainMonitor +from util.buffer import Experience +from collections import deque +# pylint: disable=invalid-name +from XTD3.xtd3 import XTD3Agent as XTD3_torch + +Agent = XTD3_torch +logging.basicConfig(level=logging.INFO) + +torch.manual_seed(0) +np.random.seed(0) + +EPSILON_DECAY_STEPS = 100 + + +def main( + n_episodes=2000, + max_t=200, + eps_start=1.0, + eps_end=0.01, + eps_decay=0.996, + score_term_rules=lambda s: False, + time_interval="25ms" +): + # pylint: disable=line-too-long + """Deep Q-Learning + + Params + ====== + n_episodes (int): maximum number of training epsiodes + max_t (int): maximum number of timesteps per episode + eps_start (float): starting value of epsilon, for epsilon-greedy action selection + eps_end (float): minimum value of epsilon + eps_decay (float): mutiplicative factor (per episode) for decreasing epsilon + + """ + scores = [] # list containing score from each episode + scores_window = deque(maxlen=100) # last 100 scores + eps = eps_start + + env = gym.make( + "Pendulum-v1", + render_mode="rgb_array", + ) + # env = gym.make( + # "LunarLander-v2", + # render_mode="rgb_array", + # continuous=True, + # ) + # env = gym.make("MountainCarContinuous-v0", render_mode="rgb_array") + env = TrainMonitor(env, tensorboard_dir="./logs", tensorboard_write_all=True) + + gamma = 0.99 + batch_size = 64 + learn_iteration = 16 + update_tau = 0.5 + + lr_actor = 0.0001 + lr_critic = 0.001 + + mu = 0.0 + theta = 0.15 + max_sigma = 0.3 + min_sigma = 0.3 + decay_period = 100000 + value_noise_clip = 0.5 + value_noise_sigma = 0.5 + gumbel_loss_beta = 2 + gumbel_loss_clip = None + + agent = Agent( + state_dims=env.observation_space, + action_space=env.action_space, + lr_actor=lr_actor, + lr_critic=lr_critic, + gamma=gamma, + batch_size=batch_size, + forget_experience=False, + update_tau=update_tau, + mu=mu, + theta=theta, + max_sigma=max_sigma, + min_sigma=min_sigma, + decay_period=decay_period, + value_noise_clip=value_noise_clip, + value_noise_sigma=value_noise_sigma, + gumbel_loss_beta=gumbel_loss_beta, + gumbel_loss_clip=gumbel_loss_clip, + ) + + dump_gif_dir = f"images/{agent.__class__.__name__}/{agent.__class__.__name__}_{{}}.gif" + for i_episode in range(1, n_episodes + 1): + state, _ = env.reset() + score = 0 + for t, _ in enumerate(repeat(0, max_t)): + action = agent.take_action(state=state, explore=True, step=t * i_episode) + next_state, reward, done, _, _ = env.step(action) + agent.remember(Experience(state, action, reward, next_state, done)) + agent.learn(learn_iteration) + + state = next_state + score += reward + + if done or score_term_rules(score): + break + + scores_window.append(score) ## save the most recent score + scores.append(score) ## sae the most recent score + eps = max(eps * eps_decay, eps_end) ## decrease the epsilon + print(" " * os.get_terminal_size().columns, end="\r") + print( + f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}", + end="\r" + ) + + if i_episode and i_episode % 100 == 0: + print(" " * os.get_terminal_size().columns, end="\r") + print( + f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}" + ) + generate_gif( + env, + filepath=dump_gif_dir.format(i_episode), + policy=lambda s: agent.take_action(s, explore=False), + duration=float(time_interval.split("ms")[0]), + max_episode_steps=max_t + ) + + return scores + + +if __name__ == "__main__": + main() diff --git a/XTD3/xtd3.py b/XTD3/xtd3.py new file mode 100644 index 0000000..f5600f5 --- /dev/null +++ b/XTD3/xtd3.py @@ -0,0 +1,451 @@ +"""XTD3 implementation with pytorch.""" +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from util.buffer import ReplayBuffer +from util.agent import Agent +from util.buffer import Experience +from functools import partial + + +def gumbel_loss(pred, label, beta, clip): + """ + Gumbel loss function + + Describe in Appendix D.3 of + + https://arxiv.org/pdf/2301.02328.pdf + + Token from + https://github.com/Div99/XQL/blob/dff09afb893fe782be259c2420903f8dfb50ef2c/online/research/algs/gumbel_sac.py#L10) + """ + assert pred.shape == label.shape, "Shapes were incorrect" + z = (label - pred) / beta + if clip is not None: + z = torch.clamp(z, -clip, clip) + loss = torch.exp(z) - z - 1 + return loss.mean() + + +def gumbel_rescale_loss(pred, label, beta, clip): + """ + Gumbel rescale loss function + + Describe in Appendix D.3 (NUMERIC STABILITY) of + + https://arxiv.org/pdf/2301.02328.pdf + + Token from + https://github.com/Div99/XQL/blob/dff09afb893fe782be259c2420903f8dfb50ef2c/online/research/algs/gumbel_sac.py#L18) + """ + assert pred.shape == label.shape, "Shapes were incorrect" + z = (label - pred) / beta + if clip is not None: + z = torch.clamp(z, -clip, clip) + max_z = torch.max(z) + max_z = torch.where( + max_z < -1.0, torch.tensor(-1.0, dtype=torch.float, device=max_z.device), + max_z + ) + max_z = max_z.detach() # Detach the gradients + loss = torch.exp(z - max_z) - z * torch.exp(-max_z) - torch.exp(-max_z) + return loss.mean() + + +class OUNoise(object): + # pylint: disable=line-too-long + """ + Taken from https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py + """ + + def __init__( + self, + action_space, + mu=0.0, + theta=0.15, + max_sigma=0.3, + min_sigma=0.3, + decay_period=100000 + ): + self.mu = mu + self.theta = theta + self.sigma = max_sigma + self.max_sigma = max_sigma + self.min_sigma = min_sigma + self.decay_period = decay_period + self.action_dim = action_space.shape[0] + self.low = action_space.low + self.high = action_space.high + self.reset() + + def reset(self): + self.state = np.ones(self.action_dim) * self.mu + + def evolve_state(self): + x = self.state + dx = self.theta * (self.mu - + x) + self.sigma * np.random.randn(self.action_dim) + self.state = x + dx + return self.state + + def get_action(self, action, t=0): + ou_state = self.evolve_state() + self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma + ) * min(1.0, t / self.decay_period) + return np.clip(action + ou_state, self.low, self.high) + + +class Actor(nn.Module): + """ Actor (Policy) Model.""" + + def __init__( + self, + state_dim, + action_space, + seed=0, + fc1_unit=64, + fc2_unit=64, + max_action=1, + init_weight_gain=np.sqrt(2), + init_policy_weight_gain=1, + init_bias=0 + ): + """ + Initialize parameters and build model. + Params + ======= + state_size (int): Dimension of each state + action_size (int): Dimension of each action + seed (int): Random seed + fc1_unit (int): Number of nodes in first hidden layer + fc2_unit (int): Number of nodes in second hidden layer + """ + super().__init__() ## calls __init__ method of nn.Module class + self.seed = torch.manual_seed(seed) + self.fc1 = nn.Linear(state_dim, fc1_unit) + self.fc1_ln = nn.LayerNorm(fc1_unit) + self.fc2 = nn.Linear(fc1_unit, fc2_unit) + self.fc2_ln = nn.LayerNorm(fc2_unit) + self.fc_policy = nn.Linear(fc2_unit, action_space) + self.max_action = max_action + + nn.init.orthogonal_(self.fc1.weight, gain=init_weight_gain) + nn.init.orthogonal_(self.fc2.weight, gain=init_weight_gain) + nn.init.uniform_( + self.fc_policy.weight, -init_policy_weight_gain, init_policy_weight_gain + ) + + nn.init.constant_(self.fc1.bias, init_bias) + nn.init.constant_(self.fc2.bias, init_bias) + nn.init.constant_(self.fc_policy.bias, init_bias) + + def forward(self, x): + """ + Build a network that maps state -> action values. + """ + x = F.relu(self.fc1_ln(self.fc1(x))) + x = F.relu(self.fc2_ln(self.fc2(x))) + pi = self.max_action * torch.tanh(self.fc_policy(x)) + return pi + + +class Critic(nn.Module): + """ Critic (Policy) Model.""" + + def __init__( + self, + state_dim, + action_space=1, + seed=0, + fc1_unit=64, + fc2_unit=64, + init_weight_gain=np.sqrt(2), + init_bias=0 + ): + """ + Initialize parameters and build model. + Params + ======= + state_size (int): Dimension of each state + action_size (int): Dimension of each action + seed (int): Random seed + fc1_unit (int): Number of nodes in first hidden layer + fc2_unit (int): Number of nodes in second hidden layer + """ + super().__init__() ## calls __init__ method of nn.Module class + self.seed = torch.manual_seed(seed) + self.fc1 = nn.Linear(state_dim + action_space, fc1_unit) + self.fc1_ln = nn.LayerNorm(fc1_unit) + self.fc2 = nn.Linear(fc1_unit, fc2_unit) + self.fc2_ln = nn.LayerNorm(fc2_unit) + self.fc3 = nn.Linear(fc2_unit, 1) + + nn.init.orthogonal_(self.fc1.weight, gain=init_weight_gain) + nn.init.orthogonal_(self.fc2.weight, gain=init_weight_gain) + + nn.init.constant_(self.fc1.bias, init_bias) + nn.init.constant_(self.fc2.bias, init_bias) + + def forward(self, x, y): + """ + Build a network that maps state -> action values. + """ + x = torch.concat([x, y], dim=1) + x = F.relu(self.fc1_ln(self.fc1(x))) + x = F.relu(self.fc2_ln(self.fc2(x))) + return self.fc3(x) + + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class XTD3Agent(Agent): + """Interacts with and learns form environment.""" + + def __init__( + self, + state_dims, + action_space, + gamma=0.99, + lr_actor=0.001, + lr_critic=0.001, + batch_size=64, + epsilon=0.01, + mem_size=None, + forget_experience=True, + update_tau=0.5, + n_steps=0, + gae_lambda=None, + beta=0, + seed=0, + mu=0.0, + theta=0.15, + max_sigma=0.3, + min_sigma=0.3, + decay_period=100000, + value_noise_clip=0.5, + value_noise_sigma=0.5, + gumbel_loss_beta=2, + gumbel_loss_clip=None, + ): + + self.state_dims = state_dims.shape[0] + self.action_space_env = action_space + self.action_space = action_space.shape[0] + self.gamma = gamma + self.batch_size = batch_size + self.epsilon = epsilon + self.seed = np.random.seed(seed) + self.n_steps = n_steps + self.gae_lambda = gae_lambda + self.lr_actor = lr_actor + self.lr_critic = lr_critic + self.beta = beta + self.noise = OUNoise( + action_space, + mu=mu, + theta=theta, + max_sigma=max_sigma, + min_sigma=min_sigma, + decay_period=decay_period, + ) + self.update_tau = update_tau + self.value_noise_clip = value_noise_clip + self.value_noise_sigma = value_noise_sigma + + self.gumbel_loss_beta = gumbel_loss_beta + self.gumbel_loss_clip = gumbel_loss_clip + + # Theta 1 network + self.actor = Actor( + self.state_dims, + self.action_space, + max_action=self.action_space_env.high[0] + ).to(device) + self.actor_target = Actor( + self.state_dims, + self.action_space, + max_action=self.action_space_env.high[0] + ).to(device) + self.actor_target.load_state_dict(self.actor.state_dict()) + + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=self.lr_actor + ) + + # Theta 1 Critic network + self.critic = Critic(self.state_dims, self.action_space).to(device) + self.critic_target = Critic(self.state_dims, self.action_space).to(device) + self.critic_target.load_state_dict(self.critic.state_dict()) + + # Theta 2 Critic network + self.critic_1 = Critic(self.state_dims, self.action_space).to(device) + self.critic_target_1 = Critic(self.state_dims, self.action_space).to(device) + self.critic_target_1.load_state_dict(self.critic_1.state_dict()) + + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=self.lr_critic + ) + + # Replay memory + self.memory = ReplayBuffer(max_size=mem_size) + + self.forget_experience = forget_experience + + self.val_loss = partial( + gumbel_rescale_loss, + beta=self.gumbel_loss_beta, + clip=self.gumbel_loss_clip + ) + self.val_1_loss = partial( + gumbel_rescale_loss, + beta=self.gumbel_loss_beta, + clip=self.gumbel_loss_clip + ) + self.policy_loss = nn.MSELoss() + + def learn(self, iteration): + if len(self.memory) > self.batch_size: + for _ in range(iteration): + experience = self.memory.sample_from(num_samples=self.batch_size) + self._learn(experience) + + def action(self, state, mode="train"): + if mode == "train": + self.actor.train() + else: + self.actor.eval() + + with torch.no_grad(): + action = self.actor.forward(state) + return action.cpu().data.numpy() + + def take_action(self, state, explore=False, step=0): + """Returns action for given state as per current policy + Params + ======= + state (array_like): current state + epsilon (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + action_values = self.action(state=state, mode="eval").squeeze(0) + if explore: + action_values = self.noise.get_action(action_values, step) + + # Clip the output according to the action space of the env + action_values = np.clip( + action_values, self.action_space_env.low[0], + self.action_space_env.high[0] + ) + return action_values + + def remember(self, scenario: Experience): + self.memory.enqueue(scenario) + + def _learn(self, experiences): + # pylint: disable=line-too-long + """Update value parameters using given batch of experience tuples. + Params + ======= + experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + + states = torch.from_numpy(np.vstack([e.state for e in experiences]) + ).float().to(device) + actions = torch.from_numpy(np.vstack([e.action for e in experiences]) + ).long().to(device) + rewards = torch.from_numpy(np.vstack([e.reward for e in experiences]) + ).float().to(device) + next_states = torch.from_numpy( + np.vstack([e.next_state for e in experiences]) + ).float().to(device) + terminate = torch.from_numpy(np.vstack([e.done for e in experiences]) + ).float().to(device) + + self.critic.train() + self.critic_target.eval() + self.critic_1.train() + self.critic_target_1.eval() + self.actor_target.eval() + + # noise ~ N(0, sigma) + noise = torch.clamp( + torch.normal(mean=0.0, std=self.value_noise_sigma, size=actions.size()), + -self.value_noise_clip, self.value_noise_clip + ).to(device) + + # Compute the target Q value + target_q = self.critic_target.forward( + next_states, + self.actor_target.forward(next_states) + noise + ) + target_q_1 = self.critic_target_1.forward( + next_states, + self.actor_target.forward(next_states) + noise + ) + + min_target_q_value = torch.min( + torch.cat((target_q, target_q_1), dim=1), dim=1 + ).values.unsqueeze(dim=1) + + target_q = rewards + ((1 - terminate) * self.gamma * + min_target_q_value).detach() + + # Get current Q estimate + current_q = self.critic.forward(states, actions) + + # Get current Q estimate + current_q_1 = self.critic_1.forward(states, actions) + + # Compute critic loss + # Due to fact that the Q function fllows the Gumbel distribution, + # We only replace the loss function with the Gumbel loss function. + # The details can be found in the paper below: + # https://arxiv.org/pdf/2301.02328.pdf + # Appendix C (EXTREME Q-LEARNING) and + # the Q value Iteration part of C.1 (X-QL) + + # Compute critic loss + critic_loss = self.val_loss(current_q, target_q + ) + self.val_1_loss(current_q_1, target_q) + + # Optimize the critic + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + # Here, `Delayed policy updates` is needed + # This implementation will assume that `policy_freq = 1`. + # Compute actor loss + actor_loss = -self.critic.forward(states, self.actor.forward(states)).mean() + + # Optimize the actor + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self.update_actor_target_network() + self.update_critic_target_network() + + def soft_update(self, local_model, target_model): + """ + Soft update model parameters. + θ_target = τ * θ_local + (1 - τ) * θ_target + Token from + https://github.com/udacity/deep-reinforcement-learning/blob/master/dqn/exercise/dqn_agent.py + """ + for target_param, local_param in zip( + target_model.parameters(), local_model.parameters() + ): + target_param.data.copy_( + self.update_tau * local_param.data + + (1.0 - self.update_tau) * target_param.data + ) + + def update_actor_target_network(self): + self.soft_update(self.actor, self.actor_target) + + def update_critic_target_network(self): + self.soft_update(self.critic, self.critic_target)