From 0ac4cb8ab4c373757f063d20b92fa94f8df87d06 Mon Sep 17 00:00:00 2001 From: CoastEgo Date: Tue, 7 Jan 2025 10:23:10 +0800 Subject: [PATCH] rename binary mag model --- docs/api/model.md | 2 +- src/microlux/__init__.py | 2 +- src/microlux/model.py | 5 +++-- test/test_grad.py | 8 ++++---- test/test_image_contour.py | 4 ++-- test/test_light_curve.py | 6 +++--- test/test_mag_map.py | 4 ++-- 7 files changed, 16 insertions(+), 15 deletions(-) diff --git a/docs/api/model.md b/docs/api/model.md index b0d8a44..f3f435b 100644 --- a/docs/api/model.md +++ b/docs/api/model.md @@ -1,7 +1,7 @@ # Magnification Model Light curve for the binary lens with adaptive contour integration and point sorce approximation. -::: microlux.model +::: microlux.extended_light_curve --- Light curve with point source approximation. ::: microlux.point_light_curve diff --git a/src/microlux/__init__.py b/src/microlux/__init__.py index 881d20f..4091525 100644 --- a/src/microlux/__init__.py +++ b/src/microlux/__init__.py @@ -15,7 +15,7 @@ ) from .model import ( contour_integral as contour_integral, - model as model, + extended_light_curve as extended_light_curve, point_light_curve as point_light_curve, ) from .utils import ( diff --git a/src/microlux/model.py b/src/microlux/model.py index 503f5fc..ea8d278 100644 --- a/src/microlux/model.py +++ b/src/microlux/model.py @@ -85,7 +85,7 @@ def point_light_curve(trajectory_l, s, q, rho, tol, return_num=False): @partial(jax.jit, static_argnames=["return_info", "default_strategy", "analytic"]) -def model( +def extended_light_curve( t_0, u_0, t_E, @@ -101,7 +101,8 @@ def model( analytic=True, ): """ - Compute the microlensing model for a binary lens system using JAX. + Compute the light curve of a binary lens system with finite source effects. + This function will dynamically choose full contour integration or point source approximation based on the quadrupole test. Args: t_0 (float): diff --git a/test/test_grad.py b/test/test_grad.py index 80273b1..8a9759f 100644 --- a/test/test_grad.py +++ b/test/test_grad.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib import gridspec -from microlux import model +from microlux import extended_light_curve from mpl_toolkits.axes_grid1.inset_locator import inset_axes from scipy.optimize import approx_fprime from test_util import VBBL_light_curve @@ -18,7 +18,7 @@ def grad_test(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, tol): vbbl_fun, ) - grad_fun = jax.jacfwd(model, argnums=(0, 1, 2, 3, 4, 5, 6)) + grad_fun = jax.jacfwd(extended_light_curve, argnums=(0, 1, 2, 3, 4, 5, 6)) jacobian = jnp.array( grad_fun(t_0, u_0, t_E, rho, q, s, alpha_deg, times, tol, retol) ) @@ -31,7 +31,7 @@ def grad_test(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, tol): plt.figure(figsize=(10, 8)) # fig = plt.figure(figsize=(8,6)) # gc = gridspec.GridSpec(2, 1,height_ratios=[1,1]) - mag = model(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, retol) + mag = extended_light_curve(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, retol) mag_vbl = np.array( VBBL_light_curve(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol) ) @@ -118,7 +118,7 @@ def format_func(value, tick_number): # times=jnp.linspace(t_0-1.*t_E,t_0+1*t_E,trajectory_n) times = jnp.linspace(8260, 8320, trajectory_n)[250:1000] - grad_fun = jax.jacfwd(model, argnums=(0, 1, 2, 3, 4, 5, 6)) + grad_fun = jax.jacfwd(extended_light_curve, argnums=(0, 1, 2, 3, 4, 5, 6)) grad_fun = jax.jit(grad_fun) grad_test(t_0, b, t_E, rho, q, s, alphadeg, times, retol, tol) diff --git a/test/test_image_contour.py b/test/test_image_contour.py index 89574e8..892fe93 100644 --- a/test/test_image_contour.py +++ b/test/test_image_contour.py @@ -2,7 +2,7 @@ import numpy as np # from microlux import model_numpy -from microlux import model, to_centroid, to_lowmass +from microlux import extended_light_curve, to_centroid, to_lowmass from MulensModel import caustics @@ -38,7 +38,7 @@ def contour_plot(t_0, b, t_E, rho, q, s, alphadeg, times, retol=1e-3, tol=1e-3): alpha = alphadeg * 2 * np.pi / 360 - mag, info = model( + mag, info = extended_light_curve( t_0, b, t_E, diff --git a/test/test_light_curve.py b/test/test_light_curve.py index e6a7fe7..66b53d7 100644 --- a/test/test_light_curve.py +++ b/test/test_light_curve.py @@ -2,7 +2,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np -from microlux import model +from microlux import extended_light_curve from test_util import timeit, VBBL_light_curve @@ -12,13 +12,13 @@ def time_test(t_0, u_0, t_E, rho, q, s, alpha_deg, times, retol, tol): ####################编译时间 - uniform_mag, time = timeit(model)( + uniform_mag, time = timeit(extended_light_curve)( t_0, b, t_E, rho, q, s, alpha_deg, times, tol, retol=retol ) # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): # jax.profiler.start_trace("/tmp/tensorboard") # with jax.disable_jit(): - # model(**parm).block_until_ready() + # extended_light_curve(**parm).block_until_ready() # jax.profiler.stop_trace() VBBL_mag, _ = timeit(VBBL_light_curve)( t_0, b, t_E, rho, q, s, alpha_deg, times, retol=1e-3, tol=1e-3 diff --git a/test/test_mag_map.py b/test/test_mag_map.py index 95b8d16..bfb2f0b 100644 --- a/test/test_mag_map.py +++ b/test/test_mag_map.py @@ -82,11 +82,11 @@ def mag_map_vbbl(i, all_params, fix_params): # jax mag map test - from microlux import model + from microlux import extended_light_curve def mag_jax(i, rho, q, s, parm): t_0, b_map, t_E, alphadeg, times_jax, tol = parm - uniform_mag, info = model( + uniform_mag, info = extended_light_curve( t_0, b_map[i], t_E,