Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
- name: Build SDist
run: pipx run build --sdist

- uses: pypa/gh-action-pypi-publish@v1.12.4
- uses: pypa/gh-action-pypi-publish@v1.13.0
31 changes: 14 additions & 17 deletions rubix/cosmology/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self, Om0: float, w0: float, wa: float, h: float):
self.wa = jnp.float32(wa)
self.h = jnp.float32(h)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def scale_factor_to_redshift(
self, a: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -79,33 +79,33 @@ def scale_factor_to_redshift(
z = 1.0 / a - 1.0
return z

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _rho_de_z(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
a = 1.0 / (1.0 + z)
de_z = a ** (-3.0 * (1.0 + self.w0 + self.wa)) * lax.exp(
-3.0 * self.wa * (1.0 - a)
)
return de_z

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _Ez(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
zp1 = 1.0 + z
Ode0 = 1.0 - self.Om0
t = self.Om0 * zp1**3 + Ode0 * self._rho_de_z(z)
E = jnp.sqrt(t)
return E

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _integrand_oneOverEz(
self, z: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
return 1 / self._Ez(z)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def comoving_distance_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -128,8 +128,8 @@ def comoving_distance_to_z(
integrand = self._integrand_oneOverEz(z_table)
return trapz(z_table, integrand) * C_SPEED * 1e-5 / self.h

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def luminosity_distance_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -150,8 +150,8 @@ def luminosity_distance_to_z(
"""
return self.comoving_distance_to_z(redshift) * (1 + redshift)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def angular_diameter_distance_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -172,8 +172,8 @@ def angular_diameter_distance_to_z(
"""
return self.comoving_distance_to_z(redshift) / (1 + redshift)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def distance_modulus_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -196,15 +196,15 @@ def distance_modulus_to_z(
mu = 5.0 * jnp.log10(d_lum * 1e5)
return mu

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _hubble_time(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
E0 = self._Ez(z)
htime = 1e-16 * MPC / YEAR / self.h / E0
return htime

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def lookback_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -229,8 +229,8 @@ def lookback_to_z(
th = self._hubble_time(0.0)
return th * res

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def age_at_z0(self) -> Float[Array, "..."]:
"""
The function calculates the age of the universe at redshift 0.
Expand All @@ -250,17 +250,17 @@ def age_at_z0(self) -> Float[Array, "..."]:
th = self._hubble_time(0.0)
return th * res

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _age_at_z_kern(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
t0 = self.age_at_z0()
tlook = self.lookback_to_z(redshift)
return t0 - tlook

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def age_at_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -285,8 +285,8 @@ def age_at_z(
def _age_at_z_vmap(self):
return jit(vmap(self._age_at_z_kern))

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def angular_scale(
self, z: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand Down Expand Up @@ -327,9 +327,6 @@ def _Om_at_z(self, z):
E = self._Ez(z)
return self.Om0 * (1.0 + z) ** 3 / E / E




@jit
def _delta_vir(self, z):
x = self._Om(z) - 1.0
Expand Down
4 changes: 2 additions & 2 deletions rubix/cosmology/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L247C1-L256C1
@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _cumtrapz_scan_func(carryover, el):
"""
Integral helper function, which uses the formula for trapezoidal integration.
Expand Down Expand Up @@ -37,8 +37,8 @@ def _cumtrapz_scan_func(carryover, el):


# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L278C1-L298C1
@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def trapz(
xarr: Union[jnp.ndarray, Float[Array, "n"]],
yarr: Union[jnp.ndarray, Float[Array, "n"]],
Expand Down
Loading