Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5
with:
# setuptools_scm requires a non-shallow clone of the repository
fetch-depth: 0

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -52,10 +52,10 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v5

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.11"

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ jobs:
id-token: write

steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v5
with:
# setuptools_scm requires a non-shallow clone of the repository
fetch-depth: 0

- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
name: Install Python

- name: Build SDist
run: pipx run build --sdist

- uses: pypa/gh-action-pypi-publish@v1.12.4
- uses: pypa/gh-action-pypi-publish@v1.13.0
31 changes: 14 additions & 17 deletions rubix/cosmology/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self, Om0: float, w0: float, wa: float, h: float):
self.wa = jnp.float32(wa)
self.h = jnp.float32(h)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def scale_factor_to_redshift(
self, a: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -79,33 +79,33 @@ def scale_factor_to_redshift(
z = 1.0 / a - 1.0
return z

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _rho_de_z(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
a = 1.0 / (1.0 + z)
de_z = a ** (-3.0 * (1.0 + self.w0 + self.wa)) * lax.exp(
-3.0 * self.wa * (1.0 - a)
)
return de_z

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _Ez(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
zp1 = 1.0 + z
Ode0 = 1.0 - self.Om0
t = self.Om0 * zp1**3 + Ode0 * self._rho_de_z(z)
E = jnp.sqrt(t)
return E

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _integrand_oneOverEz(
self, z: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
return 1 / self._Ez(z)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def comoving_distance_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -128,8 +128,8 @@ def comoving_distance_to_z(
integrand = self._integrand_oneOverEz(z_table)
return trapz(z_table, integrand) * C_SPEED * 1e-5 / self.h

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def luminosity_distance_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -150,8 +150,8 @@ def luminosity_distance_to_z(
"""
return self.comoving_distance_to_z(redshift) * (1 + redshift)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def angular_diameter_distance_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -172,8 +172,8 @@ def angular_diameter_distance_to_z(
"""
return self.comoving_distance_to_z(redshift) / (1 + redshift)

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def distance_modulus_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -196,15 +196,15 @@ def distance_modulus_to_z(
mu = 5.0 * jnp.log10(d_lum * 1e5)
return mu

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _hubble_time(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
E0 = self._Ez(z)
htime = 1e-16 * MPC / YEAR / self.h / E0
return htime

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def lookback_to_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -229,8 +229,8 @@ def lookback_to_z(
th = self._hubble_time(0.0)
return th * res

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def age_at_z0(self) -> Float[Array, "..."]:
"""
The function calculates the age of the universe at redshift 0.
Expand All @@ -250,17 +250,17 @@ def age_at_z0(self) -> Float[Array, "..."]:
th = self._hubble_time(0.0)
return th * res

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _age_at_z_kern(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
t0 = self.age_at_z0()
tlook = self.lookback_to_z(redshift)
return t0 - tlook

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def age_at_z(
self, redshift: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand All @@ -285,8 +285,8 @@ def age_at_z(
def _age_at_z_vmap(self):
return jit(vmap(self._age_at_z_kern))

@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def angular_scale(
self, z: Union[Float[Array, "..."], float]
) -> Float[Array, "..."]:
Expand Down Expand Up @@ -327,9 +327,6 @@ def _Om_at_z(self, z):
E = self._Ez(z)
return self.Om0 * (1.0 + z) ** 3 / E / E




@jit
def _delta_vir(self, z):
x = self._Om(z) - 1.0
Expand Down
4 changes: 2 additions & 2 deletions rubix/cosmology/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L247C1-L256C1
@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def _cumtrapz_scan_func(carryover, el):
"""
Integral helper function, which uses the formula for trapezoidal integration.
Expand Down Expand Up @@ -37,8 +37,8 @@ def _cumtrapz_scan_func(carryover, el):


# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L278C1-L298C1
@jaxtyped(typechecker=typechecker)
@jit
@jaxtyped(typechecker=typechecker)
def trapz(
xarr: Union[jnp.ndarray, Float[Array, "n"]],
yarr: Union[jnp.ndarray, Float[Array, "n"]],
Expand Down
Loading