Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TD3 main executable file #19

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ For example (DDPG):

![AC](AC/A2CAgent_600.gif)

- improve `AWR`, `DDPG` with Gumbel Distribution Regression from [`XQL`](https://div99.github.io/XQL):
- [x] [TD3](https://arxiv.org/pdf/1802.09477.pdf)

![TD3](TD3/TD3Agent_100.gif)

- improve `AWR`, `DDPG` `TD3` with Gumbel Distribution Regression from [`XQL`](https://div99.github.io/XQL):
- XAWR

![XAWR](XAWR/XAWRAgent_100.gif)
Expand All @@ -111,6 +115,10 @@ For example (DDPG):

![XDDPG](XDDPG/XDDPGAgent_200.gif)

- XTD3

![XTD3](XTD3/XTD3Agent_100.gif)

## Reference

- TrainMonitor and Generategif modified from [coax](https://github.com/coax-dev/coax)
Expand Down
Binary file added TD3/TD3Agent_100.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions TD3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# pylint: disable=all
136 changes: 136 additions & 0 deletions TD3/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""main executable file for TD3"""
import os
import logging
from itertools import repeat
import gymnasium as gym
import torch
import numpy as np
from util import generate_gif
from util.wrappers import TrainMonitor
from util.buffer import Experience
from collections import deque
# pylint: disable=invalid-name
from TD3.td3 import TD3Agent as TD3_torch

Agent = TD3_torch
logging.basicConfig(level=logging.INFO)

torch.manual_seed(0)
np.random.seed(0)

EPSILON_DECAY_STEPS = 100


def main(
n_episodes=2000,
max_t=200,
eps_start=1.0,
eps_end=0.01,
eps_decay=0.996,
score_term_rules=lambda s: False,
time_interval="25ms"
):
# pylint: disable=line-too-long
"""Deep Q-Learning

Params
======
n_episodes (int): maximum number of training epsiodes
max_t (int): maximum number of timesteps per episode
eps_start (float): starting value of epsilon, for epsilon-greedy action selection
eps_end (float): minimum value of epsilon
eps_decay (float): mutiplicative factor (per episode) for decreasing epsilon

"""
scores = [] # list containing score from each episode
scores_window = deque(maxlen=100) # last 100 scores
eps = eps_start

env = gym.make(
"Pendulum-v1",
render_mode="rgb_array",
)
# env = gym.make(
# "LunarLander-v2",
# render_mode="rgb_array",
# continuous=True,
# )
# env = gym.make("MountainCarContinuous-v0", render_mode="rgb_array")
env = TrainMonitor(env, tensorboard_dir="./logs", tensorboard_write_all=True)

gamma = 0.99
batch_size = 64
learn_iteration = 16
update_tau = 0.5

lr_actor = 0.0001
lr_critic = 0.001

mu = 0.0
theta = 0.15
max_sigma = 0.3
min_sigma = 0.3
decay_period = 100000
value_noise_clip = 0.5
value_noise_sigma = 0.5

agent = Agent(
state_dims=env.observation_space,
action_space=env.action_space,
lr_actor=lr_actor,
lr_critic=lr_critic,
gamma=gamma,
batch_size=batch_size,
forget_experience=False,
update_tau=update_tau,
mu=mu,
theta=theta,
max_sigma=max_sigma,
min_sigma=min_sigma,
decay_period=decay_period,
value_noise_clip=value_noise_clip,
value_noise_sigma=value_noise_sigma
)
dump_gif_dir = f"images/{agent.__class__.__name__}/{agent.__class__.__name__}_{{}}.gif"
for i_episode in range(1, n_episodes + 1):
state, _ = env.reset()
score = 0
for t, _ in enumerate(repeat(0, max_t)):
action = agent.take_action(state=state, explore=True, step=t * i_episode)
next_state, reward, done, _, _ = env.step(action)
agent.remember(Experience(state, action, reward, next_state, done))
agent.learn(learn_iteration)

state = next_state
score += reward

if done or score_term_rules(score):
break

scores_window.append(score) ## save the most recent score
scores.append(score) ## sae the most recent score
eps = max(eps * eps_decay, eps_end) ## decrease the epsilon
print(" " * os.get_terminal_size().columns, end="\r")
print(
f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}",
end="\r"
)

if i_episode and i_episode % 100 == 0:
print(" " * os.get_terminal_size().columns, end="\r")
print(
f"\rEpisode {i_episode}\tAverage Score {np.mean(scores_window):.2f}"
)
generate_gif(
env,
filepath=dump_gif_dir.format(i_episode),
policy=lambda s: agent.take_action(s, explore=False),
duration=float(time_interval.split("ms")[0]),
max_episode_steps=max_t
)

return scores


if __name__ == "__main__":
main()
Loading