Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CoastEgo committed Jan 11, 2025
1 parent 00d0bff commit 4bc0fb9
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 23 deletions.
1 change: 1 addition & 0 deletions src/microlux/basic_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def Quadrupole_test(rho, s, q, zeta, z, cond, tol=1e-2):
def get_poly_coff(zeta_l, s, m2):
"""
get the polynomial cofficients of the polynomial equation of the lens equation. The low mass object is at the origin and the primary is at s.
The input zeta_l should have the shape of (n,1) for broadcasting.
"""
zeta_conj = jnp.conj(zeta_l)
c0 = s**2 * zeta_l * m2**2
Expand Down
24 changes: 1 addition & 23 deletions test/test_caustic_mag.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from itertools import product

import jax
import jax.numpy as jnp
import numpy as np
import pytest
import VBBinaryLensing
from microlux import contour_integral, extended_light_curve, to_lowmass
from microlux.limb_darkening import LinearLimbDarkening
from MulensModel import caustics
from test_util import get_caustic_permutation


rho_values = [1e-2, 1e-3, 1e-4]
Expand All @@ -16,26 +14,6 @@
limb_a_values = [0.1, 0.5, 1]


def get_caustic_permutation(rho, q, s, n_points=1000):
"""
Test around the caustic, apadpted from https://github.com/fbartolic/caustics/blob/main/tests/test_extended_source.py
**returns**:
- return the permutation of the caustic in the central of mass coordinate system
"""
caustic = caustics.Caustics(q, s)
x, y = caustic.get_caustics(n_points)
z_centeral = jnp.array(jnp.array(x) + 1j * jnp.array(y))
## random change the position of the source
key = jax.random.key(42)
key, subkey1, subkey2 = jax.random.split(key, num=3)
phi = jax.random.uniform(subkey1, z_centeral.shape, minval=-np.pi, maxval=np.pi)
r = jax.random.uniform(subkey2, z_centeral.shape, minval=0.0, maxval=2 * rho)
z_centeral = z_centeral + r * jnp.exp(1j * phi)
return z_centeral


@pytest.mark.parametrize("rho, q, s", product(rho_values, q_values, s_values))
def test_extend_sorce(rho, q, s, retol=1e-3):
"""
Expand Down
61 changes: 61 additions & 0 deletions test/test_polynomial_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from itertools import product

import jax
import jax.numpy as jnp
import pytest
from microlux.basic_function import get_poly_coff, to_lowmass
from microlux.polynomial_solver import Aberth_Ehrlich, AE_roots0
from test_util import get_caustic_permutation


rho_values = [1e-2, 1e-3, 1e-4]
q_values = [1e-1, 1e-2, 1e-3]
s_values = [0.6, 1.0, 1.4]


@pytest.mark.parametrize("rho, q, s", product(rho_values, q_values, s_values))
def test_polynomial_caustic(rho, q, s):
trajectory_c = get_caustic_permutation(rho, q, s, n_points=100)
theta_sample = jnp.linspace(0, 2 * jnp.pi, 100)
contours = (trajectory_c + rho * jnp.exp(1j * theta_sample)[:, None]).ravel()

z_lowmass = to_lowmass(s, q, contours)

coff = get_poly_coff(z_lowmass[:, None], s, q / (1 + q))

get_AE_roots = lambda x: Aberth_Ehrlich(x, AE_roots0(x), MAX_ITER=50).sort()
AE_roots = jax.jit(jax.vmap(get_AE_roots))(coff)

get_numpy_roots = lambda x: jnp.roots(x, strip_zeros=False).sort()
numpy_roots = jax.jit(jax.vmap(get_numpy_roots))(coff)

error = jnp.abs(AE_roots - numpy_roots)

print("max absolute error is", jnp.max(error))
assert jnp.allclose(AE_roots, numpy_roots, atol=1e-10)


@pytest.mark.parametrize("q, s", product(q_values, s_values))
def test_polynomial_uniform(q, s):
x, y = jax.random.uniform(jax.random.PRNGKey(0), (2, 100000), minval=-2, maxval=2)

trajectory_c = x + 1j * y
z_lowmass = to_lowmass(s, q, trajectory_c)

coff = get_poly_coff(z_lowmass[:, None], s, q / (1 + q))

get_AE_roots = lambda x: Aberth_Ehrlich(x, AE_roots0(x), MAX_ITER=50).sort()
AE_roots = jax.jit(jax.vmap(get_AE_roots))(coff)

get_numpy_roots = lambda x: jnp.roots(x, strip_zeros=False).sort()
numpy_roots = jax.jit(jax.vmap(get_numpy_roots))(coff)

error = jnp.abs(AE_roots - numpy_roots)
print("max absolute error is", jnp.max(error))

assert jnp.allclose(AE_roots, numpy_roots, atol=1e-10)


if __name__ == "__main__":
test_polynomial_caustic(1e-2, 0.2, 0.9)
test_polynomial_uniform(0.2, 0.9)
22 changes: 22 additions & 0 deletions test/test_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import time

import jax
import jax.numpy as jnp
import numpy as np
import VBBinaryLensing
from MulensModel import caustics


def timeit(f, iters=10, verbose=True):
Expand Down Expand Up @@ -98,3 +100,23 @@ def get_trajectory(tau, u_0, alpha_deg):
alpha = alpha_deg / 180 * np.pi
trajectory = tau * np.exp(1j * alpha) + 1j * u_0 * np.exp(1j * alpha)
return trajectory


def get_caustic_permutation(rho, q, s, n_points=1000):
"""
Test around the caustic, apadpted from https://github.com/fbartolic/caustics/blob/main/tests/test_extended_source.py
**returns**:
- return the permutation of the caustic in the central of mass coordinate system
"""
caustic = caustics.Caustics(q, s)
x, y = caustic.get_caustics(n_points)
z_centeral = jnp.array(jnp.array(x) + 1j * jnp.array(y))
## random change the position of the source
key = jax.random.key(42)
key, subkey1, subkey2 = jax.random.split(key, num=3)
phi = jax.random.uniform(subkey1, z_centeral.shape, minval=-np.pi, maxval=np.pi)
r = jax.random.uniform(subkey2, z_centeral.shape, minval=0.0, maxval=2 * rho)
z_centeral = z_centeral + r * jnp.exp(1j * phi)
return z_centeral

0 comments on commit 4bc0fb9

Please sign in to comment.