Skip to content

Commit

Permalink
Merge branch 'train_fluxes' of https://github.com/haukekoehn/fiesta i…
Browse files Browse the repository at this point in the history
…nto haukekoehn-train_fluxes
  • Loading branch information
ThibeauWouters committed Dec 19, 2024
2 parents 5ba014e + a09e86c commit aaf0eec
Show file tree
Hide file tree
Showing 175 changed files with 3,394 additions and 196 deletions.
21 changes: 16 additions & 5 deletions benchmarks/GRB/benchmark_afterglowpy_tophat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,40 @@
model_dir = f"../../trained_models/GRB/afterglowpy/{name}/"
FILTERS = ["radio-6GHz", "radio-3GHz"]#["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"]


for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]:
if metric_name == "$\\mathcal{L}_2$":
file_ending = "L2"
else:
file_ending = "Linf"


B = Benchmarker(name = "tophat",
B = Benchmarker(name = name,
parameter_grid = parameter_grid,
model_dir = model_dir,
MODEL = AfterglowpyLightcurvemodel,
filters = FILTERS,
n_test_data = 2000,
metric_name = metric_name,
remake_test_data = True,
remake_test_data = False,
jet_type = -1,
)

fig, ax = B.plot_error_distribution("radio-6GHz")


for filt in FILTERS:

fig, ax = B.plot_lightcurves_mismatch(filter =filt)
fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200)
fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{core}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"])
fig.savefig(f"./benchmarks/{name}/benchmark_{filt}_{file_ending}.pdf", dpi = 200)

B.print_correlations(filter = filt)


if metric_name == "$\\mathcal{L}_\infty$":
fig, ax = B.plot_error_distribution(filt)
fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200)


fig, ax = B.plot_worst_lightcurve(filter = filt)
fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200)

11 changes: 7 additions & 4 deletions benchmarks/KN/benchmark_Bu2019lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200)

B.print_correlations(filter = filt)


fig, ax = B.plot_worst_lightcurve(filter = filt)
fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200)


if metric_name == "$\\mathcal{L}_\infty$":
fig, ax = B.plot_error_distribution(filt)
fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200)


fig, ax = B.plot_worst_lightcurves()
fig.savefig(f"./benchmarks/{name}/worst_lightcurves_{file_ending}.pdf", dpi = 200)


213 changes: 213 additions & 0 deletions examples/GRB/injection_gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Injection runs with afterglowpy gaussian"""

import os
import jax
print(f"GPU found? {jax.devices()}")
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt
import corner

from fiesta.inference.lightcurve_model import AfterglowpyPCA, PCALightcurveModel
from fiesta.inference.injection import InjectionRecoveryAfterglowpy
from fiesta.inference.likelihood import EMLikelihood
from fiesta.inference.prior import Uniform, CompositePrior, Constraint
from fiesta.inference.prior_dict import ConstrainedPrior
from fiesta.inference.fiesta import Fiesta
from fiesta.utils import load_event_data, write_event_data

import time
start_time = time.time()

################
### Preamble ###
################

jax.config.update("jax_enable_x64", True)

params = {"axes.grid": True,
"text.usetex" : True,
"font.family" : "serif",
"ytick.color" : "black",
"xtick.color" : "black",
"axes.labelcolor" : "black",
"axes.edgecolor" : "black",
"font.serif" : ["Computer Modern Serif"],
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"axes.labelsize": 16,
"legend.fontsize": 16,
"legend.title_fontsize": 16,
"figure.titlesize": 16}

plt.rcParams.update(params)

default_corner_kwargs = dict(bins=40,
smooth=1.,
label_kwargs=dict(fontsize=16),
title_kwargs=dict(fontsize=16),
color="blue",
# quantiles=[],
# levels=[0.9],
plot_density=True,
plot_datapoints=False,
fill_contours=True,
max_n_ticks=4,
min_n_ticks=3,
save=False,
truth_color="red")


##############
### MODEL ###
##############

name = "gaussian"
model_dir = f"../../flux_models/afterglowpy_{name}/model"
FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"]

model = AfterglowpyPCA(name,
model_dir,
filters = FILTERS)


###################
### INJECT ###
### AFTERGLOWPY ###
###################

trigger_time = 58849 # 01-01-2020 in mjd
remake_injection = False
injection_dict = {"inclination_EM": 0.174, "log10_E0": 54.4, "thetaCore": 0.14, "alphaWing": 3, "p": 2.6, "log10_n0": -2, "log10_epsilon_e": -2.06, "log10_epsilon_B": -4.2, "luminosity_distance": 40.0}

if remake_injection:
injection = InjectionRecoveryAfterglowpy(injection_dict, jet_type = 0, filters = FILTERS, N_datapoints = 70, error_budget = 0.5, tmin = 1, tmax = 2000, trigger_time = trigger_time)
injection.create_injection()
data = injection.data
write_event_data("./injection_gaussian/injection_gaussian.dat", data)

data = load_event_data("./injection_gaussian/injection_gaussian.dat")
#############################
### PRIORS AND LIKELIHOOD ###
#############################

inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM'])
log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0'])
thetaCore = Uniform(xmin=0.01, xmax=np.pi/5, naming=['thetaCore'])
alphaWing = Uniform(xmin = 0.2, xmax = 3.5, naming= ["alphaWing"])
thetaWing = Constraint(xmin = 0, xmax = np.pi/2, naming = ["thetaWing"])
log10_n0 = Uniform(xmin=-6.0, xmax=2.0, naming=['log10_n0'])
p = Uniform(xmin=2.01, xmax=3.0, naming=['p'])
log10_epsilon_e = Uniform(xmin=-4.0, xmax=0.0, naming=['log10_epsilon_e'])
log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B'])
epsilon_tot = Constraint(xmin = 0, xmax = 1, naming = ["epsilon_tot"])

# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance'])
def conversion_function(sample):
converted_sample = sample
converted_sample["thetaWing"] = converted_sample["thetaCore"] * converted_sample["alphaWing"]
converted_sample["epsilon_tot"] = 10**(converted_sample["log10_epsilon_B"]) + 10**(converted_sample["log10_epsilon_e"])
return converted_sample

prior_list = [inclination_EM,
log10_E0,
thetaCore,
alphaWing,
log10_n0,
p,
log10_epsilon_e,
log10_epsilon_B,
thetaWing,
epsilon_tot]

prior = ConstrainedPrior(prior_list, conversion_function)

detection_limit = None
likelihood = EMLikelihood(model,
data,
FILTERS,
tmax = 2000.0,
trigger_time=trigger_time,
detection_limit = detection_limit,
fixed_params={"luminosity_distance": 40.0},
error_budget = 1e-5)


##############
### FIESTA ###
##############

mass_matrix = jnp.eye(prior.n_dim)
eps = 5e-3
local_sampler_arg = {"step_size": mass_matrix * eps}

# Save for postprocessing
outdir = f"./injection_{name}/"
if not os.path.exists(outdir):
os.makedirs(outdir)

fiesta = Fiesta(likelihood,
prior,
n_chains = 1_000,
n_loop_training = 7,
n_loop_production = 3,
num_layers = 4,
hidden_size = [64, 64],
n_epochs = 20,
n_local_steps = 50,
n_global_steps = 200,
local_sampler_arg=local_sampler_arg,
outdir = outdir)

fiesta.sample(jax.random.PRNGKey(42))

fiesta.print_summary()

name = outdir + f'results_training.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, log_prob=log_prob, local_accs=local_accs,
global_accs=global_accs, loss_vals=loss_vals)

# - production phase
name = outdir + f'results_production.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=False)
chains, log_prob, local_accs, global_accs = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, chains=chains, log_prob=log_prob,
local_accs=local_accs, global_accs=global_accs)

################
### PLOTTING ###
################
# Fixed names: do not include them in the plotting, as will break corner
parameter_names = prior.naming
truths = [injection_dict[key] for key in parameter_names]

n_chains, n_steps, n_dim = np.shape(chains)
samples = np.reshape(chains, (n_chains * n_steps, n_dim))
samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption

corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs)
plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight')
plt.close()

end_time = time.time()
runtime_seconds = end_time - start_time
number_of_minutes = runtime_seconds // 60
number_of_seconds = np.round(runtime_seconds % 60, 2)
print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s")

print("Plotting lightcurves")
fiesta.plot_lightcurves()
print("Plotting lightcurves . . . done")

print("DONE")
Binary file added examples/GRB/injection_gaussian/corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions examples/GRB/injection_gaussian/injection_gaussian.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
2020-01-02T00:00:00.000 radio-3GHz 6.473406 0.500000
2020-04-10T22:48:00.000 radio-3GHz 12.246090 0.500000
2020-07-19T21:36:00.000 radio-3GHz 13.860922 0.500000
2020-10-27T20:24:00.000 radio-3GHz 15.145347 0.500000
2021-02-04T19:12:00.000 radio-3GHz 16.200114 0.500000
2021-05-15T18:00:00.000 radio-3GHz 17.037075 0.500000
2021-08-23T16:48:00.000 radio-3GHz 17.709738 0.500000
2021-12-01T15:36:00.000 radio-3GHz 18.265443 0.500000
2022-03-11T14:24:00.000 radio-3GHz 18.738014 0.500000
2022-06-19T13:12:00.000 radio-3GHz 19.142404 0.500000
2022-09-27T12:00:00.000 radio-3GHz 19.496765 0.500000
2023-01-05T10:48:00.000 radio-3GHz 19.810069 0.500000
2023-04-15T09:36:00.000 radio-3GHz 20.091715 0.500000
2023-07-24T08:24:00.000 radio-3GHz 20.346828 0.500000
2023-11-01T07:12:00.000 radio-3GHz 20.578875 0.500000
2024-02-09T06:00:00.000 radio-3GHz 20.791826 0.500000
2024-05-19T04:48:00.000 radio-3GHz 20.989639 0.500000
2024-08-27T03:36:00.000 radio-3GHz 21.171644 0.500000
2024-12-05T02:24:00.000 radio-3GHz 21.341979 0.500000
2025-03-15T01:12:00.000 radio-3GHz 21.500904 0.500000
2025-06-23T00:00:00.000 radio-3GHz 21.652456 0.500000
2020-01-02T00:00:00.000 radio-6GHz 6.642397 0.500000
2020-05-05T22:30:00.000 radio-6GHz 13.306995 0.500000
2020-09-07T21:00:00.000 radio-6GHz 15.133077 0.500000
2021-01-10T19:30:00.000 radio-6GHz 16.559615 0.500000
2021-05-15T18:00:00.000 radio-6GHz 17.639135 0.500000
2021-09-17T16:30:00.000 radio-6GHz 18.460232 0.500000
2022-01-20T15:00:00.000 radio-6GHz 19.112092 0.500000
2022-05-25T13:30:00.000 radio-6GHz 19.646921 0.500000
2022-09-27T12:00:00.000 radio-6GHz 20.098825 0.500000
2023-01-30T10:30:00.000 radio-6GHz 20.487108 0.500000
2023-06-04T09:00:00.000 radio-6GHz 20.824446 0.500000
2023-10-07T07:30:00.000 radio-6GHz 21.125451 0.500000
2024-02-09T06:00:00.000 radio-6GHz 21.393886 0.500000
2024-06-13T04:30:00.000 radio-6GHz 21.637560 0.500000
2024-10-16T03:00:00.000 radio-6GHz 21.859776 0.500000
2025-02-18T01:30:00.000 radio-6GHz 22.064226 0.500000
2025-06-23T00:00:00.000 radio-6GHz 22.254516 0.500000
2020-01-02T00:00:00.000 X-ray-1keV 21.209144 0.500000
2020-04-06T04:34:17.143 X-ray-1keV 27.962903 0.500000
2020-07-10T09:08:34.286 X-ray-1keV 29.537845 0.500000
2020-10-13T13:42:51.429 X-ray-1keV 30.789609 0.500000
2021-01-16T18:17:08.571 X-ray-1keV 31.828374 0.500000
2021-04-21T22:51:25.714 X-ray-1keV 32.666384 0.500000
2021-07-26T03:25:42.857 X-ray-1keV 33.344842 0.500000
2021-10-29T08:00:00.000 X-ray-1keV 33.904345 0.500000
2022-02-01T12:34:17.143 X-ray-1keV 34.378401 0.500000
2022-05-07T17:08:34.286 X-ray-1keV 34.787830 0.500000
2022-08-10T21:42:51.429 X-ray-1keV 35.144792 0.500000
2022-11-14T02:17:08.571 X-ray-1keV 35.464131 0.500000
2023-02-17T06:51:25.714 X-ray-1keV 35.747110 0.500000
2023-05-23T11:25:42.857 X-ray-1keV 36.004242 0.500000
2023-08-26T16:00:00.000 X-ray-1keV 36.238919 0.500000
2023-11-29T20:34:17.143 X-ray-1keV 36.454925 0.500000
2024-03-04T01:08:34.286 X-ray-1keV 36.654020 0.500000
2024-06-07T05:42:51.429 X-ray-1keV 36.837070 0.500000
2024-09-10T10:17:08.571 X-ray-1keV 37.010718 0.500000
2024-12-14T14:51:25.714 X-ray-1keV 37.169814 0.500000
2025-03-19T19:25:42.857 X-ray-1keV 37.321108 0.500000
2025-06-23T00:00:00.000 X-ray-1keV 37.465122 0.500000
2020-01-02T00:00:00.000 bessellv 15.916906 0.500000
2020-08-11T02:40:00.000 bessellv 24.685939 0.500000
2021-03-21T05:20:00.000 bessellv 27.115049 0.500000
2021-10-29T08:00:00.000 bessellv 28.612108 0.500000
2022-06-08T10:40:00.000 bessellv 29.619683 0.500000
2023-01-16T13:20:00.000 bessellv 30.363185 0.500000
2023-08-26T16:00:00.000 bessellv 30.946682 0.500000
2024-04-04T18:40:00.000 bessellv 31.423505 0.500000
2024-11-12T21:20:00.000 bessellv 31.826433 0.500000
2025-06-23T00:00:00.000 bessellv 32.172885 0.500000
Binary file added examples/GRB/injection_gaussian/lightcurves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit aaf0eec

Please sign in to comment.