-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'train_fluxes' of https://github.com/haukekoehn/fiesta i…
…nto haukekoehn-train_fluxes
- Loading branch information
Showing
175 changed files
with
3,394 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
"""Injection runs with afterglowpy gaussian""" | ||
|
||
import os | ||
import jax | ||
print(f"GPU found? {jax.devices()}") | ||
import jax.numpy as jnp | ||
jax.config.update("jax_enable_x64", True) | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import corner | ||
|
||
from fiesta.inference.lightcurve_model import AfterglowpyPCA, PCALightcurveModel | ||
from fiesta.inference.injection import InjectionRecoveryAfterglowpy | ||
from fiesta.inference.likelihood import EMLikelihood | ||
from fiesta.inference.prior import Uniform, CompositePrior, Constraint | ||
from fiesta.inference.prior_dict import ConstrainedPrior | ||
from fiesta.inference.fiesta import Fiesta | ||
from fiesta.utils import load_event_data, write_event_data | ||
|
||
import time | ||
start_time = time.time() | ||
|
||
################ | ||
### Preamble ### | ||
################ | ||
|
||
jax.config.update("jax_enable_x64", True) | ||
|
||
params = {"axes.grid": True, | ||
"text.usetex" : True, | ||
"font.family" : "serif", | ||
"ytick.color" : "black", | ||
"xtick.color" : "black", | ||
"axes.labelcolor" : "black", | ||
"axes.edgecolor" : "black", | ||
"font.serif" : ["Computer Modern Serif"], | ||
"xtick.labelsize": 16, | ||
"ytick.labelsize": 16, | ||
"axes.labelsize": 16, | ||
"legend.fontsize": 16, | ||
"legend.title_fontsize": 16, | ||
"figure.titlesize": 16} | ||
|
||
plt.rcParams.update(params) | ||
|
||
default_corner_kwargs = dict(bins=40, | ||
smooth=1., | ||
label_kwargs=dict(fontsize=16), | ||
title_kwargs=dict(fontsize=16), | ||
color="blue", | ||
# quantiles=[], | ||
# levels=[0.9], | ||
plot_density=True, | ||
plot_datapoints=False, | ||
fill_contours=True, | ||
max_n_ticks=4, | ||
min_n_ticks=3, | ||
save=False, | ||
truth_color="red") | ||
|
||
|
||
############## | ||
### MODEL ### | ||
############## | ||
|
||
name = "gaussian" | ||
model_dir = f"../../flux_models/afterglowpy_{name}/model" | ||
FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"] | ||
|
||
model = AfterglowpyPCA(name, | ||
model_dir, | ||
filters = FILTERS) | ||
|
||
|
||
################### | ||
### INJECT ### | ||
### AFTERGLOWPY ### | ||
################### | ||
|
||
trigger_time = 58849 # 01-01-2020 in mjd | ||
remake_injection = False | ||
injection_dict = {"inclination_EM": 0.174, "log10_E0": 54.4, "thetaCore": 0.14, "alphaWing": 3, "p": 2.6, "log10_n0": -2, "log10_epsilon_e": -2.06, "log10_epsilon_B": -4.2, "luminosity_distance": 40.0} | ||
|
||
if remake_injection: | ||
injection = InjectionRecoveryAfterglowpy(injection_dict, jet_type = 0, filters = FILTERS, N_datapoints = 70, error_budget = 0.5, tmin = 1, tmax = 2000, trigger_time = trigger_time) | ||
injection.create_injection() | ||
data = injection.data | ||
write_event_data("./injection_gaussian/injection_gaussian.dat", data) | ||
|
||
data = load_event_data("./injection_gaussian/injection_gaussian.dat") | ||
############################# | ||
### PRIORS AND LIKELIHOOD ### | ||
############################# | ||
|
||
inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM']) | ||
log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0']) | ||
thetaCore = Uniform(xmin=0.01, xmax=np.pi/5, naming=['thetaCore']) | ||
alphaWing = Uniform(xmin = 0.2, xmax = 3.5, naming= ["alphaWing"]) | ||
thetaWing = Constraint(xmin = 0, xmax = np.pi/2, naming = ["thetaWing"]) | ||
log10_n0 = Uniform(xmin=-6.0, xmax=2.0, naming=['log10_n0']) | ||
p = Uniform(xmin=2.01, xmax=3.0, naming=['p']) | ||
log10_epsilon_e = Uniform(xmin=-4.0, xmax=0.0, naming=['log10_epsilon_e']) | ||
log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B']) | ||
epsilon_tot = Constraint(xmin = 0, xmax = 1, naming = ["epsilon_tot"]) | ||
|
||
# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance']) | ||
def conversion_function(sample): | ||
converted_sample = sample | ||
converted_sample["thetaWing"] = converted_sample["thetaCore"] * converted_sample["alphaWing"] | ||
converted_sample["epsilon_tot"] = 10**(converted_sample["log10_epsilon_B"]) + 10**(converted_sample["log10_epsilon_e"]) | ||
return converted_sample | ||
|
||
prior_list = [inclination_EM, | ||
log10_E0, | ||
thetaCore, | ||
alphaWing, | ||
log10_n0, | ||
p, | ||
log10_epsilon_e, | ||
log10_epsilon_B, | ||
thetaWing, | ||
epsilon_tot] | ||
|
||
prior = ConstrainedPrior(prior_list, conversion_function) | ||
|
||
detection_limit = None | ||
likelihood = EMLikelihood(model, | ||
data, | ||
FILTERS, | ||
tmax = 2000.0, | ||
trigger_time=trigger_time, | ||
detection_limit = detection_limit, | ||
fixed_params={"luminosity_distance": 40.0}, | ||
error_budget = 1e-5) | ||
|
||
|
||
############## | ||
### FIESTA ### | ||
############## | ||
|
||
mass_matrix = jnp.eye(prior.n_dim) | ||
eps = 5e-3 | ||
local_sampler_arg = {"step_size": mass_matrix * eps} | ||
|
||
# Save for postprocessing | ||
outdir = f"./injection_{name}/" | ||
if not os.path.exists(outdir): | ||
os.makedirs(outdir) | ||
|
||
fiesta = Fiesta(likelihood, | ||
prior, | ||
n_chains = 1_000, | ||
n_loop_training = 7, | ||
n_loop_production = 3, | ||
num_layers = 4, | ||
hidden_size = [64, 64], | ||
n_epochs = 20, | ||
n_local_steps = 50, | ||
n_global_steps = 200, | ||
local_sampler_arg=local_sampler_arg, | ||
outdir = outdir) | ||
|
||
fiesta.sample(jax.random.PRNGKey(42)) | ||
|
||
fiesta.print_summary() | ||
|
||
name = outdir + f'results_training.npz' | ||
print(f"Saving samples to {name}") | ||
state = fiesta.Sampler.get_sampler_state(training=True) | ||
chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[ | ||
"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"] | ||
local_accs = jnp.mean(local_accs, axis=0) | ||
global_accs = jnp.mean(global_accs, axis=0) | ||
np.savez(name, log_prob=log_prob, local_accs=local_accs, | ||
global_accs=global_accs, loss_vals=loss_vals) | ||
|
||
# - production phase | ||
name = outdir + f'results_production.npz' | ||
print(f"Saving samples to {name}") | ||
state = fiesta.Sampler.get_sampler_state(training=False) | ||
chains, log_prob, local_accs, global_accs = state["chains"], state[ | ||
"log_prob"], state["local_accs"], state["global_accs"] | ||
local_accs = jnp.mean(local_accs, axis=0) | ||
global_accs = jnp.mean(global_accs, axis=0) | ||
np.savez(name, chains=chains, log_prob=log_prob, | ||
local_accs=local_accs, global_accs=global_accs) | ||
|
||
################ | ||
### PLOTTING ### | ||
################ | ||
# Fixed names: do not include them in the plotting, as will break corner | ||
parameter_names = prior.naming | ||
truths = [injection_dict[key] for key in parameter_names] | ||
|
||
n_chains, n_steps, n_dim = np.shape(chains) | ||
samples = np.reshape(chains, (n_chains * n_steps, n_dim)) | ||
samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption | ||
|
||
corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs) | ||
plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight') | ||
plt.close() | ||
|
||
end_time = time.time() | ||
runtime_seconds = end_time - start_time | ||
number_of_minutes = runtime_seconds // 60 | ||
number_of_seconds = np.round(runtime_seconds % 60, 2) | ||
print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s") | ||
|
||
print("Plotting lightcurves") | ||
fiesta.plot_lightcurves() | ||
print("Plotting lightcurves . . . done") | ||
|
||
print("DONE") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
2020-01-02T00:00:00.000 radio-3GHz 6.473406 0.500000 | ||
2020-04-10T22:48:00.000 radio-3GHz 12.246090 0.500000 | ||
2020-07-19T21:36:00.000 radio-3GHz 13.860922 0.500000 | ||
2020-10-27T20:24:00.000 radio-3GHz 15.145347 0.500000 | ||
2021-02-04T19:12:00.000 radio-3GHz 16.200114 0.500000 | ||
2021-05-15T18:00:00.000 radio-3GHz 17.037075 0.500000 | ||
2021-08-23T16:48:00.000 radio-3GHz 17.709738 0.500000 | ||
2021-12-01T15:36:00.000 radio-3GHz 18.265443 0.500000 | ||
2022-03-11T14:24:00.000 radio-3GHz 18.738014 0.500000 | ||
2022-06-19T13:12:00.000 radio-3GHz 19.142404 0.500000 | ||
2022-09-27T12:00:00.000 radio-3GHz 19.496765 0.500000 | ||
2023-01-05T10:48:00.000 radio-3GHz 19.810069 0.500000 | ||
2023-04-15T09:36:00.000 radio-3GHz 20.091715 0.500000 | ||
2023-07-24T08:24:00.000 radio-3GHz 20.346828 0.500000 | ||
2023-11-01T07:12:00.000 radio-3GHz 20.578875 0.500000 | ||
2024-02-09T06:00:00.000 radio-3GHz 20.791826 0.500000 | ||
2024-05-19T04:48:00.000 radio-3GHz 20.989639 0.500000 | ||
2024-08-27T03:36:00.000 radio-3GHz 21.171644 0.500000 | ||
2024-12-05T02:24:00.000 radio-3GHz 21.341979 0.500000 | ||
2025-03-15T01:12:00.000 radio-3GHz 21.500904 0.500000 | ||
2025-06-23T00:00:00.000 radio-3GHz 21.652456 0.500000 | ||
2020-01-02T00:00:00.000 radio-6GHz 6.642397 0.500000 | ||
2020-05-05T22:30:00.000 radio-6GHz 13.306995 0.500000 | ||
2020-09-07T21:00:00.000 radio-6GHz 15.133077 0.500000 | ||
2021-01-10T19:30:00.000 radio-6GHz 16.559615 0.500000 | ||
2021-05-15T18:00:00.000 radio-6GHz 17.639135 0.500000 | ||
2021-09-17T16:30:00.000 radio-6GHz 18.460232 0.500000 | ||
2022-01-20T15:00:00.000 radio-6GHz 19.112092 0.500000 | ||
2022-05-25T13:30:00.000 radio-6GHz 19.646921 0.500000 | ||
2022-09-27T12:00:00.000 radio-6GHz 20.098825 0.500000 | ||
2023-01-30T10:30:00.000 radio-6GHz 20.487108 0.500000 | ||
2023-06-04T09:00:00.000 radio-6GHz 20.824446 0.500000 | ||
2023-10-07T07:30:00.000 radio-6GHz 21.125451 0.500000 | ||
2024-02-09T06:00:00.000 radio-6GHz 21.393886 0.500000 | ||
2024-06-13T04:30:00.000 radio-6GHz 21.637560 0.500000 | ||
2024-10-16T03:00:00.000 radio-6GHz 21.859776 0.500000 | ||
2025-02-18T01:30:00.000 radio-6GHz 22.064226 0.500000 | ||
2025-06-23T00:00:00.000 radio-6GHz 22.254516 0.500000 | ||
2020-01-02T00:00:00.000 X-ray-1keV 21.209144 0.500000 | ||
2020-04-06T04:34:17.143 X-ray-1keV 27.962903 0.500000 | ||
2020-07-10T09:08:34.286 X-ray-1keV 29.537845 0.500000 | ||
2020-10-13T13:42:51.429 X-ray-1keV 30.789609 0.500000 | ||
2021-01-16T18:17:08.571 X-ray-1keV 31.828374 0.500000 | ||
2021-04-21T22:51:25.714 X-ray-1keV 32.666384 0.500000 | ||
2021-07-26T03:25:42.857 X-ray-1keV 33.344842 0.500000 | ||
2021-10-29T08:00:00.000 X-ray-1keV 33.904345 0.500000 | ||
2022-02-01T12:34:17.143 X-ray-1keV 34.378401 0.500000 | ||
2022-05-07T17:08:34.286 X-ray-1keV 34.787830 0.500000 | ||
2022-08-10T21:42:51.429 X-ray-1keV 35.144792 0.500000 | ||
2022-11-14T02:17:08.571 X-ray-1keV 35.464131 0.500000 | ||
2023-02-17T06:51:25.714 X-ray-1keV 35.747110 0.500000 | ||
2023-05-23T11:25:42.857 X-ray-1keV 36.004242 0.500000 | ||
2023-08-26T16:00:00.000 X-ray-1keV 36.238919 0.500000 | ||
2023-11-29T20:34:17.143 X-ray-1keV 36.454925 0.500000 | ||
2024-03-04T01:08:34.286 X-ray-1keV 36.654020 0.500000 | ||
2024-06-07T05:42:51.429 X-ray-1keV 36.837070 0.500000 | ||
2024-09-10T10:17:08.571 X-ray-1keV 37.010718 0.500000 | ||
2024-12-14T14:51:25.714 X-ray-1keV 37.169814 0.500000 | ||
2025-03-19T19:25:42.857 X-ray-1keV 37.321108 0.500000 | ||
2025-06-23T00:00:00.000 X-ray-1keV 37.465122 0.500000 | ||
2020-01-02T00:00:00.000 bessellv 15.916906 0.500000 | ||
2020-08-11T02:40:00.000 bessellv 24.685939 0.500000 | ||
2021-03-21T05:20:00.000 bessellv 27.115049 0.500000 | ||
2021-10-29T08:00:00.000 bessellv 28.612108 0.500000 | ||
2022-06-08T10:40:00.000 bessellv 29.619683 0.500000 | ||
2023-01-16T13:20:00.000 bessellv 30.363185 0.500000 | ||
2023-08-26T16:00:00.000 bessellv 30.946682 0.500000 | ||
2024-04-04T18:40:00.000 bessellv 31.423505 0.500000 | ||
2024-11-12T21:20:00.000 bessellv 31.826433 0.500000 | ||
2025-06-23T00:00:00.000 bessellv 32.172885 0.500000 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.