Replies: 3 comments 3 replies
-
Hey @hrbigelow, this is a great question! Planning on documenting it in the near future as an advanced guided. Currently the best documentation for how this works internally is in the Lines 1060 to 1068 in 0679702 Basically I'm omitting many details here but that is the core mechanism. In the end its just |
Beta Was this translation helpful? Give feedback.
-
Thanks @cgarciae, Thanks - yes actually I did study the graph.py and extract.py code. it helped a bit, although some parts are still mysterious. One follow-up: is taking a merge of a split of some argument values restore those same values (aside from side effect changes to the graph context)? That is: round_trip = merge(split(original))
assert round_trip == original # equal values And just to fix ideas maybe another way to write what you wrote, is this more or less the pattern? (pretend everything inside def nnx_fn(*args):
pure_args = split(args) # 1
jax.jit {
args = merge(pure_args) # 2
out = fn(*args) # Does args have the same values as provided to nnx_fn?
args_out = clear_non_graph_nodes(args)
pure_out, pure_args_out = split(args_out, out) # 3
}
out = merge(pure_out) # 4
return out |
Beta Was this translation helpful? Give feedback.
-
okay, quick follow up, when you say you get an identical graph, do you mean something like the container structure between |
Beta Was this translation helpful? Give feedback.
-
Hi @cgarciae,
I wanted to get a basic understanding of how NNX allows an impure function to be traced by jax? Taking
nnx.jit
as an example:The original non-pure function
f
is first wrapped in a function object, for exampleJitFn
which accepts pure (constant) arguments and returns bothf
's output (in pure form) plus the possibly modified versions of each pure input.f
is called within a graph context which monitors various things done during the call tof
.Mathematically this makes sense - you can express modifications to an argument as if it's a constant argument, and you return a modified copy of that input. But, ultimately, inside
JitFn
,f
is still being called as-is, in unmodified form, complete with any statements which mutate variables:out = self.f(*args, **kwargs)
What happens when
jax.jit
traces this part of the code? I mean, I understand conceptually how the outer functionJitFn
appears pure, butjax.jit
actually traces individual statements, and so it requires the whole body of the function to not modify tensors.Thanks again!
Beta Was this translation helpful? Give feedback.
All reactions