Skip to content

Conversation

@farazeid
Copy link
Owner

No description provided.

@farazeid
Copy link
Owner Author

farazeid commented Jan 16, 2026

TODO:

  • Evaluate functools.partial to JAX-native static argnames, etc — Result: functools.partial is (i) convention for configs and functions, (ii) doesn't require jit of internal functions (instead can, nicely, just jit outer-most function) unlike static_argnames, (iii) more versatile as partial creates a closure so frozen params don't need to be hashable unlike static_argnames. Reference
  • Refactor intrinsic motivation functionality also — ideally learning-algorithm-agnostic add-ons are defined in separate files to the learning algorithm
  • Refactor RunState and UpdateState use
  • Replace ppo.py with this refactor
  • Move metric logging function definitions out of batch_step()

greptile-apps[bot]

This comment was marked as off-topic.

@farazeid
Copy link
Owner Author

farazeid commented Jan 19, 2026

@gemini-code-assist

gemini-code-assist[bot]

This comment was marked as outdated.

Repository owner deleted a comment from greptile-apps bot Jan 19, 2026
@farazeid
Copy link
Owner Author

@gemini-code-assist for JAX and Flax NNX, should I be using functools.partial or static_argnames or something else to what I am trying to do?

@gemini-code-assist
Copy link

Hi @farazeid,

That's a great question, and it touches on an important aspect of writing efficient JAX/Flax NNX code. Let's break down functools.partial and static_argnames in this context.

functools.partial

As you've used it in this PR, functools.partial is an excellent way to pre-fill arguments of a function. When you use partial with a function that is then passed to jax.jit or nnx.scan (which implicitly uses jax.jit internally for its body function), the arguments bound by partial are treated as static arguments by JAX's JIT compiler.

Pros of functools.partial for static arguments:

  • Simplicity for Python objects: It's very natural for passing Python objects like dictionaries (config), environment instances (env), lists (logging_threads), or callables (lr_schedule) that are meant to be constant during the JIT-compiled execution. JAX needs these to be static to compile the function.
  • Readability: It can make the function signature cleaner by moving constant parameters out of the main argument list.
  • Implicit static behavior: JAX automatically treats these pre-bound arguments as static, avoiding recompilation if only the dynamic (non-partial'd) arguments change.

static_argnames

static_argnames is an argument directly to jax.jit (or jax.vmap, jax.scan, etc.) that explicitly tells JAX which arguments of the jitted function itself should be treated as static. This is typically used when you want to pass JAX arrays or other JAX-compatible types as static arguments, or when you need more explicit control.

Pros of static_argnames:

  • Explicit control: You explicitly declare which arguments are static, which can be clearer in some cases.
  • Flexibility for JAX types: It allows you to pass JAX arrays as static arguments, which partial can also do, but static_argnames might be preferred for direct function calls that are jitted.

Which one to use for JAX and Flax NNX?

For your use case, where you are passing configuration dictionaries, environment objects, and other Python-native types that define the structure or behavior of your JAX computation, functools.partial is generally the idiomatic and recommended approach.

In your run and batch_step functions, you are passing config, env, env_params, n_batches, batch_size, logging_threads, and lr_schedule via partial to the nnx.scan'd functions. These are exactly the kinds of arguments that should be static (i.e., constant across scan iterations or JIT compilations) and are well-suited for partial.

If you were to use static_argnames, you would apply it directly to run_fn or batch_step_fn if they were standalone jax.jit'd functions, and list the names of the arguments you want to be static. However, since nnx.scan is handling the iteration, and partial is used to pre-bind the arguments before passing to nnx.scan, the effect is largely the same: those arguments are treated as static for the compiled scan body.

In summary: Your current approach using functools.partial is perfectly appropriate and effective for making these configuration and environment-related parameters static within your JAX/Flax NNX setup. It's a common and good practice for this kind of scenario.

Let me know if you have any more questions!

@farazeid farazeid force-pushed the refactor--legible-ppo-using-functools.partial-instead-of-purejaxrl's-paradigm branch from 5210e1e to f6d7bca Compare January 19, 2026 17:41
@farazeid farazeid force-pushed the refactor--legible-ppo-using-functools.partial-instead-of-purejaxrl's-paradigm branch from f6d7bca to 4046299 Compare January 19, 2026 17:43
@farazeid

This comment was marked as outdated.

gemini-code-assist[bot]

This comment was marked as outdated.

Repository owner deleted a comment from gemini-code-assist bot Jan 19, 2026
gemini-code-assist[bot]

This comment was marked as outdated.

@farazeid
Copy link
Owner Author

@gemini-code-assist review

gemini-code-assist[bot]

This comment was marked as outdated.

@farazeid farazeid force-pushed the refactor--legible-ppo-using-functools.partial-instead-of-purejaxrl's-paradigm branch from a30fd9d to d70aac9 Compare January 20, 2026 09:45
Repository owner deleted a comment from gemini-code-assist bot Jan 20, 2026
Repository owner deleted a comment from gemini-code-assist bot Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants