Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Zimtohrli python version easier to jit. #128

Merged
merged 1 commit into from
Sep 26, 2024
Merged

Commits on Sep 26, 2024

  1. Make Zimtohrli python version easier to jit.

    Changes:
    
    * Make the `Cam` class static from the point of view of jax.
    To allow jitting, the _hz_freqs array should not be a jax array, because depending on the values of the hyperparameters, the array shape should be different, making it impossible to jit. When attempting to jit the previous version of the code, creating that array under a jit context would fail.
    -> This requires adjusting the test tolerance from 1e-7 to 1e-5.
    The equivalent test in c++ uses 1e-2.
    
    * For `Channels` and `Signals` class, mark which fields are data_fields vs. meta_fields (that can be used for JIT cache)
    
    * Remove default arguments that are jnp.array by np.array
    
    * Make the `Masking` class static and remove the `jit` over `non_masked_energy`
    `self` being the first argument makes it unsuitable for jitting.
    
    * Added a test to check that:
      - Jitting works
      - Results are the same with/without jitting.
    bunelr committed Sep 26, 2024
    Configuration menu
    Copy the full SHA
    b42eb03 View commit details
    Browse the repository at this point in the history