@@ -56,8 +56,8 @@ def __init__(self, Om0: float, w0: float, wa: float, h: float):
5656 self .wa = jnp .float32 (wa )
5757 self .h = jnp .float32 (h )
5858
59- @jaxtyped (typechecker = typechecker )
6059 @jit
60+ @jaxtyped (typechecker = typechecker )
6161 def scale_factor_to_redshift (
6262 self , a : Union [Float [Array , "..." ], float ]
6363 ) -> Float [Array , "..." ]:
@@ -79,33 +79,33 @@ def scale_factor_to_redshift(
7979 z = 1.0 / a - 1.0
8080 return z
8181
82- @jaxtyped (typechecker = typechecker )
8382 @jit
83+ @jaxtyped (typechecker = typechecker )
8484 def _rho_de_z (self , z : Union [Float [Array , "..." ], float ]) -> Float [Array , "..." ]:
8585 a = 1.0 / (1.0 + z )
8686 de_z = a ** (- 3.0 * (1.0 + self .w0 + self .wa )) * lax .exp (
8787 - 3.0 * self .wa * (1.0 - a )
8888 )
8989 return de_z
9090
91- @jaxtyped (typechecker = typechecker )
9291 @jit
92+ @jaxtyped (typechecker = typechecker )
9393 def _Ez (self , z : Union [Float [Array , "..." ], float ]) -> Float [Array , "..." ]:
9494 zp1 = 1.0 + z
9595 Ode0 = 1.0 - self .Om0
9696 t = self .Om0 * zp1 ** 3 + Ode0 * self ._rho_de_z (z )
9797 E = jnp .sqrt (t )
9898 return E
9999
100- @jaxtyped (typechecker = typechecker )
101100 @jit
101+ @jaxtyped (typechecker = typechecker )
102102 def _integrand_oneOverEz (
103103 self , z : Union [Float [Array , "..." ], float ]
104104 ) -> Float [Array , "..." ]:
105105 return 1 / self ._Ez (z )
106106
107- @jaxtyped (typechecker = typechecker )
108107 @jit
108+ @jaxtyped (typechecker = typechecker )
109109 def comoving_distance_to_z (
110110 self , redshift : Union [Float [Array , "..." ], float ]
111111 ) -> Float [Array , "..." ]:
@@ -128,8 +128,8 @@ def comoving_distance_to_z(
128128 integrand = self ._integrand_oneOverEz (z_table )
129129 return trapz (z_table , integrand ) * C_SPEED * 1e-5 / self .h
130130
131- @jaxtyped (typechecker = typechecker )
132131 @jit
132+ @jaxtyped (typechecker = typechecker )
133133 def luminosity_distance_to_z (
134134 self , redshift : Union [Float [Array , "..." ], float ]
135135 ) -> Float [Array , "..." ]:
@@ -150,8 +150,8 @@ def luminosity_distance_to_z(
150150 """
151151 return self .comoving_distance_to_z (redshift ) * (1 + redshift )
152152
153- @jaxtyped (typechecker = typechecker )
154153 @jit
154+ @jaxtyped (typechecker = typechecker )
155155 def angular_diameter_distance_to_z (
156156 self , redshift : Union [Float [Array , "..." ], float ]
157157 ) -> Float [Array , "..." ]:
@@ -172,8 +172,8 @@ def angular_diameter_distance_to_z(
172172 """
173173 return self .comoving_distance_to_z (redshift ) / (1 + redshift )
174174
175- @jaxtyped (typechecker = typechecker )
176175 @jit
176+ @jaxtyped (typechecker = typechecker )
177177 def distance_modulus_to_z (
178178 self , redshift : Union [Float [Array , "..." ], float ]
179179 ) -> Float [Array , "..." ]:
@@ -196,15 +196,15 @@ def distance_modulus_to_z(
196196 mu = 5.0 * jnp .log10 (d_lum * 1e5 )
197197 return mu
198198
199- @jaxtyped (typechecker = typechecker )
200199 @jit
200+ @jaxtyped (typechecker = typechecker )
201201 def _hubble_time (self , z : Union [Float [Array , "..." ], float ]) -> Float [Array , "..." ]:
202202 E0 = self ._Ez (z )
203203 htime = 1e-16 * MPC / YEAR / self .h / E0
204204 return htime
205205
206- @jaxtyped (typechecker = typechecker )
207206 @jit
207+ @jaxtyped (typechecker = typechecker )
208208 def lookback_to_z (
209209 self , redshift : Union [Float [Array , "..." ], float ]
210210 ) -> Float [Array , "..." ]:
@@ -229,8 +229,8 @@ def lookback_to_z(
229229 th = self ._hubble_time (0.0 )
230230 return th * res
231231
232- @jaxtyped (typechecker = typechecker )
233232 @jit
233+ @jaxtyped (typechecker = typechecker )
234234 def age_at_z0 (self ) -> Float [Array , "..." ]:
235235 """
236236 The function calculates the age of the universe at redshift 0.
@@ -250,17 +250,17 @@ def age_at_z0(self) -> Float[Array, "..."]:
250250 th = self ._hubble_time (0.0 )
251251 return th * res
252252
253- @jaxtyped (typechecker = typechecker )
254253 @jit
254+ @jaxtyped (typechecker = typechecker )
255255 def _age_at_z_kern (
256256 self , redshift : Union [Float [Array , "..." ], float ]
257257 ) -> Float [Array , "..." ]:
258258 t0 = self .age_at_z0 ()
259259 tlook = self .lookback_to_z (redshift )
260260 return t0 - tlook
261261
262- @jaxtyped (typechecker = typechecker )
263262 @jit
263+ @jaxtyped (typechecker = typechecker )
264264 def age_at_z (
265265 self , redshift : Union [Float [Array , "..." ], float ]
266266 ) -> Float [Array , "..." ]:
@@ -285,8 +285,8 @@ def age_at_z(
285285 def _age_at_z_vmap (self ):
286286 return jit (vmap (self ._age_at_z_kern ))
287287
288- @jaxtyped (typechecker = typechecker )
289288 @jit
289+ @jaxtyped (typechecker = typechecker )
290290 def angular_scale (
291291 self , z : Union [Float [Array , "..." ], float ]
292292 ) -> Float [Array , "..." ]:
@@ -327,9 +327,6 @@ def _Om_at_z(self, z):
327327 E = self._Ez(z)
328328 return self.Om0 * (1.0 + z) ** 3 / E / E
329329
330-
331-
332-
333330 @jit
334331 def _delta_vir(self, z):
335332 x = self._Om(z) - 1.0
0 commit comments