You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Please:
I would like dataclasses which are part of the Jaxpr to be printed as though the are dicts (e.g. each attribute on new line) rather than everything on a single line.
E.g. instead of:
o:f32[3] p:f32[3,3] q:f32[3] = kernel_func[
kernel_spec=Spec(name='lower_forward', f=<function kernel_fwd at 0x7e16cc5cdd00>, in_format=[ShapedArray(float32[3,3]), ShapedArray(float32[3])], out_format=[ShapedArray(float32[3]), ShapedArray(float32[3,3]), ShapedArray(float32[3])], kwargs={'jvp_residuals': True}
] n b
To be:
o:f32[3] p:f32[3,3] q:f32[3] = kernel_func[
kernel_spec=Spec(
name='lower_forward',
f=<function kernel_fwd at 0x7e16cc5cdd00>,
in_format=[ShapedArray(float32[3,3]), ShapedArray(float32[3])],
out_format=[ShapedArray(float32[3]), ShapedArray(float32[3,3]), ShapedArray(float32[3])],
kwargs={'jvp_residuals': True}
),
] n b
Happy to do a PR if someone can point me to where the code for this is.
The text was updated successfully, but these errors were encountered:
Am I understanding correctly that Spec is not registered as a pytree, and is a static parameter to a primitive named kernel_func? If so then I think what the jaxpr is printing is the standard __repr__ of the dataclass, and to change that you could overload its __repr__ function.
If I'm misunderstanding the context, then it might be helpful to add a minimal reproducible example to make things more clear.
Please:
I would like dataclasses which are part of the Jaxpr to be printed as though the are dicts (e.g. each attribute on new line) rather than everything on a single line.
E.g. instead of:
To be:
Happy to do a PR if someone can point me to where the code for this is.
The text was updated successfully, but these errors were encountered: