Skip to content

Commit

Permalink
Revert "Be more strict about jax version, and change name from fiesta…
Browse files Browse the repository at this point in the history
… to fiestaEM in src code"

This reverts commit 65ce549.
  • Loading branch information
ThibeauWouters committed Dec 19, 2024
1 parent 65ce549 commit 5ba014e
Show file tree
Hide file tree
Showing 19 changed files with 117 additions and 40 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ packages_dir=
=src
packages = find:
install_requires =
jax<=0.4.31
jaxlib<=0.4.31
jax>=0.4.24
jaxlib>=0.4.24
numpy<2.0.0
pandas<2.0.0
jaxtyping
Expand Down
1 change: 0 additions & 1 deletion src/.gitignore

This file was deleted.

54 changes: 54 additions & 0 deletions src/fiesta.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
Metadata-Version: 2.1
Name: fiesta
Version: 0.0.1
Summary: Fast inference of electromagnetic signals with JAX
Home-page: https://github.com/thibeauwouters/fiesta
Author: Thibeau Wouters
Author-email: thibeauwouters@gmail.com
License: MIT
Keywords: sampling,inference,astrophysics,kilonovae,gamma-ray bursts
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.24
Requires-Dist: jaxlib>=0.4.24
Requires-Dist: numpy<2.0.0
Requires-Dist: pandas<2.0.0
Requires-Dist: jaxtyping
Requires-Dist: beartype
Requires-Dist: tqdm
Requires-Dist: scipy<=1.14.0
Requires-Dist: ml_collections
Requires-Dist: astropy
Requires-Dist: sncosmo
Requires-Dist: flowMC
Requires-Dist: joblib

# fiesta 🎉

`fiesta`: **F**ast **I**nference of **E**lectromagnetic **S**ignals and **T**ransients with j**A**x

![fiesta logo](docs/fiesta_logo.jpeg)

**NOTE:** `fiesta` is currently under development -- stay tuned!

## Installation

pip installation is currently work in progress. Install from source by cloning this Github repository and running
```
pip install -e .
```

NOTE: This is using an older and custom version of `flowMC`. Install by cloning the `flowMC` version at [this fork](https://github.com/ThibeauWouters/flowMC/tree/fiesta) (branch `fiesta`).

## Training surrogate models

To train your own surrogate models, have a look at some of the example scripts in the repository for inspiration, under `trained_models`

- `train_Bu2019lm.py`: Example script showing how to train a surrogate model for the POSSIS `Bu2019lm` kilonova model.
- `train_afterglowpy_tophat.py`: Example script showing how to train a surrogate model for `afterglowpy`, using a tophat jet structure.

## Examples

- `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model.
- `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon.
9 changes: 9 additions & 0 deletions src/fiesta.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
LICENSE
README.md
pyproject.toml
setup.cfg
src/fiesta.egg-info/PKG-INFO
src/fiesta.egg-info/SOURCES.txt
src/fiesta.egg-info/dependency_links.txt
src/fiesta.egg-info/requires.txt
src/fiesta.egg-info/top_level.txt
1 change: 1 addition & 0 deletions src/fiesta.egg-info/dependency_links.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

13 changes: 13 additions & 0 deletions src/fiesta.egg-info/requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
jax>=0.4.24
jaxlib>=0.4.24
numpy<2.0.0
pandas<2.0.0
jaxtyping
beartype
tqdm
scipy<=1.14.0
ml_collections
astropy
sncosmo
flowMC
joblib
1 change: 1 addition & 0 deletions src/fiesta.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

File renamed without changes.
2 changes: 1 addition & 1 deletion src/fiestaEM/conversions.py → src/fiesta/conversions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fiestaEM.constants import pc_to_cm
from fiesta.constants import pc_to_cm
import jax
import jax.numpy as jnp
from jaxtyping import Array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import jax.numpy as jnp
from jaxtyping import Float, Array, PRNGKeyArray

from fiestaEM.inference.lightcurve_model import LightcurveModel
from fiestaEM.inference.prior import Prior
from fiestaEM.inference.likelihood import EMLikelihood
from fiestaEM.conversions import mag_app_from_mag_abs
from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.inference.prior import Prior
from fiesta.inference.likelihood import EMLikelihood
from fiesta.conversions import mag_app_from_mag_abs

from flowMC.sampler.Sampler import Sampler
from flowMC.sampler.MALA import MALA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import jax.numpy as jnp
from jaxtyping import Float, Array

from fiestaEM.inference.lightcurve_model import LightcurveModel
from fiestaEM.conversions import mag_app_from_mag_abs
from fiestaEM.utils import Filter
from fiestaEM.constants import days_to_seconds, c
from fiestaEM import conversions
from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.conversions import mag_app_from_mag_abs
from fiesta.utils import Filter
from fiesta.constants import days_to_seconds, c
from fiesta import conversions

import afterglowpy as grb

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from flax.training.train_state import TrainState
import pickle

import fiestaEM.train.neuralnets as fiestaEM_nn
from fiestaEM.utils import MinMaxScalerJax, inverse_svd_transform
import fiestaEM.conversions as conversions
from fiestaEM import models_utilities
import fiesta.train.neuralnets as fiesta_nn
from fiesta.utils import MinMaxScalerJax, inverse_svd_transform
import fiesta.conversions as conversions
from fiesta import models_utilities

########################
### ABSTRACT CLASSES ###
Expand Down Expand Up @@ -184,7 +184,7 @@ def load_networks(self) -> None:
self.models = {}
for filter in self.filters:
filename = os.path.join(self.directory, f"{filter}.pkl")
state, _ = fiestaEM_nn.load_model(filename)
state, _ = fiesta_nn.load_model(filename)
self.models[filter] = state

def load_parameter_names(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jaxtyping import Float, Array
import jax.numpy as jnp

from fiestaEM.inference.lightcurve_model import LightcurveModel
from fiestaEM.utils import truncated_gaussian
from fiestaEM.conversions import mag_app_from_mag_abs
from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.utils import truncated_gaussian
from fiesta.conversions import mag_app_from_mag_abs

class EMLikelihood:

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fiestaEM.inference.lightcurve_model import AfterglowpyLightcurvemodel
from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel
import afterglowpy as grb
from fiestaEM.constants import days_to_seconds
from fiestaEM import conversions
from fiestaEM import utils
from fiestaEM.utils import Filter
from fiesta.constants import days_to_seconds
from fiesta import conversions
from fiesta import utils
from fiesta.utils import Filter

from jaxtyping import Array, Float

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int
from fiestaEM.utils import MinMaxScalerJax
from fiestaEM import utils
from fiestaEM.utils import Filter
from fiestaEM import conversions
from fiestaEM.constants import days_to_seconds, c
from fiestaEM import models_utilities
import fiestaEM.train.neuralnets as fiestaEM_nn
from fiesta.utils import MinMaxScalerJax
from fiesta import utils
from fiesta.utils import Filter
from fiesta import conversions
from fiesta.constants import days_to_seconds, c
from fiesta import models_utilities
import fiesta.train.neuralnets as fiesta_nn

import matplotlib.pyplot as plt
import pickle
Expand Down Expand Up @@ -51,7 +51,7 @@ class SurrogateTrainer:
val_X_raw: Float[Array, "n_batch n_params"]
val_y_raw: dict[str, Float[Array, "n_batch n_times"]]

trained_states: dict[str, fiestaEM_nn.TrainState]
trained_states: dict[str, fiesta_nn.TrainState]

def __init__(self,
name: str,
Expand Down Expand Up @@ -107,7 +107,7 @@ def preprocess(self):
print("Preprocessing data . . . done")

def fit(self,
config: fiestaEM_nn.NeuralnetConfig = None,
config: fiesta_nn.NeuralnetConfig = None,
key: jax.random.PRNGKey = jax.random.PRNGKey(0),
verbose: bool = True):
"""
Expand All @@ -119,7 +119,7 @@ def fit(self,

# Get default choices if no config is given
if config is None:
config = fiestaEM_nn.NeuralnetConfig()
config = fiesta_nn.NeuralnetConfig()
self.config = config

trained_states = {}
Expand All @@ -128,12 +128,12 @@ def fit(self,
for filt in self.filters:

# Create neural network and initialize the state
net = fiestaEM_nn.MLP(layer_sizes=config.layer_sizes)
net = fiesta_nn.MLP(layer_sizes=config.layer_sizes)
key, subkey = jax.random.split(key)
state = fiestaEM_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config)
state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config)

# Perform training loop
state, train_losses, val_losses = fiestaEM_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose)
state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose)

# Plot and save the plot if so desired
if self.plots_dir is not None:
Expand Down Expand Up @@ -180,7 +180,7 @@ def save(self):

for filt in self.filters:
model = self.trained_states[filt.name]
fiestaEM_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl")
fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl")
save[filt.name] = self.preprocessing_metadata[filt.name]

with open(meta_filename, "wb") as meta_file:
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 5ba014e

Please sign in to comment.