From 4bc0fb9415a185c5bf399145e04c66f433e4666c Mon Sep 17 00:00:00 2001 From: CoastEgo Date: Sat, 11 Jan 2025 17:36:25 +0800 Subject: [PATCH] add tests --- src/microlux/basic_function.py | 1 + test/test_caustic_mag.py | 24 +------------ test/test_polynomial_solver.py | 61 ++++++++++++++++++++++++++++++++++ test/test_util.py | 22 ++++++++++++ 4 files changed, 85 insertions(+), 23 deletions(-) create mode 100644 test/test_polynomial_solver.py diff --git a/src/microlux/basic_function.py b/src/microlux/basic_function.py index 81e3288..86c275a 100644 --- a/src/microlux/basic_function.py +++ b/src/microlux/basic_function.py @@ -100,6 +100,7 @@ def Quadrupole_test(rho, s, q, zeta, z, cond, tol=1e-2): def get_poly_coff(zeta_l, s, m2): """ get the polynomial cofficients of the polynomial equation of the lens equation. The low mass object is at the origin and the primary is at s. + The input zeta_l should have the shape of (n,1) for broadcasting. """ zeta_conj = jnp.conj(zeta_l) c0 = s**2 * zeta_l * m2**2 diff --git a/test/test_caustic_mag.py b/test/test_caustic_mag.py index dc6db37..1094629 100644 --- a/test/test_caustic_mag.py +++ b/test/test_caustic_mag.py @@ -1,13 +1,11 @@ from itertools import product -import jax -import jax.numpy as jnp import numpy as np import pytest import VBBinaryLensing from microlux import contour_integral, extended_light_curve, to_lowmass from microlux.limb_darkening import LinearLimbDarkening -from MulensModel import caustics +from test_util import get_caustic_permutation rho_values = [1e-2, 1e-3, 1e-4] @@ -16,26 +14,6 @@ limb_a_values = [0.1, 0.5, 1] -def get_caustic_permutation(rho, q, s, n_points=1000): - """ - Test around the caustic, apadpted from https://github.com/fbartolic/caustics/blob/main/tests/test_extended_source.py - - **returns**: - - - return the permutation of the caustic in the central of mass coordinate system - """ - caustic = caustics.Caustics(q, s) - x, y = caustic.get_caustics(n_points) - z_centeral = jnp.array(jnp.array(x) + 1j * jnp.array(y)) - ## random change the position of the source - key = jax.random.key(42) - key, subkey1, subkey2 = jax.random.split(key, num=3) - phi = jax.random.uniform(subkey1, z_centeral.shape, minval=-np.pi, maxval=np.pi) - r = jax.random.uniform(subkey2, z_centeral.shape, minval=0.0, maxval=2 * rho) - z_centeral = z_centeral + r * jnp.exp(1j * phi) - return z_centeral - - @pytest.mark.parametrize("rho, q, s", product(rho_values, q_values, s_values)) def test_extend_sorce(rho, q, s, retol=1e-3): """ diff --git a/test/test_polynomial_solver.py b/test/test_polynomial_solver.py new file mode 100644 index 0000000..518e714 --- /dev/null +++ b/test/test_polynomial_solver.py @@ -0,0 +1,61 @@ +from itertools import product + +import jax +import jax.numpy as jnp +import pytest +from microlux.basic_function import get_poly_coff, to_lowmass +from microlux.polynomial_solver import Aberth_Ehrlich, AE_roots0 +from test_util import get_caustic_permutation + + +rho_values = [1e-2, 1e-3, 1e-4] +q_values = [1e-1, 1e-2, 1e-3] +s_values = [0.6, 1.0, 1.4] + + +@pytest.mark.parametrize("rho, q, s", product(rho_values, q_values, s_values)) +def test_polynomial_caustic(rho, q, s): + trajectory_c = get_caustic_permutation(rho, q, s, n_points=100) + theta_sample = jnp.linspace(0, 2 * jnp.pi, 100) + contours = (trajectory_c + rho * jnp.exp(1j * theta_sample)[:, None]).ravel() + + z_lowmass = to_lowmass(s, q, contours) + + coff = get_poly_coff(z_lowmass[:, None], s, q / (1 + q)) + + get_AE_roots = lambda x: Aberth_Ehrlich(x, AE_roots0(x), MAX_ITER=50).sort() + AE_roots = jax.jit(jax.vmap(get_AE_roots))(coff) + + get_numpy_roots = lambda x: jnp.roots(x, strip_zeros=False).sort() + numpy_roots = jax.jit(jax.vmap(get_numpy_roots))(coff) + + error = jnp.abs(AE_roots - numpy_roots) + + print("max absolute error is", jnp.max(error)) + assert jnp.allclose(AE_roots, numpy_roots, atol=1e-10) + + +@pytest.mark.parametrize("q, s", product(q_values, s_values)) +def test_polynomial_uniform(q, s): + x, y = jax.random.uniform(jax.random.PRNGKey(0), (2, 100000), minval=-2, maxval=2) + + trajectory_c = x + 1j * y + z_lowmass = to_lowmass(s, q, trajectory_c) + + coff = get_poly_coff(z_lowmass[:, None], s, q / (1 + q)) + + get_AE_roots = lambda x: Aberth_Ehrlich(x, AE_roots0(x), MAX_ITER=50).sort() + AE_roots = jax.jit(jax.vmap(get_AE_roots))(coff) + + get_numpy_roots = lambda x: jnp.roots(x, strip_zeros=False).sort() + numpy_roots = jax.jit(jax.vmap(get_numpy_roots))(coff) + + error = jnp.abs(AE_roots - numpy_roots) + print("max absolute error is", jnp.max(error)) + + assert jnp.allclose(AE_roots, numpy_roots, atol=1e-10) + + +if __name__ == "__main__": + test_polynomial_caustic(1e-2, 0.2, 0.9) + test_polynomial_uniform(0.2, 0.9) diff --git a/test/test_util.py b/test/test_util.py index 01b8bdd..03ed5ef 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,8 +1,10 @@ import time import jax +import jax.numpy as jnp import numpy as np import VBBinaryLensing +from MulensModel import caustics def timeit(f, iters=10, verbose=True): @@ -98,3 +100,23 @@ def get_trajectory(tau, u_0, alpha_deg): alpha = alpha_deg / 180 * np.pi trajectory = tau * np.exp(1j * alpha) + 1j * u_0 * np.exp(1j * alpha) return trajectory + + +def get_caustic_permutation(rho, q, s, n_points=1000): + """ + Test around the caustic, apadpted from https://github.com/fbartolic/caustics/blob/main/tests/test_extended_source.py + + **returns**: + + - return the permutation of the caustic in the central of mass coordinate system + """ + caustic = caustics.Caustics(q, s) + x, y = caustic.get_caustics(n_points) + z_centeral = jnp.array(jnp.array(x) + 1j * jnp.array(y)) + ## random change the position of the source + key = jax.random.key(42) + key, subkey1, subkey2 = jax.random.split(key, num=3) + phi = jax.random.uniform(subkey1, z_centeral.shape, minval=-np.pi, maxval=np.pi) + r = jax.random.uniform(subkey2, z_centeral.shape, minval=0.0, maxval=2 * rho) + z_centeral = z_centeral + r * jnp.exp(1j * phi) + return z_centeral