Skip to content

Commit

Permalink
fix wrong data type for actions in DDPG, TD3, XDDPG, and XTD3 agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Alwaysproblem committed Jan 3, 2024
1 parent 978c29f commit 57071e3
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion DDPG/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _learn(self, experiences):
states = torch.from_numpy(np.vstack([e.state for e in experiences])
).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences])
).long().to(device)
).float().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])
).float().to(device)
next_states = torch.from_numpy(
Expand Down
Binary file modified TD3/TD3Agent_100.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion TD3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _learn(self, experiences):
states = torch.from_numpy(np.vstack([e.state for e in experiences])
).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences])
).long().to(device)
).float().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])
).float().to(device)
next_states = torch.from_numpy(
Expand Down
2 changes: 1 addition & 1 deletion XDDPG/xddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _learn(self, experiences):
states = torch.from_numpy(np.vstack([e.state for e in experiences])
).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences])
).long().to(device)
).float().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])
).float().to(device)
next_states = torch.from_numpy(
Expand Down
2 changes: 1 addition & 1 deletion XTD3/xtd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _learn(self, experiences):
states = torch.from_numpy(np.vstack([e.state for e in experiences])
).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences])
).long().to(device)
).float().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])
).float().to(device)
next_states = torch.from_numpy(
Expand Down

0 comments on commit 57071e3

Please sign in to comment.