Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxpr string representation of dataclasses #26202

Open
botev opened this issue Jan 30, 2025 · 1 comment
Open

Jaxpr string representation of dataclasses #26202

botev opened this issue Jan 30, 2025 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@botev
Copy link
Contributor

botev commented Jan 30, 2025

Please:
I would like dataclasses which are part of the Jaxpr to be printed as though the are dicts (e.g. each attribute on new line) rather than everything on a single line.

E.g. instead of:

o:f32[3] p:f32[3,3] q:f32[3] = kernel_func[
      kernel_spec=Spec(name='lower_forward', f=<function kernel_fwd at 0x7e16cc5cdd00>, in_format=[ShapedArray(float32[3,3]), ShapedArray(float32[3])], out_format=[ShapedArray(float32[3]), ShapedArray(float32[3,3]), ShapedArray(float32[3])], kwargs={'jvp_residuals': True}
] n b

To be:

o:f32[3] p:f32[3,3] q:f32[3] = kernel_func[
      kernel_spec=Spec(
         name='lower_forward', 
         f=<function kernel_fwd at 0x7e16cc5cdd00>, 
         in_format=[ShapedArray(float32[3,3]), ShapedArray(float32[3])], 
         out_format=[ShapedArray(float32[3]), ShapedArray(float32[3,3]), ShapedArray(float32[3])], 
         kwargs={'jvp_residuals': True}
     ),
] n b

Happy to do a PR if someone can point me to where the code for this is.

@botev botev added the enhancement New feature or request label Jan 30, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 30, 2025

Am I understanding correctly that Spec is not registered as a pytree, and is a static parameter to a primitive named kernel_func? If so then I think what the jaxpr is printing is the standard __repr__ of the dataclass, and to change that you could overload its __repr__ function.

If I'm misunderstanding the context, then it might be helpful to add a minimal reproducible example to make things more clear.

@jakevdp jakevdp self-assigned this Jan 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants