diff --git a/DDPG/ddpg.py b/DDPG/ddpg.py index 5bbcaa9..c5aebca 100644 --- a/DDPG/ddpg.py +++ b/DDPG/ddpg.py @@ -246,7 +246,7 @@ def _learn(self, experiences): states = torch.from_numpy(np.vstack([e.state for e in experiences]) ).float().to(device) actions = torch.from_numpy(np.vstack([e.action for e in experiences]) - ).long().to(device) + ).float().to(device) rewards = torch.from_numpy(np.vstack([e.reward for e in experiences]) ).float().to(device) next_states = torch.from_numpy( diff --git a/TD3/TD3Agent_100.gif b/TD3/TD3Agent_100.gif index 18df994..fee5553 100644 Binary files a/TD3/TD3Agent_100.gif and b/TD3/TD3Agent_100.gif differ diff --git a/TD3/td3.py b/TD3/td3.py index 66b9c57..a39c6d8 100644 --- a/TD3/td3.py +++ b/TD3/td3.py @@ -260,7 +260,7 @@ def _learn(self, experiences): states = torch.from_numpy(np.vstack([e.state for e in experiences]) ).float().to(device) actions = torch.from_numpy(np.vstack([e.action for e in experiences]) - ).long().to(device) + ).float().to(device) rewards = torch.from_numpy(np.vstack([e.reward for e in experiences]) ).float().to(device) next_states = torch.from_numpy( diff --git a/XDDPG/xddpg.py b/XDDPG/xddpg.py index 1dd9b06..53bb0e1 100644 --- a/XDDPG/xddpg.py +++ b/XDDPG/xddpg.py @@ -300,7 +300,7 @@ def _learn(self, experiences): states = torch.from_numpy(np.vstack([e.state for e in experiences]) ).float().to(device) actions = torch.from_numpy(np.vstack([e.action for e in experiences]) - ).long().to(device) + ).float().to(device) rewards = torch.from_numpy(np.vstack([e.reward for e in experiences]) ).float().to(device) next_states = torch.from_numpy( diff --git a/XTD3/xtd3.py b/XTD3/xtd3.py index 7b01e4a..d13adfb 100644 --- a/XTD3/xtd3.py +++ b/XTD3/xtd3.py @@ -318,7 +318,7 @@ def _learn(self, experiences): states = torch.from_numpy(np.vstack([e.state for e in experiences]) ).float().to(device) actions = torch.from_numpy(np.vstack([e.action for e in experiences]) - ).long().to(device) + ).float().to(device) rewards = torch.from_numpy(np.vstack([e.reward for e in experiences]) ).float().to(device) next_states = torch.from_numpy(