Skip to content

Commit

Permalink
Improve train script
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian Weichart committed Aug 12, 2024
1 parent ab582c0 commit 0a87e1c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions examples/train_lin_grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Args:
"""the target network update rate"""
target_network_frequency: int = 1
"""the timesteps it takes to update the target network"""
batch_size: int = 32
batch_size: int = 512
"""the batch size of sample from the reply memory"""
start_e: float = 1
"""the starting epsilon for exploration"""
Expand All @@ -140,7 +140,7 @@ def thunk():
episode_trigger=lambda x: x % args.video_epoch_interval == 0,
)
else:
env = gym.make(env_id)
env = gym.make(env_id, render_mode="rgb_array", gravity=False)
env = GroupedActionsObservations(
env, observation_wrappers=[FeatureVectorObservation(env)]
)
Expand Down Expand Up @@ -371,6 +371,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
)
old_val = q_network(data.observations).squeeze(-1).squeeze(-1)

assert old_val.shape == td_target.shape
loss = F.mse_loss(old_val, td_target)

if global_step % 100 == 0:
Expand Down

0 comments on commit 0a87e1c

Please sign in to comment.