Skip to content

Commit

Permalink
rename binary mag model
Browse files Browse the repository at this point in the history
  • Loading branch information
CoastEgo committed Jan 7, 2025
1 parent 9f838c6 commit 0ac4cb8
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/api/model.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Magnification Model

Light curve for the binary lens with adaptive contour integration and point sorce approximation.
::: microlux.model
::: microlux.extended_light_curve
---
Light curve with point source approximation.
::: microlux.point_light_curve
Expand Down
2 changes: 1 addition & 1 deletion src/microlux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from .model import (
contour_integral as contour_integral,
model as model,
extended_light_curve as extended_light_curve,
point_light_curve as point_light_curve,
)
from .utils import (
Expand Down
5 changes: 3 additions & 2 deletions src/microlux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def point_light_curve(trajectory_l, s, q, rho, tol, return_num=False):


@partial(jax.jit, static_argnames=["return_info", "default_strategy", "analytic"])
def model(
def extended_light_curve(
t_0,
u_0,
t_E,
Expand All @@ -101,7 +101,8 @@ def model(
analytic=True,
):
"""
Compute the microlensing model for a binary lens system using JAX.
Compute the light curve of a binary lens system with finite source effects.
This function will dynamically choose full contour integration or point source approximation based on the quadrupole test.
Args:
t_0 (float):
Expand Down
8 changes: 4 additions & 4 deletions test/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import gridspec
from microlux import model
from microlux import extended_light_curve
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.optimize import approx_fprime
from test_util import VBBL_light_curve
Expand All @@ -18,7 +18,7 @@ def grad_test(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, tol):
vbbl_fun,
)

grad_fun = jax.jacfwd(model, argnums=(0, 1, 2, 3, 4, 5, 6))
grad_fun = jax.jacfwd(extended_light_curve, argnums=(0, 1, 2, 3, 4, 5, 6))
jacobian = jnp.array(
grad_fun(t_0, u_0, t_E, rho, q, s, alpha_deg, times, tol, retol)
)
Expand All @@ -31,7 +31,7 @@ def grad_test(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, tol):
plt.figure(figsize=(10, 8))
# fig = plt.figure(figsize=(8,6))
# gc = gridspec.GridSpec(2, 1,height_ratios=[1,1])
mag = model(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, retol)
mag = extended_light_curve(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, retol)
mag_vbl = np.array(
VBBL_light_curve(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol)
)
Expand Down Expand Up @@ -118,7 +118,7 @@ def format_func(value, tick_number):
# times=jnp.linspace(t_0-1.*t_E,t_0+1*t_E,trajectory_n)
times = jnp.linspace(8260, 8320, trajectory_n)[250:1000]

grad_fun = jax.jacfwd(model, argnums=(0, 1, 2, 3, 4, 5, 6))
grad_fun = jax.jacfwd(extended_light_curve, argnums=(0, 1, 2, 3, 4, 5, 6))
grad_fun = jax.jit(grad_fun)

grad_test(t_0, b, t_E, rho, q, s, alphadeg, times, retol, tol)
4 changes: 2 additions & 2 deletions test/test_image_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

# from microlux import model_numpy
from microlux import model, to_centroid, to_lowmass
from microlux import extended_light_curve, to_centroid, to_lowmass
from MulensModel import caustics


Expand Down Expand Up @@ -38,7 +38,7 @@
def contour_plot(t_0, b, t_E, rho, q, s, alphadeg, times, retol=1e-3, tol=1e-3):
alpha = alphadeg * 2 * np.pi / 360

mag, info = model(
mag, info = extended_light_curve(
t_0,
b,
t_E,
Expand Down
6 changes: 3 additions & 3 deletions test/test_light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from microlux import model
from microlux import extended_light_curve
from test_util import timeit, VBBL_light_curve


Expand All @@ -12,13 +12,13 @@

def time_test(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, tol):
####################编译时间
uniform_mag, time = timeit(model)(
uniform_mag, time = timeit(extended_light_curve)(
t_0, b, t_E, rho, q, s, alpha_deg, times, tol, retol=retol
)
# with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# jax.profiler.start_trace("/tmp/tensorboard")
# with jax.disable_jit():
# model(**parm).block_until_ready()
# extended_light_curve(**parm).block_until_ready()
# jax.profiler.stop_trace()
VBBL_mag, _ = timeit(VBBL_light_curve)(
t_0, b, t_E, rho, q, s, alpha_deg, times, retol=1e-3, tol=1e-3
Expand Down
4 changes: 2 additions & 2 deletions test/test_mag_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def mag_map_vbbl(i, all_params, fix_params):

# jax mag map test

from microlux import model
from microlux import extended_light_curve

def mag_jax(i, rho, q, s, parm):
t_0, b_map, t_E, alphadeg, times_jax, tol = parm
uniform_mag, info = model(
uniform_mag, info = extended_light_curve(
t_0,
b_map[i],
t_E,
Expand Down

0 comments on commit 0ac4cb8

Please sign in to comment.