You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to write a transform using Trace (Let's call it TransformTrace). For the specific transform, I need to be able to transform all primitives, not just primitives that are bound with data-dependence to the original function parameters. I read about omnistaging in the autodidax tutorial (very helpful, by the way!), and tried using dynamic=True when using jax.core.new_main. Before I go any further, just a disclaimer that I'm still very new to JAX's core, so my vocabulary may not be perfect. Please correct me if I use any incorrect jargon.
The transform is supposed to be bind other primitives, so that other transforms can be applied after this one, before the evaluation trace does its job. However, I'm running into trouble with this step. With the way that Primitive.bind is implemented, it always gives priority to the dynamic trace over the trace at the top of the trace stack when the primitive's arguments are not tracers. However, in this case, the dynamic trace is the TransformTrace's MainTrace, so I end up in an infinite recursion of TransformTrace.process_primitive calling Primitive.bind.
I'm at a loss about what to do, and would appreciate some help. Is this even possible? Or is it a requirement for omnistaging to only happen at the bottom of the stack such that no new primitives are bound?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm trying to write a transform using
Trace
(Let's call itTransformTrace
). For the specific transform, I need to be able to transform all primitives, not just primitives that are bound with data-dependence to the original function parameters. I read about omnistaging in the autodidax tutorial (very helpful, by the way!), and tried usingdynamic=True
when usingjax.core.new_main
. Before I go any further, just a disclaimer that I'm still very new to JAX's core, so my vocabulary may not be perfect. Please correct me if I use any incorrect jargon.The transform is supposed to be bind other primitives, so that other transforms can be applied after this one, before the evaluation trace does its job. However, I'm running into trouble with this step. With the way that
Primitive.bind
is implemented, it always gives priority to the dynamic trace over the trace at the top of the trace stack when the primitive's arguments are not tracers. However, in this case, the dynamic trace is theTransformTrace
'sMainTrace
, so I end up in an infinite recursion ofTransformTrace.process_primitive
callingPrimitive.bind
.I'm at a loss about what to do, and would appreciate some help. Is this even possible? Or is it a requirement for omnistaging to only happen at the bottom of the stack such that no new primitives are bound?
Beta Was this translation helpful? Give feedback.
All reactions