Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 1dd31c8

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Fix issues related to new behavior of JAX DeviceArray.copy()
PiperOrigin-RevId: 438711926
1 parent 83252f9 commit 1dd31c8

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

trax/rl/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def play(env, policy, dm_suite=False, max_steps=None, last_observation=None):
271271
cur_trajectory = Trajectory(last_observation)
272272
while not done and (max_steps is None or cur_step < max_steps):
273273
action, dist_inputs = policy(cur_trajectory)
274+
action = np.asarray(action)
274275
step = env.step(action)
275276
if dm_suite:
276277
(observation, reward, done) = (

0 commit comments

Comments
 (0)