Skip to content

Commit

Permalink
c51 working
Browse files Browse the repository at this point in the history
  • Loading branch information
manila95 committed Jan 18, 2024
1 parent 10fafbc commit ac29639
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from island_navigation import *

import gymnasium.spaces as spaces


def parse_args():
Expand Down Expand Up @@ -43,13 +46,13 @@ def parse_args():
# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="CartPole-v1",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=500000,
parser.add_argument("--total-timesteps", type=int, default=50000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--n-atoms", type=int, default=101,
parser.add_argument("--n-atoms", type=int, default=51,
help="the number of atoms")
parser.add_argument("--v-min", type=float, default=-100,
help="the return lower bound")
Expand All @@ -69,9 +72,9 @@ def parse_args():
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.5,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
parser.add_argument("--learning-starts", type=int, default=10000,
parser.add_argument("--learning-starts", type=int, default=1000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=10,
parser.add_argument("--train-frequency", type=int, default=1,
help="the frequency of training")
args = parser.parse_args()
# fmt: on
Expand Down Expand Up @@ -102,9 +105,9 @@ def __init__(self, env, n_atoms=101, v_min=-100, v_max=100):
self.env = env
self.n_atoms = n_atoms
self.register_buffer("atoms", torch.linspace(v_min, v_max, steps=n_atoms))
self.n = env.single_action_space.n
self.n = 4
self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
nn.Linear(np.array(env.observation_spec()["board"].shape).prod(), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
Expand Down Expand Up @@ -165,10 +168,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
envs = IslandNavigationEnvironment()

# assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs, n_atoms=args.n_atoms, v_min=args.v_min, v_max=args.v_max).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate, eps=0.01 / args.batch_size)
Expand All @@ -177,44 +179,64 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
spaces.Box(shape=(48,), low=0, high=10),
spaces.Discrete(4),
device,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
_, _, _, obs = envs.reset()
episodic_return = 0
num_terminations = 0
scores = []
episode = 0
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * 3000, episode)
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
actions = np.array([np.random.choice(range(4)) for _ in range(args.num_envs)])
else:
actions, pmf = q_network.get_action(torch.Tensor(obs).to(device))
actions, pmf = q_network.get_action(torch.Tensor(obs["board"].reshape(1, -1)).to(device))
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
_, reward, not_done, next_obs = envs.step(actions)
if not_done is None:
_, _, _, obs = envs.reset()
done = False
episodic_return = 0
continue
else:
done = 1 - not_done
episodic_return += reward

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)

if done:
num_terminations += (envs.environment_data['safety'] < 1)
scores.append(episodic_return)
print(f"global_step={global_step}, episodic_return={episodic_return}, total cost={num_terminations}")
writer.add_scalar("charts/episodic_return", np.mean(scores[-100:]), global_step)
# writer.add_scalar("charts/Total Goals", total_goals, global_step)
# writer.add_scalar("charts/Avg Return", np.mean(scores[-100:]), global_step)
# writer.add_scalar("charts/total_cost", total_cost, global_step)
# writer.add_scalar("charts/episodic_cost", cost, global_step)
# writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
writer.add_scalar("charts/Num terminations", num_terminations, global_step)
_, _, _, obs = envs.reset()
episodic_return = 0
writer.add_scalar("charts/epsilon", epsilon, global_step)
episode += 1
if episode > 3000:
break
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
# real_next_obs = next_obs.copy()
# for idx, d in enumerate(truncated):
# if d:
# real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs["board"].reshape(1, -1), next_obs["board"].reshape(1, -1), actions, reward, done, {})

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -291,5 +313,5 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "C51", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
# envs.close()
writer.close()

0 comments on commit ac29639

Please sign in to comment.