Replies: 1 comment 1 reply
-
I think the easiest way to make something like this compatible with JAX tracing is to change your data model from one that uses an array of structs to one that uses a struct of arrays. For example, your agents struct might look like this: agents = {'p': jnp.array(np.random.uniform(size=(n_players, n_actions)))} Then you would have no issues indexing into the array with a tracer. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I am looking to implement an RL setup of playing a card game. The game consists of multiple plies and every time the winner of the previous ply begins the next one. I want to have the agents be controlled by NNs, which carry their state as pytrees, so I run into the issue that in the game loop I need to dynamically index into a list of pytrees (to select and run the NN whose turn it is), which results in
TypeError: list indices must be integers or slices, not DynamicJaxprTracer
. Is there some standard workaround? I am interested in two scenarios:With 1., if I understand correctly, hand-rolling some awkward function to collapse the entire network state into a single contiguous jax array and storing agent states as a
[n_agents, n_weights_per_agent]
would allow me to dynamically select the weights in question. This looks very awkward, but at least makes it seem technically possible. Is 2. latter even possible? I'd like this thing to be possibly speedy, but I am aware that executing the various functions would probably have to be sequential.Below is a minimal example to show what I mean (the easier case).
Thank you
Beta Was this translation helpful? Give feedback.
All reactions