diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 9c4640b..70bc3d9 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -25,4 +25,4 @@ jobs: - name: Build SDist run: pipx run build --sdist - - uses: pypa/gh-action-pypi-publish@v1.12.4 + - uses: pypa/gh-action-pypi-publish@v1.13.0 diff --git a/rubix/cosmology/base.py b/rubix/cosmology/base.py index b5ce7d2..1790b71 100644 --- a/rubix/cosmology/base.py +++ b/rubix/cosmology/base.py @@ -56,8 +56,8 @@ def __init__(self, Om0: float, w0: float, wa: float, h: float): self.wa = jnp.float32(wa) self.h = jnp.float32(h) - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def scale_factor_to_redshift( self, a: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -79,8 +79,8 @@ def scale_factor_to_redshift( z = 1.0 / a - 1.0 return z - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def _rho_de_z(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]: a = 1.0 / (1.0 + z) de_z = a ** (-3.0 * (1.0 + self.w0 + self.wa)) * lax.exp( @@ -88,8 +88,8 @@ def _rho_de_z(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."] ) return de_z - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def _Ez(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]: zp1 = 1.0 + z Ode0 = 1.0 - self.Om0 @@ -97,15 +97,15 @@ def _Ez(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]: E = jnp.sqrt(t) return E - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def _integrand_oneOverEz( self, z: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: return 1 / self._Ez(z) - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def comoving_distance_to_z( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -128,8 +128,8 @@ def comoving_distance_to_z( integrand = self._integrand_oneOverEz(z_table) return trapz(z_table, integrand) * C_SPEED * 1e-5 / self.h - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def luminosity_distance_to_z( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -150,8 +150,8 @@ def luminosity_distance_to_z( """ return self.comoving_distance_to_z(redshift) * (1 + redshift) - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def angular_diameter_distance_to_z( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -172,8 +172,8 @@ def angular_diameter_distance_to_z( """ return self.comoving_distance_to_z(redshift) / (1 + redshift) - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def distance_modulus_to_z( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -196,15 +196,15 @@ def distance_modulus_to_z( mu = 5.0 * jnp.log10(d_lum * 1e5) return mu - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def _hubble_time(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]: E0 = self._Ez(z) htime = 1e-16 * MPC / YEAR / self.h / E0 return htime - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def lookback_to_z( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -229,8 +229,8 @@ def lookback_to_z( th = self._hubble_time(0.0) return th * res - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def age_at_z0(self) -> Float[Array, "..."]: """ The function calculates the age of the universe at redshift 0. @@ -250,8 +250,8 @@ def age_at_z0(self) -> Float[Array, "..."]: th = self._hubble_time(0.0) return th * res - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def _age_at_z_kern( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -259,8 +259,8 @@ def _age_at_z_kern( tlook = self.lookback_to_z(redshift) return t0 - tlook - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def age_at_z( self, redshift: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -285,8 +285,8 @@ def age_at_z( def _age_at_z_vmap(self): return jit(vmap(self._age_at_z_kern)) - @jaxtyped(typechecker=typechecker) @jit + @jaxtyped(typechecker=typechecker) def angular_scale( self, z: Union[Float[Array, "..."], float] ) -> Float[Array, "..."]: @@ -327,9 +327,6 @@ def _Om_at_z(self, z): E = self._Ez(z) return self.Om0 * (1.0 + z) ** 3 / E / E - - - @jit def _delta_vir(self, z): x = self._Om(z) - 1.0 diff --git a/rubix/cosmology/utils.py b/rubix/cosmology/utils.py index 0579fec..60a6f9d 100644 --- a/rubix/cosmology/utils.py +++ b/rubix/cosmology/utils.py @@ -8,8 +8,8 @@ # Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L247C1-L256C1 -@jaxtyped(typechecker=typechecker) @jit +@jaxtyped(typechecker=typechecker) def _cumtrapz_scan_func(carryover, el): """ Integral helper function, which uses the formula for trapezoidal integration. @@ -37,8 +37,8 @@ def _cumtrapz_scan_func(carryover, el): # Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L278C1-L298C1 -@jaxtyped(typechecker=typechecker) @jit +@jaxtyped(typechecker=typechecker) def trapz( xarr: Union[jnp.ndarray, Float[Array, "n"]], yarr: Union[jnp.ndarray, Float[Array, "n"]],