Skip to content

Commit ec93658

Browse files
authored
Merge pull request #159 from AstroAI-Lab/dependabot/github_actions/pypa/gh-action-pypi-publish-1.13.0
Bump pypa/gh-action-pypi-publish from 1.12.4 to 1.13.0 additionally fix order of @jit and @typechecker decorator to fix failing tests.
2 parents b51d566 + 81b5671 commit ec93658

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

.github/workflows/pypi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ jobs:
2525
- name: Build SDist
2626
run: pipx run build --sdist
2727

28-
- uses: pypa/gh-action-pypi-publish@v1.12.4
28+
- uses: pypa/gh-action-pypi-publish@v1.13.0

rubix/cosmology/base.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def __init__(self, Om0: float, w0: float, wa: float, h: float):
5656
self.wa = jnp.float32(wa)
5757
self.h = jnp.float32(h)
5858

59-
@jaxtyped(typechecker=typechecker)
6059
@jit
60+
@jaxtyped(typechecker=typechecker)
6161
def scale_factor_to_redshift(
6262
self, a: Union[Float[Array, "..."], float]
6363
) -> Float[Array, "..."]:
@@ -79,33 +79,33 @@ def scale_factor_to_redshift(
7979
z = 1.0 / a - 1.0
8080
return z
8181

82-
@jaxtyped(typechecker=typechecker)
8382
@jit
83+
@jaxtyped(typechecker=typechecker)
8484
def _rho_de_z(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
8585
a = 1.0 / (1.0 + z)
8686
de_z = a ** (-3.0 * (1.0 + self.w0 + self.wa)) * lax.exp(
8787
-3.0 * self.wa * (1.0 - a)
8888
)
8989
return de_z
9090

91-
@jaxtyped(typechecker=typechecker)
9291
@jit
92+
@jaxtyped(typechecker=typechecker)
9393
def _Ez(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
9494
zp1 = 1.0 + z
9595
Ode0 = 1.0 - self.Om0
9696
t = self.Om0 * zp1**3 + Ode0 * self._rho_de_z(z)
9797
E = jnp.sqrt(t)
9898
return E
9999

100-
@jaxtyped(typechecker=typechecker)
101100
@jit
101+
@jaxtyped(typechecker=typechecker)
102102
def _integrand_oneOverEz(
103103
self, z: Union[Float[Array, "..."], float]
104104
) -> Float[Array, "..."]:
105105
return 1 / self._Ez(z)
106106

107-
@jaxtyped(typechecker=typechecker)
108107
@jit
108+
@jaxtyped(typechecker=typechecker)
109109
def comoving_distance_to_z(
110110
self, redshift: Union[Float[Array, "..."], float]
111111
) -> Float[Array, "..."]:
@@ -128,8 +128,8 @@ def comoving_distance_to_z(
128128
integrand = self._integrand_oneOverEz(z_table)
129129
return trapz(z_table, integrand) * C_SPEED * 1e-5 / self.h
130130

131-
@jaxtyped(typechecker=typechecker)
132131
@jit
132+
@jaxtyped(typechecker=typechecker)
133133
def luminosity_distance_to_z(
134134
self, redshift: Union[Float[Array, "..."], float]
135135
) -> Float[Array, "..."]:
@@ -150,8 +150,8 @@ def luminosity_distance_to_z(
150150
"""
151151
return self.comoving_distance_to_z(redshift) * (1 + redshift)
152152

153-
@jaxtyped(typechecker=typechecker)
154153
@jit
154+
@jaxtyped(typechecker=typechecker)
155155
def angular_diameter_distance_to_z(
156156
self, redshift: Union[Float[Array, "..."], float]
157157
) -> Float[Array, "..."]:
@@ -172,8 +172,8 @@ def angular_diameter_distance_to_z(
172172
"""
173173
return self.comoving_distance_to_z(redshift) / (1 + redshift)
174174

175-
@jaxtyped(typechecker=typechecker)
176175
@jit
176+
@jaxtyped(typechecker=typechecker)
177177
def distance_modulus_to_z(
178178
self, redshift: Union[Float[Array, "..."], float]
179179
) -> Float[Array, "..."]:
@@ -196,15 +196,15 @@ def distance_modulus_to_z(
196196
mu = 5.0 * jnp.log10(d_lum * 1e5)
197197
return mu
198198

199-
@jaxtyped(typechecker=typechecker)
200199
@jit
200+
@jaxtyped(typechecker=typechecker)
201201
def _hubble_time(self, z: Union[Float[Array, "..."], float]) -> Float[Array, "..."]:
202202
E0 = self._Ez(z)
203203
htime = 1e-16 * MPC / YEAR / self.h / E0
204204
return htime
205205

206-
@jaxtyped(typechecker=typechecker)
207206
@jit
207+
@jaxtyped(typechecker=typechecker)
208208
def lookback_to_z(
209209
self, redshift: Union[Float[Array, "..."], float]
210210
) -> Float[Array, "..."]:
@@ -229,8 +229,8 @@ def lookback_to_z(
229229
th = self._hubble_time(0.0)
230230
return th * res
231231

232-
@jaxtyped(typechecker=typechecker)
233232
@jit
233+
@jaxtyped(typechecker=typechecker)
234234
def age_at_z0(self) -> Float[Array, "..."]:
235235
"""
236236
The function calculates the age of the universe at redshift 0.
@@ -250,17 +250,17 @@ def age_at_z0(self) -> Float[Array, "..."]:
250250
th = self._hubble_time(0.0)
251251
return th * res
252252

253-
@jaxtyped(typechecker=typechecker)
254253
@jit
254+
@jaxtyped(typechecker=typechecker)
255255
def _age_at_z_kern(
256256
self, redshift: Union[Float[Array, "..."], float]
257257
) -> Float[Array, "..."]:
258258
t0 = self.age_at_z0()
259259
tlook = self.lookback_to_z(redshift)
260260
return t0 - tlook
261261

262-
@jaxtyped(typechecker=typechecker)
263262
@jit
263+
@jaxtyped(typechecker=typechecker)
264264
def age_at_z(
265265
self, redshift: Union[Float[Array, "..."], float]
266266
) -> Float[Array, "..."]:
@@ -285,8 +285,8 @@ def age_at_z(
285285
def _age_at_z_vmap(self):
286286
return jit(vmap(self._age_at_z_kern))
287287

288-
@jaxtyped(typechecker=typechecker)
289288
@jit
289+
@jaxtyped(typechecker=typechecker)
290290
def angular_scale(
291291
self, z: Union[Float[Array, "..."], float]
292292
) -> Float[Array, "..."]:
@@ -327,9 +327,6 @@ def _Om_at_z(self, z):
327327
E = self._Ez(z)
328328
return self.Om0 * (1.0 + z) ** 3 / E / E
329329
330-
331-
332-
333330
@jit
334331
def _delta_vir(self, z):
335332
x = self._Om(z) - 1.0

rubix/cosmology/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L247C1-L256C1
11-
@jaxtyped(typechecker=typechecker)
1211
@jit
12+
@jaxtyped(typechecker=typechecker)
1313
def _cumtrapz_scan_func(carryover, el):
1414
"""
1515
Integral helper function, which uses the formula for trapezoidal integration.
@@ -37,8 +37,8 @@ def _cumtrapz_scan_func(carryover, el):
3737

3838

3939
# Source: https://github.com/ArgonneCPAC/dsps/blob/b81bac59e545e2d68ccf698faba078d87cfa2dd8/dsps/utils.py#L278C1-L298C1
40-
@jaxtyped(typechecker=typechecker)
4140
@jit
41+
@jaxtyped(typechecker=typechecker)
4242
def trapz(
4343
xarr: Union[jnp.ndarray, Float[Array, "n"]],
4444
yarr: Union[jnp.ndarray, Float[Array, "n"]],

0 commit comments

Comments
 (0)