Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Zimtohrli python version easier to jit. #128

Merged
merged 1 commit into from
Sep 26, 2024
Merged

Conversation

bunelr
Copy link
Contributor

@bunelr bunelr commented Sep 26, 2024

Changes:

  • Make the Cam class static from the point of view of jax. To allow jitting, the _hz_freqs array should not be a jax array, because depending on the values of the hyperparameters, the array shape should be different, making it impossible to jit. When attempting to jit the previous version of the code, creating that array under a jit context would fail. -> This requires adjusting the test tolerance from 1e-7 to 1e-5. The equivalent test in c++ uses 1e-2.

  • For Channels and Signals class, mark which fields are data_fields vs. meta_fields (that can be used for JIT cache)

  • Remove default arguments that are jnp.array by np.array

  • Make the Masking class static and remove the jit over non_masked_energy self being the first argument makes it unsuitable for jitting.

  • Added a test to check that:

    • Jitting works
    • Results are the same with/without jitting.

Changes:

* Make the `Cam` class static from the point of view of jax.
To allow jitting, the _hz_freqs array should not be a jax array, because depending on the values of the hyperparameters, the array shape should be different, making it impossible to jit. When attempting to jit the previous version of the code, creating that array under a jit context would fail.
-> This requires adjusting the test tolerance from 1e-7 to 1e-5.
The equivalent test in c++ uses 1e-2.

* For `Channels` and `Signals` class, mark which fields are data_fields vs. meta_fields (that can be used for JIT cache)

* Remove default arguments that are jnp.array by np.array

* Make the `Masking` class static and remove the `jit` over `non_masked_energy`
`self` being the first argument makes it unsuitable for jitting.

* Added a test to check that:
  - Jitting works
  - Results are the same with/without jitting.
@zond zond enabled auto-merge (rebase) September 26, 2024 10:23
@zond
Copy link
Collaborator

zond commented Sep 26, 2024

Great, thank you!

I enabled merge-as-soon-as-tests-pass.

@zond zond merged commit 4341af5 into google:main Sep 26, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants