Skip to content

Commit

Permalink
WIP fixing tests and CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
daviesje committed Oct 24, 2024
1 parent fc3647d commit 8d0b758
Show file tree
Hide file tree
Showing 22 changed files with 616 additions and 628 deletions.
1 change: 1 addition & 0 deletions src/py21cmfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .cache_tools import query_cache
from .drivers.coeval import Coeval, run_coeval
from .drivers.lightcone import LightCone, run_lightcone
from .drivers.param_config import InputParameters
from .drivers.single_field import (
brightness_temperature,
compute_halo_grid,
Expand Down
113 changes: 51 additions & 62 deletions src/py21cmfast/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module that contains the command line app."""

import attrs
import builtins
import click
import inspect
Expand All @@ -11,10 +12,19 @@
from astropy import units as un
from os import path, remove
from pathlib import Path
from wrapper._utils import camel_to_snake

from . import _cfg, cache_tools, global_params, plotting
from . import wrapper as lib
from .drivers.coeval import run_coeval
from .drivers.lightcone import run_lightcone
from .drivers.single_field import (
compute_initial_conditions,
compute_ionization_field,
perturb_field,
spin_temperature,
)
from .lightcones import RectilinearLightconer
from .wrapper.inputs import AstroParams, CosmoParams, FlagOptions, UserParams


def _get_config(config=None):
Expand Down Expand Up @@ -62,18 +72,27 @@ def _update(obj, ctx):
pass


def _override(ctx, *param_dicts):
def _get_params_from_ctx(ctx, cfg):
# Try to use the extra arguments as an override of config.
ctx = _ctx_to_dct(ctx.args) if ctx.args else {}
params = {}
for cls in (UserParams, CosmoParams, AstroParams, FlagOptions):
fieldnames = [camel_to_snake(field.name) for field in attrs.fields(cls)]
ctx_params = {k: v for k, v in ctx.items() if k in fieldnames}
[ctx.pop(k) for k in ctx_params.keys()]
params[camel_to_snake(cls.__name__)] = ctx_params

user_params = UserParams.new(params["user_params"])
cosmo_params = CosmoParams.new(params["cosmo_params"])
flag_options = FlagOptions.new(params["flag_options"])
astro_params = AstroParams.new(params["astro_params"], flag_options=flag_options)

if ctx.args:
ctx = _ctx_to_dct(ctx.args)
for p in param_dicts:
_update(p, ctx)
# Also update globals, always.
_update(global_params, ctx)
if ctx:
warnings.warn("The following arguments were not able to be set: %s" % ctx)

# Also update globals, always.
_update(global_params, ctx)
if ctx:
warnings.warn("The following arguments were not able to be set: %s" % ctx)
return user_params, cosmo_params, astro_params, flag_options


main = click.Group()
Expand Down Expand Up @@ -126,14 +145,10 @@ def init(ctx, config, regen, direc, seed):
Random seed used to generate data.
"""
cfg = _get_config(config)

# Set user/cosmo params from config.
user_params = lib.UserParams(**cfg.get("user_params", {}))
cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {}))

_override(ctx, user_params, cosmo_params)
user_params, cosmo_params, _, _ = _get_params_from_ctx(ctx, cfg)

lib.initial_conditions(
compute_initial_conditions(
user_params=user_params,
cosmo_params=cosmo_params,
regenerate=regen,
Expand Down Expand Up @@ -193,14 +208,10 @@ def perturb(ctx, redshift, config, regen, direc, seed):
Random seed used to generate data.
"""
cfg = _get_config(config)

# Set user/cosmo params from config.
user_params = lib.UserParams(**cfg.get("user_params", {}))
cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {}))

_override(ctx, user_params, cosmo_params)
user_params, cosmo_params, _, _ = _get_params_from_ctx(ctx, cfg)

lib.perturb_field(
perturb_field(
redshift=redshift,
user_params=user_params,
cosmo_params=cosmo_params,
Expand Down Expand Up @@ -271,17 +282,12 @@ def spin(ctx, redshift, prev_z, config, regen, direc, seed):
"""
cfg = _get_config(config)

# Set user/cosmo params from config.
user_params = lib.UserParams(**cfg.get("user_params", {}))
cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {}))
flag_options = lib.FlagOptions(**cfg.get("flag_options", {}))
astro_params = lib.AstroParams(
**cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO
# Set params from config
user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx(
ctx, cfg
)

_override(ctx, user_params, cosmo_params, astro_params, flag_options)

lib.spin_temperature(
spin_temperature(
redshift=redshift,
astro_params=astro_params,
flag_options=flag_options,
Expand Down Expand Up @@ -355,19 +361,12 @@ def ionize(ctx, redshift, prev_z, config, regen, direc, seed):
"""
cfg = _get_config(config)

# Set user/cosmo params from config.
user_params = lib.UserParams(**cfg.get("user_params", {}))
cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {}))
flag_options = lib.FlagOptions(
**cfg.get("flag_options", {}),
)
astro_params = lib.AstroParams(
**cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO
# Set params from config
user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx(
ctx, cfg
)

_override(ctx, user_params, cosmo_params, astro_params, flag_options)

lib.ionize_box(
compute_ionization_field(
redshift=redshift,
astro_params=astro_params,
flag_options=flag_options,
Expand Down Expand Up @@ -450,17 +449,12 @@ def coeval(ctx, redshift, config, out, regen, direc, seed):

cfg = _get_config(config)

# Set user/cosmo params from config.
user_params = lib.UserParams(**cfg.get("user_params", {}))
cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {}))
flag_options = lib.FlagOptions(**cfg.get("flag_options", {}))
astro_params = lib.AstroParams(
**cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO
# Set params from config
user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx(
ctx, cfg
)

_override(ctx, user_params, cosmo_params, astro_params, flag_options)

coeval = lib.run_coeval(
coeval = run_coeval(
redshift=redshift,
astro_params=astro_params,
flag_options=flag_options,
Expand Down Expand Up @@ -565,16 +559,11 @@ def lightcone(ctx, redshift, config, out, regen, direc, max_z, seed, lq):
elif not out.parent.exists():
out.parent.mkdir()

# Set user/cosmo params from config.
user_params = lib.UserParams(**cfg.get("user_params", {}))
cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {}))
flag_options = lib.FlagOptions(**cfg.get("flag_options", {}))
astro_params = lib.AstroParams(
**cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO
# Set params from config
user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx(
ctx, cfg
)

_override(ctx, user_params, cosmo_params, astro_params, flag_options)

# For now, always use the old default lightconing algorithm
lcn = RectilinearLightconer.with_equal_cdist_slices(
min_redshift=redshift,
Expand All @@ -584,7 +573,7 @@ def lightcone(ctx, redshift, config, out, regen, direc, max_z, seed, lq):
quantities=lq,
)

lc = lib.run_lightcone(
lc = run_lightcone(
lightconer=lcn,
astro_params=astro_params,
flag_options=flag_options,
Expand Down Expand Up @@ -780,7 +769,7 @@ def pr_feature(

if lightcone:
print("Running default lightcone...")
lc_default = lib.run_lightcone(
lc_default = run_lightcone(
redshift=redshift,
max_redshift=max_redshift,
random_seed=random_seed,
Expand All @@ -790,7 +779,7 @@ def pr_feature(
structs[struct][param] = value

print("Running lightcone with new feature...")
lc_new = lib.run_lightcone(
lc_new = run_lightcone(
redshift=redshift,
max_redshift=max_redshift,
random_seed=random_seed,
Expand Down
41 changes: 21 additions & 20 deletions src/py21cmfast/drivers/coeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from typing import Any, Sequence

from .. import __version__
from .. import utils as _ut
from .._cfg import config
from ..c_21cmfast import lib
from ..wrapper._utils import camel_to_snake
from ..wrapper.globals import global_params
from ..wrapper.inputs import AstroParams, CosmoParams, FlagOptions, UserParams
from ..wrapper.outputs import (
Expand Down Expand Up @@ -456,9 +456,9 @@ def _read_particular(cls, fname):
kwargs = {}

with h5py.File(fname, "r") as fl:
for output_class in _ut.OutputStruct._implementations():
for output_class in _OutputStruct._implementations():
if output_class.__name__ in fl:
kwargs[_ut.camel_to_snake(output_class.__name__)] = (
kwargs[camel_to_snake(output_class.__name__)] = (
output_class.from_file(fname)
)

Expand All @@ -468,6 +468,7 @@ def __eq__(self, other):
"""Determine if this is equal to another object."""
return (
isinstance(other, self.__class__)
and other.random_seed == self.random_seed
and other.redshift == self.redshift
and self.user_params == other.user_params
and self.cosmo_params == other.cosmo_params
Expand Down Expand Up @@ -583,20 +584,18 @@ def run_coeval(
else random_seed
)

inputs = InputParameters(
random_seed=random_seed,
user_params=user_params,
cosmo_params=cosmo_params,
astro_params=astro_params,
flag_options=flag_options,
)
# For the high-level, we need all the InputStruct initialised
cosmo_params = CosmoParams.new(cosmo_params)
user_params = UserParams.new(user_params)
flag_options = FlagOptions.new(flag_options)
astro_params = AstroParams.new(astro_params, flag_options=flag_options)

iokw = {"regenerate": regenerate, "hooks": hooks, "direc": direc}

if initial_conditions is None:
initial_conditions = sf.compute_initial_conditions(
user_params=inputs.user_params,
cosmo_params=inputs.cosmo_params,
user_params=user_params,
cosmo_params=cosmo_params,
random_seed=random_seed,
**iokw,
)
Expand All @@ -617,8 +616,8 @@ def run_coeval(

kw = {
**{
"astro_params": inputs.astro_params,
"flag_options": inputs.flag_options,
"astro_params": astro_params,
"flag_options": flag_options,
"initial_conditions": initial_conditions,
},
**iokw,
Expand Down Expand Up @@ -662,7 +661,7 @@ def run_coeval(
)
# get the halos (reverse redshift order)
pt_halos = []
if inputs.flag_options.USE_HALO_FIELD and not inputs.flag_options.FIXED_HALO_GRIDS:
if flag_options.USE_HALO_FIELD and not flag_options.FIXED_HALO_GRIDS:
halos_desc = None
for i, z in enumerate(node_redshifts[::-1]):
halos = sf.determine_halo_list(
Expand Down Expand Up @@ -716,6 +715,9 @@ def run_coeval(
z_halos = []
hbox_arr = []
for iz, z in enumerate(node_redshifts):
logger.info(
f"Computing Redshift {z} ({iz + 1}/{len(node_redshifts)}) iterations."
)
pf2 = perturbed_field[iz]
pf2.load_all()

Expand Down Expand Up @@ -755,7 +757,7 @@ def run_coeval(

ib2 = sf.compute_ionization_field(
redshift=z,
previous_ionize_box=ib,
previous_ionized_box=ib,
perturbed_field=pf2,
# perturb field *not* interpolated here.
previous_perturbed_field=pf,
Expand Down Expand Up @@ -797,7 +799,6 @@ def run_coeval(
)

bt[out_redshifts.index(z)] = _bt

else:
ib = ib2
pf = pf2
Expand All @@ -806,17 +807,17 @@ def run_coeval(
st = st2

perturb_files.append((z, os.path.join(direc, pf2.filename)))
if inputs.flag_options.USE_HALO_FIELD:
if flag_options.USE_HALO_FIELD:
hbox_files.append((z, os.path.join(direc, hb2.filename)))
pth_files.append((z, os.path.join(direc, ph2.filename)))
if inputs.flag_options.USE_TS_FLUCT:
if flag_options.USE_TS_FLUCT:
spin_temp_files.append((z, os.path.join(direc, st2.filename)))
ionize_files.append((z, os.path.join(direc, ib2.filename)))

if _bt is not None:
brightness_files.append((z, os.path.join(direc, _bt.filename)))

if inputs.flag_options.PHOTON_CONS_TYPE == "z-photoncons":
if flag_options.PHOTON_CONS_TYPE == "z-photoncons":
photon_nonconservation_data = _get_photon_nonconservation_data()

if lib.photon_cons_allocated:
Expand Down
Loading

0 comments on commit 8d0b758

Please sign in to comment.