Skip to content

Commit

Permalink
Merge pull request #10 from grburgess/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
grburgess authored Sep 4, 2023
2 parents f1b4f18 + 4d510de commit c48c448
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 26 deletions.
2 changes: 1 addition & 1 deletion ronswanson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .simulation import Simulation
from .simulation_builder import SimulationBuilder
from .grids import ParameterGrid
from .utils import ronswanson_config, show_configuration
from .utils import ronswanson_config, show_configuration, generate_lhs_unit_cube
from .utils.logging import update_logging_level

__all__ = [
Expand Down
212 changes: 208 additions & 4 deletions ronswanson/database.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import collections
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from pathlib import Path
from typing import Dict, List, Optional, Union


import astropy.units as u
import h5py
import numpy as np
import plotly.graph_objects as go
from astromodels import TemplateModel, TemplateModelFactory
from astromodels.functions.template_model import TemplateFile
from astromodels.utils import get_user_data_path
from astromodels.utils.logging import silence_console_log
from joblib import Parallel, delayed
from tqdm.auto import tqdm

from ronswanson.grids import Parameter, ParameterGrid
from ronswanson.utils.cartesian_product import cartesian_jit
from ronswanson.utils.color import Colors

from .utils.logging import setup_logger

from joblib import Parallel, delayed

log = setup_logger(__name__)


Expand Down Expand Up @@ -66,6 +69,12 @@ def __init__(

self._parameter_names: List[str] = parameter_names

for i, name in enumerate(self._parameter_names):

if not isinstance(name, str):

self._parameter_names[i] = name.decode()

self._energy_grid: np.ndarray = energy_grid

# self._values: Dict[str, np.ndarray] = values
Expand Down Expand Up @@ -157,6 +166,12 @@ def energy_grid(self) -> np.ndarray:

return self._energy_grid

@property
def energy_grid_nu(self) -> np.ndarray:

return (self._energy_grid * u.keV).to("Hz", equivalencies = u.spectral())


@property
def meta_data(self) -> Optional[Dict[str, np.ndarray]]:

Expand Down Expand Up @@ -275,6 +290,48 @@ def from_file(cls, file_name: str, output: int = 0) -> "Database":
meta_data=meta_data,
)

def to_hdf5(
self, file_name: Union[str, Path], overwrite: bool = False
) -> None:

path = Path(file_name).absolute()

if path.exists() and (not overwrite):

msg = f"{path} exists!"

log.error(msg)

raise RuntimeError(msg)

with h5py.File(path.as_posix(), "w") as f:

energy_grp: h5py.Group = f.create_group("energy_grid")

energy_grp.create_dataset("energy_grid_0", data=self._energy_grid)

values_grp = f.create_group("values")

values_grp.create_dataset("output_0", data=self._values)

par_name_grp = f.create_group("parameter_names")

for i, name in enumerate(self._parameter_names):

par_name_grp.attrs[f"par{i}"] = name

f.create_dataset("parameters", data=self._grid_points)

f.create_dataset("run_time", data=self._run_time)

if self._meta_data is not None:

meta_grp = f.create_group("meta")

for k, v in self._meta_data.items():

meta_grp.create_dataset(k, data=v)

def replace_nan_inf_with(self, value: float = 0.0) -> None:

"""
Expand Down Expand Up @@ -342,6 +399,22 @@ def _get_sub_selection(
selection=selection,
)

def _get_sub_selection_via_index(
self, selection_index: np.ndarray
) -> SelectionContainer:

sub_grid = self._grid_points[selection_index, ...]
sub_values = self._values[selection_index, ...]

sub_parameter_ranges = {}

return SelectionContainer(
sub_grid=sub_grid,
sub_values=sub_values,
sub_range=sub_parameter_ranges,
selection=selection_index,
)

# @classmethod
# def create_sub_selected_database(self, **selection) -> "Database":

Expand Down Expand Up @@ -408,6 +481,96 @@ def to_3ml(

return TemplateModel(name)

@classmethod
def from_astromodels(cls, model_name: str) -> "Database":
# Get the data directory

data_dir_path: Path = get_user_data_path()

# Sanitize the data file

filename_sanitized = data_dir_path.absolute() / f"{model_name}.h5"

if not filename_sanitized.exists():

msg = f"The data file {filename_sanitized} does not exists. Did you use the TemplateFactory?"

log.error(msg)

raise RuntimeError(msg)

# Open the template definition and read from it

data_file: Path = filename_sanitized

# use the file shadow to read

template_file: TemplateFile = TemplateFile.from_file(
filename_sanitized.as_posix()
)

parameters_grids = []

for key in template_file.parameter_order:

try:

# sometimes this is
# stored binary

k = key.decode()

except (AttributeError):

# if not, then we
# load as a normal str

k = key

parameters_grids.append(np.array(template_file.parameters[key]))

parameter_grid_cart = cartesian_jit(parameters_grids)

energies = template_file.energies

shape = 1
for dim in template_file.grid.shape[:-1]:
shape *= dim

values = template_file.grid.reshape(shape, template_file.grid.shape[-1])

return cls(
grid_points=parameter_grid_cart,
parameter_names=template_file.parameter_order,
energy_grid=energies,
run_time=np.zeros(parameter_grid_cart.shape[0]),
values=values,
)

def new_from_selections(
self, selection_index: Optional[np.ndarray] = None, **selections
) -> "Database":

if selection_index is None:

selection_container: SelectionContainer = self._get_sub_selection(
selections
)

else:

selection_container: SelectionContainer = (
self._get_sub_selection_via_index(selection_index)
)

return Database(
selection_container.sub_grid,
self.parameter_names,
self.energy_grid,
self._run_time[selection_container.selection],
selection_container.sub_values,
)

def check_for_missing_parameters(
self, parameter_grid: ParameterGrid, create_new_grid: bool = False
) -> None:
Expand Down Expand Up @@ -630,6 +793,47 @@ def update_database(
f["run_time"][idx] = r.attrs["run_time"]


def merge_outputs(
*files_names: List[Union[str, Path]], out_file_name: Union[str, Path]
) -> None:

with h5py.File(out_file_name, "w") as out_file:

energy_grp: h5py.Group = out_file.create_group("energy_grid")
values_grp = out_file.create_group("values")
par_name_grp = out_file.create_group("parameter_names")

for n_output, file_name in enumerate(files_names):

with h5py.File(file_name, "r") as f:

if n_output == 0:

for k, v in f["parameter_names"].attrs.items():

par_name_grp.attrs[k] = v

if "meta" in f.keys():

meta_grp = out_file.create_group("meta")

for key in list(f["meta"].keys()):

meta_grp.create_dataset(key, data=f[f"meta/{key}"])

out_file.create_dataset("run_time", data=f["run_time"])

out_file.create_dataset("parameters", data=f["parameters"])

energy_grp.create_dataset(
f"energy_grid_{n_output}",
data=f["energy_grid/energy_grid_0"],
)
values_grp.create_dataset(
f"output_{n_output}", data=f["values/output_0"]
)


def merge_databases(
*file_names: List[str], new_name: str = "merged_db.h5"
) -> None:
Expand Down
24 changes: 24 additions & 0 deletions ronswanson/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def grid(self) -> np.ndarray:
@classmethod
def from_dict(cls, name: str, d: Dict[str, Any]) -> "Parameter":

log.debug(f"read parameter: {name}")
log.debug(f"inputs: {d}")

inputs = {}
inputs["custom"] = d["custom"]

Expand Down Expand Up @@ -258,6 +261,19 @@ def __post_init__(self):
cartesian_jit([p.grid for p in self.parameter_list]),
)

pass

@property
def min_max_values(self) -> np.ndarray:

out = []

for p in self.parameter_list:

out.append([min(p.grid), max(p.grid)])

return np.array(out)

@property
def n_points(self) -> int:

Expand Down Expand Up @@ -296,6 +312,8 @@ def from_dict(cls, d: Dict[str, Dict[str, Any]]) -> "ParameterGrid":

is_multi_output = True

log.debug(f"found {n_energy_grids} energy grids")

if not is_multi_output:

energy_grid = [EnergyGrid.from_dict(d.pop("energy_grid"))]
Expand All @@ -307,6 +325,8 @@ def from_dict(cls, d: Dict[str, Dict[str, Any]]) -> "ParameterGrid":
for i in range(n_energy_grids)
]

log.debug("now reading parameters")

pars = list(d.keys())

pars.sort()
Expand All @@ -315,11 +335,15 @@ def from_dict(cls, d: Dict[str, Dict[str, Any]]) -> "ParameterGrid":
Parameter.from_dict(par_name, d[par_name]) for par_name in pars
]

log.debug("parameters have been read in")

return cls(par_list, energy_grid)

@classmethod
def from_yaml(cls, file_name: str) -> "ParameterGrid":

log.debug(f"reading: {file_name}")

with open(file_name, 'r') as f:

inputs = yaml.load(stream=f, Loader=yaml.SafeLoader)
Expand Down
Loading

0 comments on commit c48c448

Please sign in to comment.