From 9f37aa256880b8c5604ac613e39a81bfd7b38458 Mon Sep 17 00:00:00 2001 From: Kelvinson Date: Thu, 19 Apr 2018 07:13:18 -0700 Subject: [PATCH 1/3] fix traj dimension error --- epg/agents.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/epg/agents.py b/epg/agents.py index ee5715c..267c8e4 100644 --- a/epg/agents.py +++ b/epg/agents.py @@ -117,7 +117,7 @@ def _compute_ppo_loss(self, obs, acts, at, vt, old_params): return ppo_surr_loss def update(self, obs, acts, rews, dones, ppo_factor, inner_opt_freq): - + epg_rews = rews # Want to zero out rewards to the EPG loss function? # epg_rews = np.zeros_like(rews) @@ -139,8 +139,8 @@ def update(self, obs, acts, rews, dones, ppo_factor, inner_opt_freq): act_pad = np.zeros((self._buffer_size - acts.shape[0], acts.shape[1]), dtype=np.float32) pad = np.hstack([obs_pad, act_pad, rew_pad[:, None], auxs_pad[:, None], done_pad[:, None]]) traj = np.vstack([pad, traj]) - traj[:, obs.shape[1] + acts.shape[1]] = epg_rews - traj[:, -1] = dones + traj[0:len(epg_rews), obs.shape[1] + acts.shape[1]] = epg_rews + traj[0:len(dones), -1] = dones # Since the buffer length can be larger than the set of new samples, we truncate the # trajectories here for PPO. From 3cbad6fd7af46dfbb275bcadc9e1d6a35413dae9 Mon Sep 17 00:00:00 2001 From: Kelvinson Date: Thu, 19 Apr 2018 07:16:08 -0700 Subject: [PATCH 2/3] format --- epg/agents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/epg/agents.py b/epg/agents.py index 267c8e4..c09337b 100644 --- a/epg/agents.py +++ b/epg/agents.py @@ -117,7 +117,7 @@ def _compute_ppo_loss(self, obs, acts, at, vt, old_params): return ppo_surr_loss def update(self, obs, acts, rews, dones, ppo_factor, inner_opt_freq): - + epg_rews = rews # Want to zero out rewards to the EPG loss function? # epg_rews = np.zeros_like(rews) From 298d36523786214161929de15ae308dce823ab9e Mon Sep 17 00:00:00 2001 From: Kelvinson Date: Thu, 19 Apr 2018 07:26:01 -0700 Subject: [PATCH 3/3] format fix --- epg/agents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/epg/agents.py b/epg/agents.py index c09337b..fd1dd0d 100644 --- a/epg/agents.py +++ b/epg/agents.py @@ -117,7 +117,7 @@ def _compute_ppo_loss(self, obs, acts, at, vt, old_params): return ppo_surr_loss def update(self, obs, acts, rews, dones, ppo_factor, inner_opt_freq): - + epg_rews = rews # Want to zero out rewards to the EPG loss function? # epg_rews = np.zeros_like(rews)