From 8d0b758966c6a9b208a0d8309db3c3e1ef2eeabb Mon Sep 17 00:00:00 2001 From: James Davies Date: Thu, 24 Oct 2024 22:44:47 +0200 Subject: [PATCH] WIP fixing tests and CLI --- src/py21cmfast/__init__.py | 1 + src/py21cmfast/cli.py | 113 +++++------ src/py21cmfast/drivers/coeval.py | 41 ++-- src/py21cmfast/drivers/lightcone.py | 75 +++---- src/py21cmfast/drivers/param_config.py | 6 +- src/py21cmfast/drivers/single_field.py | 10 +- src/py21cmfast/src/PerturbHaloField.c | 2 +- src/py21cmfast/wrapper/inputs.py | 2 +- src/py21cmfast/wrapper/structs.py | 33 ++- tests/conftest.py | 61 +++++- tests/produce_integration_test_data.py | 49 +++-- tests/test_c_interpolation_tables.py | 267 ++++++++++++------------- tests/test_cli.py | 9 +- tests/test_config.py | 2 +- tests/test_exceptions.py | 24 ++- tests/test_filtering.py | 16 +- tests/test_halo_sampler.py | 237 ++++++---------------- tests/test_high_level_io.py | 27 ++- tests/test_initial_conditions.py | 16 +- tests/test_input_structs.py | 16 +- tests/test_output_structs.py | 38 ++-- tests/test_wrapper.py | 199 +++++++++--------- 22 files changed, 616 insertions(+), 628 deletions(-) diff --git a/src/py21cmfast/__init__.py b/src/py21cmfast/__init__.py index 792748ce..d7a7bbde 100644 --- a/src/py21cmfast/__init__.py +++ b/src/py21cmfast/__init__.py @@ -24,6 +24,7 @@ from .cache_tools import query_cache from .drivers.coeval import Coeval, run_coeval from .drivers.lightcone import LightCone, run_lightcone +from .drivers.param_config import InputParameters from .drivers.single_field import ( brightness_temperature, compute_halo_grid, diff --git a/src/py21cmfast/cli.py b/src/py21cmfast/cli.py index 7b3f6c24..3259909d 100644 --- a/src/py21cmfast/cli.py +++ b/src/py21cmfast/cli.py @@ -1,5 +1,6 @@ """Module that contains the command line app.""" +import attrs import builtins import click import inspect @@ -11,10 +12,19 @@ from astropy import units as un from os import path, remove from pathlib import Path +from wrapper._utils import camel_to_snake from . import _cfg, cache_tools, global_params, plotting -from . import wrapper as lib +from .drivers.coeval import run_coeval +from .drivers.lightcone import run_lightcone +from .drivers.single_field import ( + compute_initial_conditions, + compute_ionization_field, + perturb_field, + spin_temperature, +) from .lightcones import RectilinearLightconer +from .wrapper.inputs import AstroParams, CosmoParams, FlagOptions, UserParams def _get_config(config=None): @@ -62,18 +72,27 @@ def _update(obj, ctx): pass -def _override(ctx, *param_dicts): +def _get_params_from_ctx(ctx, cfg): # Try to use the extra arguments as an override of config. + ctx = _ctx_to_dct(ctx.args) if ctx.args else {} + params = {} + for cls in (UserParams, CosmoParams, AstroParams, FlagOptions): + fieldnames = [camel_to_snake(field.name) for field in attrs.fields(cls)] + ctx_params = {k: v for k, v in ctx.items() if k in fieldnames} + [ctx.pop(k) for k in ctx_params.keys()] + params[camel_to_snake(cls.__name__)] = ctx_params + + user_params = UserParams.new(params["user_params"]) + cosmo_params = CosmoParams.new(params["cosmo_params"]) + flag_options = FlagOptions.new(params["flag_options"]) + astro_params = AstroParams.new(params["astro_params"], flag_options=flag_options) - if ctx.args: - ctx = _ctx_to_dct(ctx.args) - for p in param_dicts: - _update(p, ctx) + # Also update globals, always. + _update(global_params, ctx) + if ctx: + warnings.warn("The following arguments were not able to be set: %s" % ctx) - # Also update globals, always. - _update(global_params, ctx) - if ctx: - warnings.warn("The following arguments were not able to be set: %s" % ctx) + return user_params, cosmo_params, astro_params, flag_options main = click.Group() @@ -126,14 +145,10 @@ def init(ctx, config, regen, direc, seed): Random seed used to generate data. """ cfg = _get_config(config) - # Set user/cosmo params from config. - user_params = lib.UserParams(**cfg.get("user_params", {})) - cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {})) - - _override(ctx, user_params, cosmo_params) + user_params, cosmo_params, _, _ = _get_params_from_ctx(ctx, cfg) - lib.initial_conditions( + compute_initial_conditions( user_params=user_params, cosmo_params=cosmo_params, regenerate=regen, @@ -193,14 +208,10 @@ def perturb(ctx, redshift, config, regen, direc, seed): Random seed used to generate data. """ cfg = _get_config(config) - # Set user/cosmo params from config. - user_params = lib.UserParams(**cfg.get("user_params", {})) - cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {})) - - _override(ctx, user_params, cosmo_params) + user_params, cosmo_params, _, _ = _get_params_from_ctx(ctx, cfg) - lib.perturb_field( + perturb_field( redshift=redshift, user_params=user_params, cosmo_params=cosmo_params, @@ -271,17 +282,12 @@ def spin(ctx, redshift, prev_z, config, regen, direc, seed): """ cfg = _get_config(config) - # Set user/cosmo params from config. - user_params = lib.UserParams(**cfg.get("user_params", {})) - cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {})) - flag_options = lib.FlagOptions(**cfg.get("flag_options", {})) - astro_params = lib.AstroParams( - **cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO + # Set params from config + user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx( + ctx, cfg ) - _override(ctx, user_params, cosmo_params, astro_params, flag_options) - - lib.spin_temperature( + spin_temperature( redshift=redshift, astro_params=astro_params, flag_options=flag_options, @@ -355,19 +361,12 @@ def ionize(ctx, redshift, prev_z, config, regen, direc, seed): """ cfg = _get_config(config) - # Set user/cosmo params from config. - user_params = lib.UserParams(**cfg.get("user_params", {})) - cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {})) - flag_options = lib.FlagOptions( - **cfg.get("flag_options", {}), - ) - astro_params = lib.AstroParams( - **cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO + # Set params from config + user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx( + ctx, cfg ) - _override(ctx, user_params, cosmo_params, astro_params, flag_options) - - lib.ionize_box( + compute_ionization_field( redshift=redshift, astro_params=astro_params, flag_options=flag_options, @@ -450,17 +449,12 @@ def coeval(ctx, redshift, config, out, regen, direc, seed): cfg = _get_config(config) - # Set user/cosmo params from config. - user_params = lib.UserParams(**cfg.get("user_params", {})) - cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {})) - flag_options = lib.FlagOptions(**cfg.get("flag_options", {})) - astro_params = lib.AstroParams( - **cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO + # Set params from config + user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx( + ctx, cfg ) - _override(ctx, user_params, cosmo_params, astro_params, flag_options) - - coeval = lib.run_coeval( + coeval = run_coeval( redshift=redshift, astro_params=astro_params, flag_options=flag_options, @@ -565,16 +559,11 @@ def lightcone(ctx, redshift, config, out, regen, direc, max_z, seed, lq): elif not out.parent.exists(): out.parent.mkdir() - # Set user/cosmo params from config. - user_params = lib.UserParams(**cfg.get("user_params", {})) - cosmo_params = lib.CosmoParams(**cfg.get("cosmo_params", {})) - flag_options = lib.FlagOptions(**cfg.get("flag_options", {})) - astro_params = lib.AstroParams( - **cfg.get("astro_params", {}), INHOMO_RECO=flag_options.INHOMO_RECO + # Set params from config + user_params, cosmo_params, astro_params, flag_options = _get_params_from_ctx( + ctx, cfg ) - _override(ctx, user_params, cosmo_params, astro_params, flag_options) - # For now, always use the old default lightconing algorithm lcn = RectilinearLightconer.with_equal_cdist_slices( min_redshift=redshift, @@ -584,7 +573,7 @@ def lightcone(ctx, redshift, config, out, regen, direc, max_z, seed, lq): quantities=lq, ) - lc = lib.run_lightcone( + lc = run_lightcone( lightconer=lcn, astro_params=astro_params, flag_options=flag_options, @@ -780,7 +769,7 @@ def pr_feature( if lightcone: print("Running default lightcone...") - lc_default = lib.run_lightcone( + lc_default = run_lightcone( redshift=redshift, max_redshift=max_redshift, random_seed=random_seed, @@ -790,7 +779,7 @@ def pr_feature( structs[struct][param] = value print("Running lightcone with new feature...") - lc_new = lib.run_lightcone( + lc_new = run_lightcone( redshift=redshift, max_redshift=max_redshift, random_seed=random_seed, diff --git a/src/py21cmfast/drivers/coeval.py b/src/py21cmfast/drivers/coeval.py index d915dbbd..668d9efc 100644 --- a/src/py21cmfast/drivers/coeval.py +++ b/src/py21cmfast/drivers/coeval.py @@ -10,9 +10,9 @@ from typing import Any, Sequence from .. import __version__ -from .. import utils as _ut from .._cfg import config from ..c_21cmfast import lib +from ..wrapper._utils import camel_to_snake from ..wrapper.globals import global_params from ..wrapper.inputs import AstroParams, CosmoParams, FlagOptions, UserParams from ..wrapper.outputs import ( @@ -456,9 +456,9 @@ def _read_particular(cls, fname): kwargs = {} with h5py.File(fname, "r") as fl: - for output_class in _ut.OutputStruct._implementations(): + for output_class in _OutputStruct._implementations(): if output_class.__name__ in fl: - kwargs[_ut.camel_to_snake(output_class.__name__)] = ( + kwargs[camel_to_snake(output_class.__name__)] = ( output_class.from_file(fname) ) @@ -468,6 +468,7 @@ def __eq__(self, other): """Determine if this is equal to another object.""" return ( isinstance(other, self.__class__) + and other.random_seed == self.random_seed and other.redshift == self.redshift and self.user_params == other.user_params and self.cosmo_params == other.cosmo_params @@ -583,20 +584,18 @@ def run_coeval( else random_seed ) - inputs = InputParameters( - random_seed=random_seed, - user_params=user_params, - cosmo_params=cosmo_params, - astro_params=astro_params, - flag_options=flag_options, - ) + # For the high-level, we need all the InputStruct initialised + cosmo_params = CosmoParams.new(cosmo_params) + user_params = UserParams.new(user_params) + flag_options = FlagOptions.new(flag_options) + astro_params = AstroParams.new(astro_params, flag_options=flag_options) iokw = {"regenerate": regenerate, "hooks": hooks, "direc": direc} if initial_conditions is None: initial_conditions = sf.compute_initial_conditions( - user_params=inputs.user_params, - cosmo_params=inputs.cosmo_params, + user_params=user_params, + cosmo_params=cosmo_params, random_seed=random_seed, **iokw, ) @@ -617,8 +616,8 @@ def run_coeval( kw = { **{ - "astro_params": inputs.astro_params, - "flag_options": inputs.flag_options, + "astro_params": astro_params, + "flag_options": flag_options, "initial_conditions": initial_conditions, }, **iokw, @@ -662,7 +661,7 @@ def run_coeval( ) # get the halos (reverse redshift order) pt_halos = [] - if inputs.flag_options.USE_HALO_FIELD and not inputs.flag_options.FIXED_HALO_GRIDS: + if flag_options.USE_HALO_FIELD and not flag_options.FIXED_HALO_GRIDS: halos_desc = None for i, z in enumerate(node_redshifts[::-1]): halos = sf.determine_halo_list( @@ -716,6 +715,9 @@ def run_coeval( z_halos = [] hbox_arr = [] for iz, z in enumerate(node_redshifts): + logger.info( + f"Computing Redshift {z} ({iz + 1}/{len(node_redshifts)}) iterations." + ) pf2 = perturbed_field[iz] pf2.load_all() @@ -755,7 +757,7 @@ def run_coeval( ib2 = sf.compute_ionization_field( redshift=z, - previous_ionize_box=ib, + previous_ionized_box=ib, perturbed_field=pf2, # perturb field *not* interpolated here. previous_perturbed_field=pf, @@ -797,7 +799,6 @@ def run_coeval( ) bt[out_redshifts.index(z)] = _bt - else: ib = ib2 pf = pf2 @@ -806,17 +807,17 @@ def run_coeval( st = st2 perturb_files.append((z, os.path.join(direc, pf2.filename))) - if inputs.flag_options.USE_HALO_FIELD: + if flag_options.USE_HALO_FIELD: hbox_files.append((z, os.path.join(direc, hb2.filename))) pth_files.append((z, os.path.join(direc, ph2.filename))) - if inputs.flag_options.USE_TS_FLUCT: + if flag_options.USE_TS_FLUCT: spin_temp_files.append((z, os.path.join(direc, st2.filename))) ionize_files.append((z, os.path.join(direc, ib2.filename))) if _bt is not None: brightness_files.append((z, os.path.join(direc, _bt.filename))) - if inputs.flag_options.PHOTON_CONS_TYPE == "z-photoncons": + if flag_options.PHOTON_CONS_TYPE == "z-photoncons": photon_nonconservation_data = _get_photon_nonconservation_data() if lib.photon_cons_allocated: diff --git a/src/py21cmfast/drivers/lightcone.py b/src/py21cmfast/drivers/lightcone.py index fee8f11b..3af6a2bb 100644 --- a/src/py21cmfast/drivers/lightcone.py +++ b/src/py21cmfast/drivers/lightcone.py @@ -222,7 +222,12 @@ def __eq__(self, other): """Determine if this is equal to another object.""" return ( isinstance(other, self.__class__) - and other.redshift == self.redshift + and other.random_seed == self.random_seed + and np.all( + np.isclose( + other.lightcone_redshifts, self.lightcone_redshifts, atol=1e-3 + ) + ) and np.all(np.isclose(other.node_redshifts, self.node_redshifts, atol=1e-3)) and self.user_params == other.user_params and self.cosmo_params == other.cosmo_params @@ -324,8 +329,8 @@ def _run_lightcone_from_perturbed_fields( initial_conditions: InitialConditions, perturbed_fields: Sequence[PerturbedField], lightconer: Lightconer, - astro_params: AstroParams | None = None, - flag_options: FlagOptions | None = None, + astro_params: AstroParams, + flag_options: FlagOptions, regenerate: bool | None = None, global_quantities: tuple[str] = ("brightness_temp", "xH_box"), direc: Path | str | None = None, @@ -529,14 +534,14 @@ def _run_lightcone_from_perturbed_fields( kw = { **{ "initial_conditions": initial_conditions, - "astro_params": astro_params, - "flag_options": flag_options, + "astro_params": inputs.astro_params, + "flag_options": inputs.flag_options, }, **iokw, } photon_nonconservation_data = None - if flag_options.PHOTON_CONS_TYPE != "no-photoncons": + if inputs.flag_options.PHOTON_CONS_TYPE != "no-photoncons": setup_photon_cons(**kw) # At first we don't have any "previous" fields. @@ -562,7 +567,7 @@ def _run_lightcone_from_perturbed_fields( astro_params=inputs.astro_params, ) try: - st = cached_boxes["TsBox"][0] if flag_options.USE_TS_FLUCT else None + st = cached_boxes["TsBox"][0] if inputs.flag_options.USE_TS_FLUCT else None pf = cached_boxes["PerturbedField"][0] ib = cached_boxes["IonizedBox"][0] except (KeyError, IndexError): @@ -575,12 +580,12 @@ def _run_lightcone_from_perturbed_fields( # Now we can purge init_box further. with contextlib.suppress(OSError): initial_conditions.prepare_for_halos( - flag_options=flag_options, force=always_purge + flag_options=inputs.flag_options, force=always_purge ) # we explicitly pass the descendant halos here since we have a redshift list prior # this will generate the extra fields if STOC_MINIMUM_Z is given pt_halos = [] - if flag_options.USE_HALO_FIELD and not flag_options.FIXED_HALO_GRIDS: + if inputs.flag_options.USE_HALO_FIELD and not inputs.flag_options.FIXED_HALO_GRIDS: halos_desc = None for iz, z in enumerate(scrollz[::-1]): halo_field = sf.determine_halo_list( @@ -600,7 +605,7 @@ def _run_lightcone_from_perturbed_fields( # Now that we've got all the perturb fields, we can purge init more. with contextlib.suppress(OSError): initial_conditions.prepare_for_spin_temp( - flag_options=flag_options, force=always_purge + flag_options=inputs.flag_options, force=always_purge ) # arrays to hold cache filenames @@ -635,8 +640,8 @@ def _run_lightcone_from_perturbed_fields( # This ensures that all the arrays that are required for spin_temp are there, # in case we dumped them from memory into file. pf2.load_all() - if flag_options.USE_HALO_FIELD: - if not flag_options.FIXED_HALO_GRIDS: + if inputs.flag_options.USE_HALO_FIELD: + if not inputs.flag_options.FIXED_HALO_GRIDS: ph2 = pt_halos[iz] ph2.load_all() @@ -649,7 +654,7 @@ def _run_lightcone_from_perturbed_fields( **kw, ) - if flag_options.USE_TS_FLUCT: + if inputs.flag_options.USE_TS_FLUCT: z_halos.append(z) hboxes.append(hbox2) xrs = sf.compute_xray_source_field( @@ -659,7 +664,7 @@ def _run_lightcone_from_perturbed_fields( **kw, ) - if flag_options.USE_TS_FLUCT: + if inputs.flag_options.USE_TS_FLUCT: st2 = sf.spin_temperature( redshift=z, previous_spin_temp=st, @@ -702,10 +707,13 @@ def _run_lightcone_from_perturbed_fields( ) perturb_files.append((z, direc / pf2.filename)) - if flag_options.USE_HALO_FIELD and not flag_options.FIXED_HALO_GRIDS: + if ( + inputs.flag_options.USE_HALO_FIELD + and not inputs.flag_options.FIXED_HALO_GRIDS + ): hbox_files.append((z, direc / hbox2.filename)) pth_files.append((z, direc / ph2.filename)) - if flag_options.USE_TS_FLUCT: + if inputs.flag_options.USE_TS_FLUCT: spin_temp_files.append((z, direc / st2.filename)) ionize_files.append((z, direc / ib2.filename)) brightness_files.append((z, direc / bt2.filename)) @@ -730,14 +738,7 @@ def _run_lightcone_from_perturbed_fields( lightcone_filename, redshift=z, index=lc_index ) - # Save current ones as old ones. - if flag_options.USE_TS_FLUCT: - st = st2 - ib = ib2 - if flag_options.USE_MINI_HALOS: - pf = pf2 - prev_coeval = coeval - + # purge arrays we don't need if pf is not None: with contextlib.suppress(OSError): pf.purge(force=always_purge) @@ -756,12 +757,17 @@ def _run_lightcone_from_perturbed_fields( ], force=always_purge, ) + + # Save current ones as old ones. pf = pf2 hbox = hbox2 + st = st2 + ib = ib2 + prev_coeval = coeval # last redshift things if iz == len(scrollz) - 1: - if flag_options.PHOTON_CONS_TYPE == "z-photoncons": + if inputs.flag_options.PHOTON_CONS_TYPE == "z-photoncons": photon_nonconservation_data = _get_photon_nonconservation_data() if lib.photon_cons_allocated: @@ -770,7 +776,7 @@ def _run_lightcone_from_perturbed_fields( lightcone.photon_nonconservation_data = photon_nonconservation_data if isinstance(lightcone, AngularLightcone) and lightconer.get_los_velocity: lightcone.compute_rsds( - fname=lightcone_filename, n_subcells=astro_params.N_RSD_STEPS + fname=lightcone_filename, n_subcells=inputs.astro_params.N_RSD_STEPS ) # Append some info to the lightcone before we return @@ -918,14 +924,11 @@ def run_lightcone( if cosmo_params is None and initial_conditions is None: cosmo_params = CosmoParams.from_astropy(lightconer.cosmo) - inputs = InputParameters.from_output_structs( - (initial_conditions, *perturbed_fields), - cosmo_params=cosmo_params, - user_params=user_params, - astro_params=astro_params, - flag_options=flag_options, - redshift=None, - ) + # For the high-level, we need all the InputStruct initialised + cosmo_params = CosmoParams.new(cosmo_params) + user_params = UserParams.new(user_params) + flag_options = AstroParams.new(flag_options) + astro_params = AstroParams.new(astro_params, flag_options=flag_options) if pf_given: node_redshifts = [pf.redshift for pf in perturbed_fields] @@ -984,8 +987,8 @@ def run_lightcone( initial_conditions=initial_conditions, perturbed_fields=perturbed_fields, lightconer=lightconer, - astro_params=inputs.astro_params, - flag_options=inputs.flag_options, + astro_params=astro_params, + flag_options=flag_options, regenerate=regenerate, global_quantities=global_quantities, direc=direc, diff --git a/src/py21cmfast/drivers/param_config.py b/src/py21cmfast/drivers/param_config.py index 39efb2fd..39f7567f 100644 --- a/src/py21cmfast/drivers/param_config.py +++ b/src/py21cmfast/drivers/param_config.py @@ -212,11 +212,9 @@ def __repr__(self): + f"flag_options: {repr(self.flag_options)}\n" ) + # TODO: methods for equality: w/o redshift, w/o seed + -# TODO: In order to fully combine this with the other paramter config, we -# need to pass in a boolean sequence to InputParamters.from_output_structs -# marking structs (previous, initial) as exempt from redshift comparison. -# This would make .merge and .is_compatible ignore their redshifts def check_redshift_consistency(inputs: InputParameters, output_structs): """Check the redshifts between provided OutputStruct objects and an InputParamters instance.""" for struct in output_structs: diff --git a/src/py21cmfast/drivers/single_field.py b/src/py21cmfast/drivers/single_field.py index 2f49bb1d..63b8e9fb 100644 --- a/src/py21cmfast/drivers/single_field.py +++ b/src/py21cmfast/drivers/single_field.py @@ -1136,8 +1136,14 @@ def spin_temperature( "that being evaluated." ) - if inputs.flag_options.USE_HALO_FIELD and xray_source_box is None: - raise ValueError("xray_source_box is required when USE_HALO_FIELD is True") + if xray_source_box is None: + if inputs.flag_options.USE_HALO_FIELD: + raise ValueError("xray_source_box is required when USE_HALO_FIELD is True") + else: + xray_source_box = XraySourceBox( + inputs=inputs.evolve(redshift=0.0), + dummy=True, + ) # Set up the box without computing anything. box = TsBox( diff --git a/src/py21cmfast/src/PerturbHaloField.c b/src/py21cmfast/src/PerturbHaloField.c index 62613737..5cb0e301 100644 --- a/src/py21cmfast/src/PerturbHaloField.c +++ b/src/py21cmfast/src/PerturbHaloField.c @@ -31,7 +31,7 @@ int ComputePerturbHaloField(float redshift, UserParams *user_params, CosmoParams LOG_DEBUG("input value:"); LOG_DEBUG("redshift=%f", redshift); -#if LOG_LEVEL >= DEBUG_LEVEL +#if LOG_LEVEL >= SUPER_DEBUG_LEVEL writeUserParams(user_params); writeCosmoParams(cosmo_params); writeAstroParams(flag_options, astro_params); diff --git a/src/py21cmfast/wrapper/inputs.py b/src/py21cmfast/wrapper/inputs.py index df9b9812..360c5856 100644 --- a/src/py21cmfast/wrapper/inputs.py +++ b/src/py21cmfast/wrapper/inputs.py @@ -141,7 +141,7 @@ def from_astropy(cls, cosmo: FLRW, **kwargs): values. """ return cls( - hlittle=cosmo.h, OMm=cosmo.Om0, OMb=cosmo.Ob0, _base_cosmo=cosmo, **kwargs + hlittle=cosmo.h, OMm=cosmo.Om0, OMb=cosmo.Ob0, base_cosmo=cosmo, **kwargs ) diff --git a/src/py21cmfast/wrapper/structs.py b/src/py21cmfast/wrapper/structs.py index 7cfcdf5e..603856b4 100644 --- a/src/py21cmfast/wrapper/structs.py +++ b/src/py21cmfast/wrapper/structs.py @@ -117,7 +117,7 @@ class InputStruct: _write_exclude_fields = () @classmethod - def new(cls, x: dict | InputStruct | None): + def new(cls, x: dict | InputStruct | None, **kwargs): """ Create a new instance of the struct. @@ -130,11 +130,11 @@ def new(cls, x: dict | InputStruct | None): struct will be initialised with default values. """ if isinstance(x, dict): - return cls(**x) + return cls(**x, **kwargs) elif isinstance(x, InputStruct): return x elif x is None: - return cls() + return cls(**kwargs) else: raise ValueError( f"Cannot instantiate {cls.__name__} with type {x.__class__}" @@ -266,6 +266,10 @@ def __init__(self, *, dummy=False, initial=False, **kwargs): raise KeyError( f"{self.__class__.__name__} requires the keyword argument {k}" ) from e + if getattr(self, k) is None: + raise KeyError( + f"{self.__class__.__name__} has required input {k} == None" + ) if kwargs: warnings.warn( @@ -773,7 +777,7 @@ def read( self, direc: str | Path | None = None, fname: str | Path | None | h5py.File | h5py.Group = None, - keys: Sequence[str] | None = None, + keys: Sequence[str] | None = (), ): """ Try find and read existing boxes from cache, which match the parameters of this instance. @@ -880,16 +884,9 @@ def from_file( fname = direc / fname with h5py.File(fname, "r") as fl: - if h5_group is not None: - self = cls(**cls._read_inputs(fl[h5_group])) - else: - self = cls(**cls._read_inputs(fl)) - - if h5_group is not None: - with h5py.File(fname, "r") as fl: - self.read(fname=fl[h5_group], keys=arrays_to_load) - else: - self.read(fname=fname, keys=arrays_to_load) + fl_inp = fl[h5_group] if h5_group else fl + self = cls(**cls._read_inputs(fl_inp)) + self.read(fname=fl_inp, keys=arrays_to_load) return self @@ -899,7 +896,6 @@ def _read_inputs(cls, grp: h5py.File | h5py.Group): # Read the input parameter dictionaries from file. kwargs = {} - inputstructs = {} for k in cls._inputs: kfile = k.lstrip("_") input_class_name = snake_to_camel(kfile) @@ -909,14 +905,11 @@ def _read_inputs(cls, grp: h5py.File | h5py.Group): input_classes.index(input_class_name) ] subgrp = grp[kfile] - logger.info( - {k: v for k, v in dict(subgrp.attrs).items() if v != "none"} - ) - inputstructs[k] = input_class.new( + kwargs[k] = input_class.new( {k: v for k, v in dict(subgrp.attrs).items() if v != "none"} ) else: - kwargs[kfile] = grp.attrs[kfile] + kwargs[k] = grp.attrs[kfile] return kwargs def __repr__(self): diff --git a/tests/conftest.py b/tests/conftest.py index 564f609a..44539658 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,20 @@ import logging import os from astropy import units as un - -from py21cmfast import UserParams, config, global_params, run_lightcone, wrapper +from collections import deque + +from py21cmfast import ( + AstroParams, + CosmoParams, + FlagOptions, + InputParameters, + UserParams, + compute_initial_conditions, + config, + global_params, + perturb_field, + run_lightcone, +) from py21cmfast.cache_tools import clear_cache from py21cmfast.lightcones import RectilinearLightconer @@ -109,9 +121,42 @@ def default_user_params(): return UserParams(HII_DIM=35, DIM=70, BOX_LEN=50, KEEP_3D_VELOCITIES=True) +@pytest.fixture(scope="session") +def default_flag_options(): + return FlagOptions( + USE_HALO_FIELD=False, + USE_EXP_FILTER=False, + CELL_RECOMB=False, + HALO_STOCHASTICITY=False, + ) + + +@pytest.fixture(scope="session") +def default_input_struct(default_user_params, default_flag_options): + return InputParameters( + redshift=10.0, + random_seed=1, + cosmo_params=CosmoParams.new(), + astro_params=AstroParams.new(default_flag_options), + user_params=default_user_params, + flag_options=default_flag_options, + ) + + +@pytest.fixture(scope="session") +def default_flag_options_ts(): + return FlagOptions( + USE_HALO_FIELD=False, + USE_EXP_FILTER=False, + CELL_RECOMB=False, + HALO_STOCHASTICITY=False, + USE_TS_FLUCT=True, + ) + + @pytest.fixture(scope="session") def ic(default_user_params, tmpdirec): - return wrapper.initial_conditions( + return compute_initial_conditions( user_params=default_user_params, write=True, direc=tmpdirec, random_seed=12 ) @@ -137,7 +182,7 @@ def low_redshift(): @pytest.fixture(scope="session") def perturbed_field(ic, redshift): """A default perturb_field""" - return wrapper.perturb_field(redshift=redshift, init_boxes=ic, write=True) + return perturb_field(redshift=redshift, initial_conditions=ic, write=True) @pytest.fixture(scope="session") @@ -151,5 +196,9 @@ def rectlcn(perturbed_field, max_redshift) -> RectilinearLightconer: @pytest.fixture(scope="session") -def lc(perturbed_field, rectlcn): - return run_lightcone(lightconer=rectlcn, perturb=perturbed_field) +def lc(perturbed_field, rectlcn, default_flag_options): + lc_gen = run_lightcone( + lightconer=rectlcn, perturb=perturbed_field, flag_options=default_flag_options + ) + iz, z, coev, lc = deque(lc_gen, maxlen=1) + return lc diff --git a/tests/produce_integration_test_data.py b/tests/produce_integration_test_data.py index b785abae..d62e7f04 100644 --- a/tests/produce_integration_test_data.py +++ b/tests/produce_integration_test_data.py @@ -7,6 +7,7 @@ fail at the tens-of-percent level. """ +import attrs import click import glob import h5py @@ -25,10 +26,10 @@ FlagOptions, InitialConditions, UserParams, + compute_initial_conditions, config, determine_halo_list, global_params, - initial_conditions, perturb_field, perturb_halo_list, run_coeval, @@ -41,12 +42,15 @@ SEED = 12345 DATA_PATH = Path(__file__).parent / "test_data" + +# NOTE: Since this is called in `evolve()` AFTER the OPTIONS kwargs, +# These should only contain dimensions, which don't show up in the +# OPTIONS dicts DEFAULT_USER_PARAMS = { "HII_DIM": 50, "DIM": 150, "BOX_LEN": 100, "NO_RNG": True, - "USE_INTERPOLATION_TABLES": True, } DEFAULT_ZPRIME_STEP_FACTOR = 1.04 @@ -274,23 +278,32 @@ raise ValueError("There is a non-unique option_halo name!") -def get_defaults(kwargs, cls): - return {k: kwargs.get(k, v) for k, v in cls._defaults_.items()} +def get_input_struct(kwargs, cls): + fieldnames = [field.name.lstrip("_") for field in attrs.fields(cls)] + subdict = {k: v for (k, v) in kwargs.items() if k in fieldnames} + return cls.new(subdict) + +def get_all_input_structs(kwargs): + flag_options = get_input_struct(kwargs, FlagOptions) + cosmo_params = get_input_struct(kwargs, CosmoParams) + user_params = get_input_struct(kwargs, UserParams) -def get_all_defaults(kwargs): - flag_options = get_defaults(kwargs, FlagOptions) - astro_params = get_defaults(kwargs, AstroParams) - cosmo_params = get_defaults(kwargs, CosmoParams) - user_params = get_defaults(kwargs, UserParams) + kwargs_a = kwargs.copy() + kwargs_a.update({"flag_options": flag_options}) + logger.info(kwargs_a) + astro_params = get_input_struct(kwargs_a, AstroParams) return user_params, cosmo_params, astro_params, flag_options def get_all_options(redshift, **kwargs): - user_params, cosmo_params, astro_params, flag_options = get_all_defaults(kwargs) - user_params.update(DEFAULT_USER_PARAMS) + user_params, cosmo_params, astro_params, flag_options = get_all_input_structs( + kwargs + ) + user_params = attrs.evolve(user_params, **DEFAULT_USER_PARAMS) + out = { - "redshift": redshift, + "out_redshifts": redshift, "user_params": user_params, "cosmo_params": cosmo_params, "astro_params": astro_params, @@ -306,7 +319,9 @@ def get_all_options(redshift, **kwargs): def get_all_options_ics(**kwargs): - user_params, cosmo_params, astro_params, flag_options = get_all_defaults(kwargs) + user_params, cosmo_params, astro_params, flag_options = get_all_input_structs( + kwargs + ) user_params.update(DEFAULT_USER_PARAMS) out = { "user_params": user_params, @@ -321,10 +336,12 @@ def get_all_options_ics(**kwargs): def get_all_options_halo(redshift, **kwargs): - user_params, cosmo_params, astro_params, flag_options = get_all_defaults(kwargs) + user_params, cosmo_params, astro_params, flag_options = get_all_input_structs( + kwargs + ) user_params.update(DEFAULT_USER_PARAMS) out = { - "redshift": redshift, + "out_redshifts": redshift, "user_params": user_params, "cosmo_params": cosmo_params, "astro_params": astro_params, @@ -411,7 +428,7 @@ def produce_perturb_field_data(redshift, **kwargs): velocity_normalisation = 1e16 with config.use(regenerate=True, write=False): - init_box = initial_conditions(**options_ics) + init_box = compute_initial_conditions(**options_ics) pt_box = perturb_field(redshift=redshift, init_boxes=init_box, **out) p_dens, k_dens = get_power( diff --git a/tests/test_c_interpolation_tables.py b/tests/test_c_interpolation_tables.py index ae085ffb..ce8428b4 100644 --- a/tests/test_c_interpolation_tables.py +++ b/tests/test_c_interpolation_tables.py @@ -1,5 +1,6 @@ import pytest +import attrs import matplotlib as mpl import numpy as np from astropy import constants as c @@ -21,26 +22,26 @@ RELATIVE_TOLERANCE = 2e-2 OPTIONS_PS = { - "EH": [10, {"POWER_SPECTRUM": 0}], - "BBKS": [10, {"POWER_SPECTRUM": 1}], - "BE": [10, {"POWER_SPECTRUM": 2}], - "Peebles": [10, {"POWER_SPECTRUM": 3}], - "White": [10, {"POWER_SPECTRUM": 4}], - "CLASS": [10, {"POWER_SPECTRUM": 5}], + "EH": [10, {"POWER_SPECTRUM": "EH"}], + "BBKS": [10, {"POWER_SPECTRUM": "BBKS"}], + "BE": [10, {"POWER_SPECTRUM": "EFSTATHIOU"}], + "Peebles": [10, {"POWER_SPECTRUM": "PEEBLES"}], + "White": [10, {"POWER_SPECTRUM": "WHITE"}], + "CLASS": [10, {"POWER_SPECTRUM": "CLASS"}], } OPTIONS_HMF = { - "PS": [10, {"HMF": 0}], - "ST": [10, {"HMF": 1}], - # "Watson": [10, {"HMF": 2}], - # "Watsonz": [10, {"HMF": 3}], - # "Delos": [10, {"HMF": 4}], + "PS": [10, {"HMF": "PS"}], + "ST": [10, {"HMF": "ST"}], + # "Watson": [10, {"HMF": "WATSON"}], + # "Watsonz": [10, {"HMF": "WATSON-Z"}], + # "Delos": [10, {"HMF": "DELOS"}], } OPTIONS_INTMETHOD = { - "QAG": 0, - "GL": 1, - "FFCOLL": 2, + "QAG": "GSL-QAG", + "GL": "GAUSS-LEGENDRE", + "FFCOLL": "GAMMA-APPROX", } R_PARAM_LIST = [1.5, 5, 10, 30, 60] @@ -64,10 +65,9 @@ def test_sigma_table(name, plt): redshift, kwargs = OPTIONS_PS[name] opts = prd.get_all_options(redshift, **kwargs) - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - up.update(USE_INTERPOLATION_TABLES=True) - lib.Broadcast_struct_global_noastro(up(), cp()) + up = opts["user_params"] + cp = opts["cosmo_params"] + lib.Broadcast_struct_global_noastro(up.cstruct, cp.cstruct) lib.init_ps() lib.initialiseSigmaMInterpTable( @@ -104,17 +104,16 @@ def test_inverse_cmf_tables(name, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - up.update(USE_INTERPOLATION_TABLES=True) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] hist_size = 1000 edges = np.logspace(7, 12, num=hist_size).astype("f4") edges_ln = np.log(edges) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) lib.init_ps() lib.initialiseSigmaMInterpTable( @@ -136,9 +135,9 @@ def test_inverse_cmf_tables(name, plt): sigma_cond_cell = lib.sigma_z0(cell_mass) sigma_cond_halo = np.vectorize(lib.sigma_z0)(edges) - delta_crit = lib.get_delta_crit(up.HMF, sigma_cond_cell, growth_in) + delta_crit = lib.get_delta_crit(up.cdict["HMF"], sigma_cond_cell, growth_in) delta_update = ( - np.vectorize(lib.get_delta_crit)(up.HMF, sigma_cond_halo, growth_in) + np.vectorize(lib.get_delta_crit)(up.cdict["HMF"], sigma_cond_halo, growth_in) * growth_out / growth_in ) @@ -269,19 +268,16 @@ def test_inverse_cmf_tables(name, plt): def test_Massfunc_conditional_tables(name, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) - - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - up.update(USE_INTERPOLATION_TABLES=True) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) hist_size = 1000 edges = np.logspace(7, 12, num=hist_size).astype("f4") edges_ln = np.log(edges) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) - lib.init_ps() lib.initialiseSigmaMInterpTable( global_params.M_MIN_INTEGRAL, global_params.M_MAX_INTEGRAL @@ -302,9 +298,9 @@ def test_Massfunc_conditional_tables(name, plt): sigma_cond_cell = lib.sigma_z0(cell_mass) sigma_cond_halo = np.vectorize(lib.sigma_z0)(edges) - delta_crit = lib.get_delta_crit(up.HMF, sigma_cond_cell, growth_in) + delta_crit = lib.get_delta_crit(up.cdict["HMF"], sigma_cond_cell, growth_in) delta_update = ( - np.vectorize(lib.get_delta_crit)(up.HMF, sigma_cond_halo, growth_in) + np.vectorize(lib.get_delta_crit)(up.cdict["HMF"], sigma_cond_halo, growth_in) * growth_out / growth_in ) @@ -475,14 +471,11 @@ def test_Massfunc_conditional_tables(name, plt): def test_FgtrM_conditional_tables(name, R, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) - - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - - up.update(USE_INTERPOLATION_TABLES=True) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) hist_size = 1000 M_min = global_params.M_MIN_INTEGRAL @@ -500,7 +493,7 @@ def test_FgtrM_conditional_tables(name, R, plt): .value ) sigma_cond = lib.sigma_z0(cond_mass) - delta_crit = lib.get_delta_crit(up.HMF, sigma_cond, growth_out) + delta_crit = lib.get_delta_crit(up.cdict["HMF"], sigma_cond, growth_out) edges_d = np.linspace(-1, delta_crit * 1.1, num=hist_size).astype( "f4" @@ -518,8 +511,8 @@ def test_FgtrM_conditional_tables(name, R, plt): edges_d[:-1], redshift, sigma_min, sigma_cond ) - up.update(USE_INTERPOLATION_TABLES=False) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + up = attrs.evolve(up, USE_INTERPOLATION_TABLES=False) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) fcoll_integrals = np.vectorize(lib.EvaluateFcoll_delta)( edges_d[:-1], growth_out, sigma_min, sigma_cond @@ -572,22 +565,18 @@ def test_FgtrM_conditional_tables(name, R, plt): def test_SFRD_z_tables(name, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - - up.update( - USE_INTERPOLATION_TABLES=True, - ) - fo.update( + fo = attrs.evolve( + fo, USE_MINI_HALOS=True, - USE_MASS_DEPENDENT_ZETA=True, INHOMO_RECO=True, USE_TS_FLUCT=True, ) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) hist_size = 1000 M_min = global_params.M_MIN_INTEGRAL @@ -701,22 +690,18 @@ def test_SFRD_z_tables(name, plt): def test_Nion_z_tables(name, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - - up.update( - USE_INTERPOLATION_TABLES=True, - ) - fo.update( + fo = attrs.evolve( + fo, USE_MINI_HALOS=True, - USE_MASS_DEPENDENT_ZETA=True, INHOMO_RECO=True, USE_TS_FLUCT=True, ) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) f10s = 10**ap.F_STAR10 f7s = 10**ap.F_STAR7_MINI @@ -855,24 +840,23 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - - up.update( - USE_INTERPOLATION_TABLES=True, + up = attrs.evolve( + up, INTEGRATION_METHOD_ATOMIC=OPTIONS_INTMETHOD[intmethod], INTEGRATION_METHOD_MINI=OPTIONS_INTMETHOD[intmethod], ) - fo.update( + fo = attrs.evolve( + fo, USE_MINI_HALOS=mini_flag, - USE_MASS_DEPENDENT_ZETA=True, INHOMO_RECO=True, USE_TS_FLUCT=True, ) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) hist_size = 1000 M_min = global_params.M_MIN_INTEGRAL @@ -880,7 +864,7 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): lib.init_ps() - if up.INTEGRATION_METHOD_ATOMIC == 1 or up.INTEGRATION_METHOD_MINI == 1: + if "GAUSS-LEGENDRE" in (up.INTEGRATION_METHOD_ATOMIC, up.INTEGRATION_METHOD_MINI): lib.initialise_GL(np.log(M_min), np.log(M_max)) growth_out = lib.dicke(redshift) @@ -890,7 +874,7 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): .value ) sigma_cond = lib.sigma_z0(cond_mass) - delta_crit = lib.get_delta_crit(up.HMF, sigma_cond, growth_out) + delta_crit = lib.get_delta_crit(up.cdict["HMF"], sigma_cond, growth_out) edges_d = np.linspace(-1, delta_crit * 1.1, num=hist_size).astype("f4") edges_m = np.logspace(5, 10, num=int(hist_size / 10)).astype("f4") @@ -925,8 +909,8 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): 10**ap.F_ESC7_MINI, Mlim_Fstar_MINI, Mlim_Fesc_MINI, - up.INTEGRATION_METHOD_ATOMIC, - up.INTEGRATION_METHOD_MINI, + up.cdict["INTEGRATION_METHOD_ATOMIC"], + up.cdict["INTEGRATION_METHOD_MINI"], mini_flag, False, ) @@ -957,7 +941,7 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): 10**ap.F_ESC10, Mlim_Fstar, Mlim_Fesc, - up.INTEGRATION_METHOD_ATOMIC, + up.cdict["INTEGRATION_METHOD_ATOMIC"], ) #### FIRST ASSERT #### @@ -991,7 +975,7 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): 10**ap.F_ESC7_MINI, Mlim_Fstar_MINI, Mlim_Fesc_MINI, - up.INTEGRATION_METHOD_MINI, + up.cdict["INTEGRATION_METHOD_MINI"], ) print_failure_stats( Nion_tables_mini, @@ -1039,26 +1023,26 @@ def test_Nion_conditional_tables(name, R, mini, intmethod, plt): def test_SFRD_conditional_table(name, R, intmethod, plt): if name != "PS" and intmethod == "FFCOLL": pytest.skip("FAST FFCOLL INTEGRALS WORK ONLY WITH EPS") + redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - - up.update( - USE_INTERPOLATION_TABLES=True, + up = attrs.evolve( + up, INTEGRATION_METHOD_ATOMIC=OPTIONS_INTMETHOD[intmethod], INTEGRATION_METHOD_MINI=OPTIONS_INTMETHOD[intmethod], ) - fo.update( + fo = attrs.evolve( + fo, USE_MINI_HALOS=True, - USE_MASS_DEPENDENT_ZETA=True, INHOMO_RECO=True, USE_TS_FLUCT=True, ) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) hist_size = 1000 M_min = global_params.M_MIN_INTEGRAL @@ -1066,7 +1050,7 @@ def test_SFRD_conditional_table(name, R, intmethod, plt): lib.init_ps() - if up.INTEGRATION_METHOD_ATOMIC == 1 or up.INTEGRATION_METHOD_MINI == 1: + if "GAUSS-LEGENDRE" in (up.INTEGRATION_METHOD_ATOMIC, up.INTEGRATION_METHOD_MINI): lib.initialise_GL(np.log(M_min), np.log(M_max)) growth_out = lib.dicke(redshift) @@ -1076,7 +1060,7 @@ def test_SFRD_conditional_table(name, R, intmethod, plt): .value ) sigma_cond = lib.sigma_z0(cond_mass) - delta_crit = lib.get_delta_crit(up.HMF, sigma_cond, growth_out) + delta_crit = lib.get_delta_crit(up.cdict["HMF"], sigma_cond, growth_out) edges_d = np.linspace(-1, delta_crit * 1.1, num=hist_size).astype("f4") edges_m = np.logspace(5, 10, num=int(hist_size / 10)).astype("f4") @@ -1098,8 +1082,8 @@ def test_SFRD_conditional_table(name, R, intmethod, plt): ap.ALPHA_STAR_MINI, 10**ap.F_STAR10, 10**ap.F_STAR7_MINI, - up.INTEGRATION_METHOD_ATOMIC, - up.INTEGRATION_METHOD_MINI, + up.cdict["INTEGRATION_METHOD_ATOMIC"], + up.cdict["INTEGRATION_METHOD_MINI"], fo.USE_MINI_HALOS, ) # since the turnover mass table edges are hardcoded, we make sure we are within those limits @@ -1133,7 +1117,7 @@ def test_SFRD_conditional_table(name, R, intmethod, plt): 1.0, Mlim_Fstar, 0.0, - up.INTEGRATION_METHOD_ATOMIC, + up.cdict["INTEGRATION_METHOD_ATOMIC"], ) SFRD_integrals_mini = np.vectorize(lib.Nion_ConditionalM_MINI)( @@ -1151,7 +1135,7 @@ def test_SFRD_conditional_table(name, R, intmethod, plt): 1.0, Mlim_Fstar_MINI, 0.0, - up.INTEGRATION_METHOD_MINI, + up.cdict["INTEGRATION_METHOD_MINI"], ) abs_tol = 5e-18 # minimum = exp(-40) ~1e-18 @@ -1203,33 +1187,32 @@ def test_SFRD_conditional_table(name, R, intmethod, plt): def test_conditional_integral_methods(R, name, integrand, plt): redshift, kwargs = OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) - - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - - up.update( - USE_INTERPOLATION_TABLES=True, - ) - fo.update( + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] + + up = attrs.evolve(up, USE_INTERPOLATION_TABLES=True) + fo = attrs.evolve( + fo, USE_MINI_HALOS=True, USE_MASS_DEPENDENT_ZETA=True, INHOMO_RECO=True, USE_TS_FLUCT=True, ) if "sfr" in integrand: - ap.update(F_ESC10=0.0, F_ESC7_MINI=0.0, ALPHA_ESC=0.0) # F_ESCX is in log10 + ap = attrs.evolve( + ap, F_ESC10=0.0, F_ESC7_MINI=0.0, ALPHA_ESC=0.0 + ) # F_ESCX is in log10 - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) hist_size = 1000 M_min = global_params.M_MIN_INTEGRAL M_max = global_params.M_MAX_INTEGRAL lib.init_ps() - - if up.INTEGRATION_METHOD_ATOMIC == 1 or up.INTEGRATION_METHOD_MINI == 1: + if "GAUSS-LEGENDRE" in (up.INTEGRATION_METHOD_ATOMIC, up.INTEGRATION_METHOD_MINI): lib.initialise_GL(np.log(M_min), np.log(M_max)) growth_out = lib.dicke(redshift) @@ -1239,7 +1222,7 @@ def test_conditional_integral_methods(R, name, integrand, plt): .value ) sigma_cond = lib.sigma_z0(cond_mass) - delta_crit = lib.get_delta_crit(up.HMF, sigma_cond, growth_out) + delta_crit = lib.get_delta_crit(up.cdict["HMF"], sigma_cond, growth_out) edges_d = np.linspace(-1, delta_crit * 1.1, num=hist_size).astype("f4") edges_m = np.logspace(5, 10, num=int(hist_size / 10)).astype("f4") @@ -1258,12 +1241,15 @@ def test_conditional_integral_methods(R, name, integrand, plt): integrals = [] integrals_mini = [] input_arr = np.meshgrid(edges_d[:-1], np.log10(edges_m[:-1]), indexing="ij") - for method in range(0, 3): - if name != "PS" and method == 2: + for method in ["GSL-QAG", "GAUSS-LEGENDRE", "GAMMA-APPROX"]: + print(f"Starting method {method}", flush=True) + if name != "PS" and method == "GAMMA-APPROX": continue - up.update(INTEGRATION_METHOD_ATOMIC=method, INTEGRATION_METHOD_MINI=method) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + up = attrs.evolve( + up, INTEGRATION_METHOD_ATOMIC=method, INTEGRATION_METHOD_MINI=method + ) + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) integrals.append( np.vectorize(lib.Nion_ConditionalM)( @@ -1280,7 +1266,7 @@ def test_conditional_integral_methods(R, name, integrand, plt): 10**ap.F_ESC10, Mlim_Fstar, Mlim_Fesc, - up.INTEGRATION_METHOD_ATOMIC, + up.cdict["INTEGRATION_METHOD_ATOMIC"], ) ) integrals_mini.append( @@ -1299,7 +1285,7 @@ def test_conditional_integral_methods(R, name, integrand, plt): 10**ap.F_ESC7_MINI, Mlim_Fstar_MINI, Mlim_Fesc_MINI, - up.INTEGRATION_METHOD_MINI, + up.cdict["INTEGRATION_METHOD_MINI"], ) ) @@ -1326,24 +1312,25 @@ def test_conditional_integral_methods(R, name, integrand, plt): # for the FAST_FFCOLL integrals, only the delta-Mturn behaviour matters (because of the mean fixing), so we divide by # the value at delta=0 (mturn ~ 5e7 for minihalos) and set a wider tolerance - sel_deltazero = np.argmin(np.fabs(edges_d)) - sel_mturn = np.argmin(np.fabs(edges_m - 5e7)) - ffcoll_deltazero = integrals[2][sel_deltazero] - ffcoll_deltazero_mini = integrals_mini[2][sel_deltazero, sel_mturn] - qag_deltazero = integrals[0][sel_deltazero] - qag_deltazero_mini = integrals_mini[0][sel_deltazero, sel_mturn] - np.testing.assert_allclose( - integrals[2] / ffcoll_deltazero, - integrals[0] / qag_deltazero, - atol=abs_tol, - rtol=1e-1, - ) - np.testing.assert_allclose( - integrals_mini[2] / ffcoll_deltazero_mini[None, :], - integrals_mini[0] / qag_deltazero_mini[None, :], - atol=abs_tol, - rtol=1e-1, - ) + if name == "PS": + sel_deltazero = np.argmin(np.fabs(edges_d)) + sel_mturn = np.argmin(np.fabs(edges_m - 5e7)) + ffcoll_deltazero = integrals[2][sel_deltazero] + ffcoll_deltazero_mini = integrals_mini[2][sel_deltazero, sel_mturn] + qag_deltazero = integrals[0][sel_deltazero] + qag_deltazero_mini = integrals_mini[0][sel_deltazero, sel_mturn] + np.testing.assert_allclose( + integrals[2] / ffcoll_deltazero, + integrals[0] / qag_deltazero, + atol=abs_tol, + rtol=1e-1, + ) + np.testing.assert_allclose( + integrals_mini[2] / ffcoll_deltazero_mini[None, :], + integrals_mini[0] / qag_deltazero_mini[None, :], + atol=abs_tol, + rtol=1e-1, + ) def make_table_comparison_plot( diff --git a/tests/test_cli.py b/tests/test_cli.py index 15611f88..3a1cd4db 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,11 +14,11 @@ def runner(): @pytest.fixture(scope="module") def cfg(default_user_params, tmpdirec): with open(tmpdirec / "cfg.yml", "w") as f: - yaml.dump({"user_params": default_user_params.self}, f) + yaml.dump({"user_params": default_user_params.asdict()}, f) return tmpdirec / "cfg.yml" -def test_init(module_direc, default_user_params, runner, cfg): +def test_init(module_direc, default_input_struct, runner, cfg): # Run the CLI. There's no way to turn off writing from the CLI (since that # would be useless). We produce a *new* initial conditions box in a new # directory and check that it exists. It gets auto-deleted after. @@ -32,7 +32,10 @@ def test_init(module_direc, default_user_params, runner, cfg): assert result.exit_code == 0 - ic = InitialConditions(user_params=default_user_params, random_seed=101010) + ic = InitialConditions( + inputs=default_input_struct, + random_seed=101010, + ) assert ic.exists(direc=str(module_direc)) diff --git a/tests/test_config.py b/tests/test_config.py index 6d2dd635..22f9390a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,7 +14,7 @@ def cfgdir(tmp_path_factory): def test_config_context(cfgdir, default_user_params): with p21.config.use(direc=cfgdir, write=True): - init = p21.initial_conditions(user_params=default_user_params) + init = p21.compute_initial_conditions(user_params=default_user_params) assert (cfgdir / init.filename).exists() assert "config_test_dir" not in p21.config["direc"] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 6f07c749..9e8c77d4 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,8 +1,13 @@ import pytest -from py21cmfast.c_21cmfast import lib -from py21cmfast.wrapper import _call_c_simple -from py21cmfast.wrapper._utils import PHOTONCONSERROR, ParameterError +import numpy as np + +from py21cmfast.c_21cmfast import ffi, lib +from py21cmfast.wrapper.exceptions import ( + PHOTONCONSERROR, + ParameterError, + _process_exitcode, +) @pytest.mark.parametrize("subfunc", [True, False]) @@ -13,10 +18,19 @@ def test_basic(subfunc): @pytest.mark.parametrize("subfunc", [True, False]) def test_simple(subfunc): + answer = np.array([0], dtype="f8") with pytest.raises(ParameterError): - _call_c_simple(lib.FunctionThatCatches, subfunc, False) + status = lib.FunctionThatCatches( + subfunc, False, ffi.cast("double *", ffi.from_buffer(answer)) + ) + _process_exitcode( + status, + lib.FunctionThatCatches, + (False, ffi.cast("double *", ffi.from_buffer(answer))), + ) def test_pass(): - answer = _call_c_simple(lib.FunctionThatCatches, True, True) + answer = np.array([0], dtype="f8") + lib.FunctionThatCatches(True, True, ffi.cast("double *", ffi.from_buffer(answer))) assert answer == 5.0 diff --git a/tests/test_filtering.py b/tests/test_filtering.py index b38aed85..93e45b84 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -95,10 +95,10 @@ def get_binned_stats(x_arr, y_arr, bins, stats): def test_filters(filter_flag, R, plt): opts = prd.get_all_options(redshift=10.0) - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] # testing a single pixel source input_box_centre = np.zeros((up.HII_DIM,) * 3, dtype="f4") @@ -113,10 +113,10 @@ def test_filters(filter_flag, R, plt): R_param = 0 lib.test_filter( - up(), - cp(), - ap(), - fo(), + up.cstruct, + cp.cstruct, + ap.cstruct, + fo.cstruct, ffi.cast("float *", input_box_centre.ctypes.data), R, R_param, diff --git a/tests/test_halo_sampler.py b/tests/test_halo_sampler.py index 84f36a45..694b2063 100644 --- a/tests/test_halo_sampler.py +++ b/tests/test_halo_sampler.py @@ -21,7 +21,7 @@ options_hmf = list(cint.OPTIONS_HMF.keys()) options_delta = [-0.9, 0, 1, 1.6] # cell densities to draw samples from -options_mass = [1e9, 1e10, 1e11, 1e12] # halo masses to draw samples from +options_mass = [9, 10, 11, 12] # halo masses to draw samples from @pytest.mark.parametrize("name", options_hmf) @@ -30,12 +30,14 @@ def test_sampler_from_catalog(name, mass, plt): redshift, kwargs = cint.OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - up.update(USE_INTERPOLATION_TABLES=True) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] + + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) + + mass = 10**mass l10min = np.log10(up.SAMPLER_MIN_MASS) l10max = np.log10(mass) @@ -49,7 +51,7 @@ def test_sampler_from_catalog(name, mass, plt): global_params.M_MIN_INTEGRAL, global_params.M_MAX_INTEGRAL ) - n_cond = 30000 + n_cond = 2000 z = 6.0 z_prev = 5.8 @@ -58,7 +60,9 @@ def test_sampler_from_catalog(name, mass, plt): sigma_cond_m = lib.sigma_z0(mass) delta_cond_m = ( - lib.get_delta_crit(up.HMF, sigma_cond_m, growth_prev) * growthf / growth_prev + lib.get_delta_crit(up.cdict["HMF"], sigma_cond_m, growth_prev) + * growthf + / growth_prev ) mass_dens = cp.cosmo.Om0 * cp.cosmo.critical_density(0).to("Mpc-3 M_sun").value volume_total_m = mass * n_cond / mass_dens @@ -76,10 +80,10 @@ def test_sampler_from_catalog(name, mass, plt): halocrd_out = np.zeros(int(3e8)).astype("i4") lib.single_test_sample( - up(), - cp(), - ap(), - fo(), + up.cstruct, + cp.cstruct, + ap.cstruct, + fo.cstruct, 12345, n_cond, ffi.cast("float *", cond_in.ctypes.data), @@ -139,19 +143,18 @@ def test_sampler_from_grid(name, delta, plt): redshift, kwargs = cint.OPTIONS_HMF[name] opts = prd.get_all_options(redshift, **kwargs) - up = UserParams(opts["user_params"]) - cp = CosmoParams(opts["cosmo_params"]) - ap = AstroParams(opts["astro_params"]) - fo = FlagOptions(opts["flag_options"]) - up.update(USE_INTERPOLATION_TABLES=True) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) lib.init_ps() lib.initialiseSigmaMInterpTable( global_params.M_MIN_INTEGRAL, global_params.M_MAX_INTEGRAL ) - n_cond = 30000 + n_cond = 2000 z = 6.0 growthf = lib.dicke(z) @@ -183,10 +186,10 @@ def test_sampler_from_grid(name, delta, plt): halocrd_out = np.zeros(int(3e8)).astype("i4") lib.single_test_sample( - up(), - cp(), - ap(), - fo(), + up.cstruct, + cp.cstruct, + ap.cstruct, + fo.cstruct, 12345, # TODO: homogenize n_cond, ffi.cast("float *", cond_in.ctypes.data), @@ -245,143 +248,26 @@ def test_sampler_from_grid(name, delta, plt): # changes to any scaling relation model will result in a test fail def test_halo_scaling_relations(): # specify parameters to use for this test - f_star10 = -1.0 - f_star7 = -2.0 - a_star = 1.0 - a_star_mini = 1.0 - t_star = 0.5 - f_esc10 = -1.0 - f_esc7 = -1.0 - a_esc = -0.5 # for the test we don't want a_esc = -a_star - lx = 40.0 - lx_mini = 40.0 - sigma_star = 0.3 - sigma_sfr_lim = 0.2 - sigma_sfr_index = -0.12 - sigma_lx = 0.5 - redshift = 10.0 + opts = prd.get_all_options(redshift, {}) - # setup specific parameters that so we know what the outcome should be - up = UserParams() - cp = CosmoParams() - ap = AstroParams( - F_STAR10=f_star10, - F_STAR7_MINI=f_star7, - ALPHA_STAR=a_star, - ALPHA_STAR_MINI=a_star_mini, - F_ESC10=f_esc10, - F_ESC7_MINI=f_esc7, - ALPHA_ESC=a_esc, - L_X=lx, - L_X_MINI=lx_mini, - SIGMA_STAR=sigma_star, - SIGMA_SFR_LIM=sigma_sfr_lim, - SIGMA_SFR_INDEX=sigma_sfr_index, - SIGMA_LX=sigma_lx, - t_STAR=0.5, - M_TURN=6.0, - ) - # NOTE: Not using upper turnover, this test should be extended - fo = FlagOptions( - USE_MINI_HALOS=True, - INHOMO_RECO=True, - USE_TS_FLUCT=True, - USE_HALO_FIELD=True, - FIXED_HALO_GRIDS=False, - HALO_STOCHASTICITY=True, - USE_UPPER_STELLAR_TURNOVER=False, - ) + up = opts["user_params"] + cp = opts["cosmo_params"] + ap = opts["astro_params"] + fo = opts["flag_options"] + lib.Broadcast_struct_global_all(up.cstruct, cp.cstruct, ap.cstruct, fo.cstruct) - lib.Broadcast_struct_global_all(up(), cp(), ap(), fo()) mturn_acg = np.maximum(lib.atomic_cooling_threshold(redshift), 10**ap.M_TURN) - mturn_mcg = ( - 10**ap.M_TURN - ) # I don't want to test the LW or reionisation feedback here - - print(f"turnovers [{mturn_acg},{mturn_mcg}]") + # mturn_mcg = 10**ap.M_TURN print(f"z={redshift} th = {1/cp.cosmo.H(redshift).to('s-1').value}") # setup the halo masses to test - halo_masses = np.array([1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12]) - halo_rng = np.ones_like( - halo_masses - ) # we set the RNG to one sigma above for the test - - # independently calculate properties for the halos from our scaling relations - exp_fstar = (10**f_star10) * (halo_masses / 1e10) ** a_star - exp_fesc = np.minimum((10**f_esc10) * (halo_masses / 1e10) ** a_esc, 1) - exp_fstar_mini = (10**f_star7) * (halo_masses / 1e7) ** a_star_mini - exp_fesc_mini = np.minimum((10**f_esc7) * (halo_masses / 1e7) ** a_esc, 1) - b_r = cp.OMb / cp.OMm - acg_turnover = np.exp(-mturn_acg / halo_masses) - mcg_turnovers = np.exp(-halo_masses / mturn_acg) * np.exp(-mturn_mcg / halo_masses) - - expected_hm = halo_masses - expected_sm = ( - np.minimum(exp_fstar * np.exp(halo_rng * sigma_star) * acg_turnover, 1) - * halo_masses - * b_r - ) - expected_sm_mini = ( - np.minimum(exp_fstar_mini * np.exp(halo_rng * sigma_star) * mcg_turnovers, 1) - * halo_masses - * b_r - ) - - sigma_sfr = ( - sigma_sfr_index * np.log10((expected_sm + expected_sm_mini) / 1e10) - + sigma_sfr_lim - ) - sigma_sfr = np.maximum(sigma_sfr, sigma_sfr_lim) - expected_sfr = ( - expected_sm - / t_star - * cp.cosmo.H(redshift).to("s-1").value - * np.exp(halo_rng * sigma_sfr) - ) - expected_sfr_mini = ( - expected_sm_mini - / t_star - * cp.cosmo.H(redshift).to("s-1").value - * np.exp(halo_rng * sigma_sfr) - ) - - expected_nion = ( - expected_sm * exp_fesc * global_params.Pop2_ion - + expected_sm_mini * exp_fesc_mini * global_params.Pop3_ion - ) - expected_wsfr = ( - expected_sfr * exp_fesc * global_params.Pop2_ion - + expected_sfr_mini * exp_fesc_mini * global_params.Pop3_ion - ) - - # NOTE: These are currently hardcoded in the backend, changes will result in this test failing - s_per_yr = 365.25 * 60 * 60 * 24 - expected_metals = ( - 1.28825e10 * ((expected_sfr + expected_sfr_mini) * s_per_yr) ** 0.56 - ) # SM denominator - expected_metals = ( - 0.296 - * ( - (1 + ((expected_sm + expected_sm_mini) / expected_metals) ** (-2.1)) - ** -0.148 - ) - * 10 ** (-0.056 * redshift + 0.064) - ) - - expected_xray = ( - (expected_sfr * s_per_yr) ** 1.03 - * expected_metals**-0.64 - * np.exp(halo_rng * sigma_lx) - * 10**lx - ) - expected_xray += ( - (expected_sfr_mini * s_per_yr) ** 1.03 - * expected_metals**-0.64 - * np.exp(halo_rng * sigma_lx) - * 10**lx_mini - ) + halo_mass_vals = [1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12] + n_halo_per_mass = 1000 + halo_masses = np.array( + [n_halo_per_mass * [val] for val in halo_mass_vals] + ).flatten() + halo_rng = np.random.normal(size=n_halo_per_mass * len(halo_mass_vals)) # HACK: Make the fake halo list fake_pthalos = PerturbHaloField( @@ -406,10 +292,10 @@ def test_halo_scaling_relations(): out_buffer = np.zeros(12 * halo_masses.size).astype("f4") lib.test_halo_props( redshift, - up(), - cp(), - ap(), - fo(), + up.cstruct, + cp.cstruct, + ap.cstruct, + fo.cstruct, zero_array, zero_array, zero_array, @@ -418,28 +304,31 @@ def test_halo_scaling_relations(): ffi.cast("float *", out_buffer.ctypes.data), ) - np.testing.assert_allclose(expected_hm, out_buffer[0::12], atol=1e0) + # (n_halo*n_mass*n_prop) --> (n_prop,n_mass,n_halo) - np.testing.assert_allclose(mturn_acg, out_buffer[8::12], atol=1e0) - np.testing.assert_allclose(mturn_mcg, out_buffer[9::12], atol=1e0) - np.testing.assert_allclose(0.0, out_buffer[10::12], atol=1e0) # no reion feedback + # mass,star,sfr,xray,nion,wsfr,starmini,sfrmini,mturna,mturnm,mturnr,Z + out_buffer = out_buffer.reshape(len(halo_mass_vals), n_halo_per_mass, 12) - np.testing.assert_allclose(expected_sm, out_buffer[1::12], atol=1e0) - np.testing.assert_allclose(expected_sm_mini, out_buffer[6::12], atol=1e0) - - # hubble differences between the two codes make % level changes TODO: change hubble to double precision in backend - np.testing.assert_allclose(expected_sfr, out_buffer[2::12], rtol=5e-2, atol=1e-20) - np.testing.assert_allclose( - expected_sfr_mini, out_buffer[7::12], rtol=5e-2, atol=1e-20 + exp_SHMR = ( + (10**ap.F_STAR10) + * halo_mass_vals**ap.ALPHA_STAR + * np.exp(-mturn_acg / halo_mass_vals) ) + sim_SHMR = out_buffer[:, :, 1] / out_buffer[:, :, 0] + np.testing.assert_allclose(exp_SHMR, sim_SHMR.mean(axis=1), rtol=1e-1) + np.testing.assert_allclose(ap.SIGMA_STAR, sim_SHMR.std(axis=1), rtol=1e-1) - np.testing.assert_allclose(expected_metals, out_buffer[11::12], rtol=1e-3) + exp_SSFR = cp.cosmo.H(redshift).to("s").value / (ap.t_STAR) + sim_SSFR = out_buffer[:, :, 2] / out_buffer[:, :, 1] + np.testing.assert_allclose(exp_SSFR, sim_SSFR.mean(axis=1), rtol=1e-1) np.testing.assert_allclose( - expected_xray, out_buffer[3::12].astype(float) * 1e38, rtol=5e-2 - ) + ap.SIGMA_SFR_LIM, sim_SSFR.std(axis=1), rtol=1e-1 + ) # WRONG - np.testing.assert_allclose(expected_nion, out_buffer[4::12]) - np.testing.assert_allclose(expected_wsfr, out_buffer[5::12], rtol=5e-2) + exp_LX = 10 ** (ap.L_X) # low-z approx + sim_LX = out_buffer[:, :, 3] / out_buffer[:, :, 2] + np.testing.assert_allclose(exp_LX, sim_LX.mean(axis=1), rtol=1e-1) + np.testing.assert_allclose(ap.SIGMA_LX, sim_LX.std(axis=1), rtol=1e-1) def plot_sampler_comparison( diff --git a/tests/test_high_level_io.py b/tests/test_high_level_io.py index 3a91d9d5..9f15bf6e 100644 --- a/tests/test_high_level_io.py +++ b/tests/test_high_level_io.py @@ -1,7 +1,9 @@ import pytest +import attrs import h5py import numpy as np +from collections import deque from py21cmfast import ( BrightnessTemp, @@ -18,30 +20,36 @@ @pytest.fixture(scope="module") -def coeval(ic): +def coeval(ic, default_flag_options_ts): return run_coeval( - redshift=25.0, init_box=ic, write=True, flag_options={"USE_TS_FLUCT": True} + out_redshifts=25.0, + initial_conditions=ic, + write=True, + flag_options=default_flag_options_ts, ) @pytest.fixture(scope="module") -def lightcone(ic): +def lightcone(ic, default_flag_options_ts): lcn = RectilinearLightconer.with_equal_cdist_slices( min_redshift=25.0, max_redshift=35.0, resolution=ic.user_params.cell_size, ) - return run_lightcone( + lc_gen = run_lightcone( lightconer=lcn, init_box=ic, write=True, - flag_options={"USE_TS_FLUCT": True}, + flag_options=default_flag_options_ts, ) + iz, z, coev, lc = deque(lc_gen, maxlen=1) + return lc + @pytest.fixture(scope="module") -def ang_lightcone(ic, lc): +def ang_lightcone(ic, lc, default_flag_options_ts): lcn = AngularLightconer.like_rectilinear( match_at_z=lc.lightcone_redshifts.min(), max_redshift=lc.lightcone_redshifts.max(), @@ -49,13 +57,16 @@ def ang_lightcone(ic, lc): get_los_velocity=True, ) - return run_lightcone( + lc_gen = run_lightcone( lightconer=lcn, init_box=ic, write=True, - flag_options={"APPLY_RSDS": False}, + flag_options=attrs.evolve(default_flag_options_ts, APPLY_RSDS=False), ) + iz, z, coev, lc = deque(lc_gen, maxlen=1) + return lc + def test_read_bad_file_lc(test_direc, lc): # create a bad hdf5 file with some good fields, diff --git a/tests/test_initial_conditions.py b/tests/test_initial_conditions.py index 95ad8df1..c0e11926 100644 --- a/tests/test_initial_conditions.py +++ b/tests/test_initial_conditions.py @@ -1,5 +1,5 @@ """ -Various tests of the initial_conditions() function and InitialConditions class. +Various tests of the compute_initial_conditions() function and InitialConditions class. """ import pytest @@ -7,7 +7,7 @@ import numpy as np from multiprocessing import cpu_count -from py21cmfast import wrapper +import py21cmfast as p21c def test_box_shape(ic): @@ -32,13 +32,13 @@ def test_box_shape(ic): assert not hasattr(ic, "lowres_vcb") - assert ic.cosmo_params == wrapper.CosmoParams() + assert ic.cosmo_params == p21c.CosmoParams() def test_modified_cosmo(ic): """Test using a modified cosmology""" - cosmo = wrapper.CosmoParams(SIGMA_8=0.9) - ic2 = wrapper.initial_conditions( + cosmo = p21c.CosmoParams(SIGMA_8=0.9) + ic2 = p21c.compute_initial_conditions( cosmo_params=cosmo, user_params=ic.user_params, ) @@ -51,7 +51,7 @@ def test_modified_cosmo(ic): def test_transfer_function(ic, default_user_params): """Test using a modified transfer function""" user_params = default_user_params.clone(POWER_SPECTRUM=5) - ic2 = wrapper.initial_conditions( + ic2 = p21c.compute_initial_conditions( random_seed=ic.random_seed, user_params=user_params, ) @@ -65,9 +65,9 @@ def test_transfer_function(ic, default_user_params): def test_relvels(): """Test for relative velocity initial conditions""" - ic = wrapper.initial_conditions( + ic = p21c.compute_initial_conditions( random_seed=1, - user_params=wrapper.UserParams( + user_params=p21c.UserParams( HII_DIM=100, DIM=300, BOX_LEN=300, diff --git a/tests/test_input_structs.py b/tests/test_input_structs.py index 7c496d57..2592837b 100644 --- a/tests/test_input_structs.py +++ b/tests/test_input_structs.py @@ -9,7 +9,7 @@ from py21cmfast import AstroParams # An example of a struct with defaults from py21cmfast import CosmoParams, FlagOptions, UserParams, __version__, global_params -from py21cmfast.wrapper.inputs import validate_all_inputs +from py21cmfast.drivers.param_config import InputParameters @pytest.fixture(scope="module") @@ -170,8 +170,11 @@ def test_validation(): with global_params.use(HII_FILTER=2): with pytest.warns(UserWarning, match="Setting R_BUBBLE_MAX to BOX_LEN"): - validate_all_inputs( - cosmo_params=c, astro_params=a, flag_options=f, user_params=u + InputParameters( + cosmo_params=c, + astro_params=a, + user_params=u, + flag_options=f, ) assert a.R_BUBBLE_MAX == u.BOX_LEN @@ -180,8 +183,11 @@ def test_validation(): with global_params.use(HII_FILTER=1): with pytest.raises(ValueError, match="Your R_BUBBLE_MAX is > BOX_LEN/3"): - validate_all_inputs( - cosmo_params=c, astro_params=a, flag_options=f, user_params=u + InputParameters( + cosmo_params=c, + astro_params=a, + user_params=u, + flag_options=f, ) diff --git a/tests/test_output_structs.py b/tests/test_output_structs.py index cc0ec5b5..eb8b4f8c 100644 --- a/tests/test_output_structs.py +++ b/tests/test_output_structs.py @@ -12,9 +12,14 @@ from py21cmfast import IonizedBox, PerturbedField, TsBox, global_params +@pytest.fixture(scope="module") +def input_struct_noseed(default_input_struct): + return default_input_struct.evolve(random_seed=None) + + @pytest.fixture(scope="function") -def init(default_user_params): - return InitialConditions(user_params=default_user_params) +def init(input_struct_noseed): + return InitialConditions(inputs=input_struct_noseed) @pytest.mark.parametrize("cls", [InitialConditions, PerturbedField, IonizedBox, TsBox]) @@ -46,8 +51,8 @@ def test_writeability(init): init.write() -def test_readability(ic, tmpdirec, default_user_params): - ic2 = InitialConditions(user_params=default_user_params) +def test_readability(ic, tmpdirec, input_struct_noseed): + ic2 = InitialConditions(inputs=input_struct_noseed) # without seeds, they are obviously exactly the same. assert ic._seedless_repr() == ic2._seedless_repr() @@ -63,8 +68,8 @@ def test_readability(ic, tmpdirec, default_user_params): assert ic is not ic2 -def test_different_seeds(init, default_user_params): - ic2 = InitialConditions(random_seed=2, user_params=default_user_params) +def test_different_seeds(init, input_struct_noseed): + ic2 = InitialConditions(inputs=input_struct_noseed.evolve(random_seed=2)) assert init is not ic2 assert init != ic2 @@ -77,8 +82,9 @@ def test_different_seeds(init, default_user_params): assert init._random_seed is None -def test_pickleability(default_user_params): - ic_ = InitialConditions(init=True, user_params=default_user_params) +def test_pickleability(input_struct_noseed): + # TODO: remove init kwarg which does nothing? + ic_ = InitialConditions(init=True, inputs=input_struct_noseed) ic_.filled = True ic_.random_seed @@ -88,14 +94,14 @@ def test_pickleability(default_user_params): assert repr(ic_) == repr(ic2) -def test_fname(default_user_params): - ic1 = InitialConditions(user_params=default_user_params) - ic2 = InitialConditions(user_params=default_user_params) +def test_fname(input_struct_noseed): + ic1 = InitialConditions(inputs=input_struct_noseed) + ic2 = InitialConditions(inputs=input_struct_noseed) # we didn't give them seeds, so can't access the filename attribute # (it is undefined until a seed is set) with pytest.raises(AttributeError): - assert ic1.filename != ic2.filename # random seeds are different + assert ic1.filename != ic2.filename # *but* should be able to get a skeleton filename: assert ic1._fname_skeleton == ic2._fname_skeleton @@ -107,15 +113,15 @@ def test_fname(default_user_params): assert ic1._fname_skeleton == ic2._fname_skeleton -def test_match_seed(tmpdirec, default_user_params): - ic2 = InitialConditions(random_seed=1, user_params=default_user_params) +def test_match_seed(tmpdirec, input_struct_noseed): + ic2 = InitialConditions(inputs=input_struct_noseed.evolve(random_seed=3)) # This fails because we've set the seed and it's different to the existing one. with pytest.raises(IOError): ic2.read(direc=tmpdirec) -def test_bad_class_definition(default_user_params): +def test_bad_class_definition(input_struct_noseed): class CustomInitialConditions(InitialConditions): _name = "InitialConditions" @@ -129,7 +135,7 @@ def _get_box_structures(self): return out with pytest.raises(TypeError): - CustomInitialConditions(init=True, user_params=default_user_params) + CustomInitialConditions(inputs=input_struct_noseed) def test_bad_write(init): diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index f6a82d11..97f61970 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -10,36 +10,42 @@ import numpy as np from astropy import units as un -from py21cmfast import wrapper +import py21cmfast as p21c @pytest.fixture(scope="module") def perturb_field_lowz(ic, low_redshift): """A default perturb_field""" - return wrapper.perturb_field(redshift=low_redshift, init_boxes=ic, write=True) + return p21c.perturb_field(redshift=low_redshift, initial_conditions=ic, write=True) @pytest.fixture(scope="module") -def ionize_box(perturbed_field): +def ionize_box(ic, perturbed_field): """A default ionize_box""" - return wrapper.ionize_box(perturbed_field=perturbed_field, write=True) + return p21c.compute_ionization_field( + initial_conditions=ic, perturbed_field=perturbed_field, write=True + ) @pytest.fixture(scope="module") -def ionize_box_lowz(perturb_field_lowz): +def ionize_box_lowz(ic, perturb_field_lowz): """A default ionize_box at lower redshift.""" - return wrapper.ionize_box(perturbed_field=perturb_field_lowz, write=True) + return p21c.compute_ionization_field( + initial_conditions=ic, perturbed_field=perturb_field_lowz, write=True + ) @pytest.fixture(scope="module") -def spin_temp(perturbed_field): +def spin_temp(ic, perturbed_field): """A default perturb_field""" - return wrapper.spin_temperature(perturbed_field=perturbed_field, write=True) + return p21c.spin_temperature( + initial_conditions=ic, perturbed_field=perturbed_field, write=True + ) def test_perturb_field_no_ic(default_user_params, redshift, perturbed_field): """Run a perturb field without passing an init box""" - pf = wrapper.perturb_field(redshift=redshift, user_params=default_user_params) + pf = p21c.perturb_field(redshift=redshift, user_params=default_user_params) assert len(pf.density) == pf.user_params.HII_DIM == default_user_params.HII_DIM assert pf.redshift == redshift assert pf.random_seed != perturbed_field.random_seed @@ -50,20 +56,20 @@ def test_perturb_field_no_ic(default_user_params, redshift, perturbed_field): def test_ib_no_z(ic): with pytest.raises(ValueError): - wrapper.ionize_box(init_boxes=ic) + p21c.compute_ionization_field(initial_conditions=ic) def test_pf_unnamed_param(): """Try using an un-named parameter.""" with pytest.raises(TypeError): - wrapper.perturb_field(7) + p21c.perturb_field(7) def test_perturb_field_ic(perturbed_field, ic): # this will run perturb_field again, since by default regenerate=True for tests. # BUT it should produce exactly the same as the default perturb_field since it has # the same seed. - pf = wrapper.perturb_field(redshift=perturbed_field.redshift, init_boxes=ic) + pf = p21c.perturb_field(redshift=perturbed_field.redshift, initial_conditions=ic) assert len(pf.density) == len(ic.lowres_density) assert pf.cosmo_params == ic.cosmo_params @@ -77,7 +83,7 @@ def test_perturb_field_ic(perturbed_field, ic): def test_cache_exists(default_user_params, perturbed_field, tmpdirec): - pf = wrapper.PerturbedField( + pf = p21c.PerturbedField( redshift=perturbed_field.redshift, cosmo_params=perturbed_field.cosmo_params, user_params=default_user_params, @@ -91,7 +97,7 @@ def test_cache_exists(default_user_params, perturbed_field, tmpdirec): def test_pf_new_seed(perturbed_field, tmpdirec): - pf = wrapper.perturb_field( + pf = p21c.perturb_field( redshift=perturbed_field.redshift, user_params=perturbed_field.user_params, random_seed=1, @@ -107,12 +113,12 @@ def test_pf_new_seed(perturbed_field, tmpdirec): def test_ib_new_seed(ionize_box_lowz, perturb_field_lowz, tmpdirec): # this should fail because perturb_field has a seed set already, which isn't 1. with pytest.raises(ValueError): - wrapper.ionize_box( + p21c.compute_ionization_field( perturbed_field=perturb_field_lowz, random_seed=1, ) - ib = wrapper.ionize_box( + ib = p21c.compute_ionization_field( cosmo_params=perturb_field_lowz.cosmo_params, redshift=perturb_field_lowz.redshift, user_params=perturb_field_lowz.user_params, @@ -128,12 +134,12 @@ def test_ib_new_seed(ionize_box_lowz, perturb_field_lowz, tmpdirec): def test_st_new_seed(spin_temp, perturbed_field, tmpdirec): # this should fail because perturb_field has a seed set already, which isn't 1. with pytest.raises(ValueError): - wrapper.spin_temperature( + p21c.spin_temperature( perturbed_field=perturbed_field, random_seed=1, ) - st = wrapper.spin_temperature( + st = p21c.spin_temperature( cosmo_params=spin_temp.cosmo_params, user_params=spin_temp.user_params, astro_params=spin_temp.astro_params, @@ -151,7 +157,7 @@ def test_st_new_seed(spin_temp, perturbed_field, tmpdirec): def test_st_from_z(perturb_field_lowz, spin_temp): # This one has all the same parameters as the nominal spin_temp, but is evaluated # with an interpolated perturb_field - st = wrapper.spin_temperature( + st = p21c.spin_temperature( perturbed_field=perturb_field_lowz, astro_params=spin_temp.astro_params, flag_options=spin_temp.flag_options, @@ -163,14 +169,14 @@ def test_st_from_z(perturb_field_lowz, spin_temp): def test_ib_from_pf(perturbed_field): - ib = wrapper.ionize_box(perturbed_field=perturbed_field) + ib = p21c.compute_ionization_field(perturbed_field=perturbed_field) assert ib.redshift == perturbed_field.redshift assert ib.user_params == perturbed_field.user_params assert ib.cosmo_params == perturbed_field.cosmo_params def test_ib_from_z(default_user_params, perturbed_field): - ib = wrapper.ionize_box( + ib = p21c.compute_ionization_field( redshift=perturbed_field.redshift, user_params=default_user_params, regenerate=False, @@ -183,7 +189,7 @@ def test_ib_from_z(default_user_params, perturbed_field): def test_ib_override_z(perturbed_field): with pytest.raises(ValueError): - wrapper.ionize_box( + p21c.compute_ionization_field( redshift=perturbed_field.redshift + 1, perturbed_field=perturbed_field, ) @@ -191,33 +197,33 @@ def test_ib_override_z(perturbed_field): def test_ib_override_z_heat_max(perturbed_field): # save previous z_heat_max - zheatmax = wrapper.global_params.Z_HEAT_MAX + zheatmax = p21c.global_params.Z_HEAT_MAX - wrapper.ionize_box( + p21c.compute_ionization_field( redshift=perturbed_field.redshift, perturbed_field=perturbed_field, z_heat_max=12.0, ) - assert wrapper.global_params.Z_HEAT_MAX == zheatmax + assert p21c.global_params.Z_HEAT_MAX == zheatmax def test_ib_bad_st(ic, redshift): with pytest.raises(ValueError): - wrapper.ionize_box(redshift=redshift, spin_temp=ic) + p21c.compute_ionization_field(redshift=redshift, spin_temp=ic) def test_bt(ionize_box, spin_temp, perturbed_field): with pytest.raises(TypeError): # have to specify param names - wrapper.brightness_temperature(ionize_box, spin_temp, perturbed_field) + p21c.brightness_temperature(ionize_box, spin_temp, perturbed_field) # this will fail because ionized_box was not created with spin temperature. with pytest.raises(ValueError): - wrapper.brightness_temperature( + p21c.brightness_temperature( ionized_box=ionize_box, perturbed_field=perturbed_field, spin_temp=spin_temp ) - bt = wrapper.brightness_temperature( + bt = p21c.brightness_temperature( ionized_box=ionize_box, perturbed_field=perturbed_field ) @@ -228,7 +234,7 @@ def test_bt(ionize_box, spin_temp, perturbed_field): def test_coeval_against_direct(ic, perturbed_field, ionize_box): - coeval = wrapper.run_coeval(perturb=perturbed_field, init_box=ic) + coeval = p21c.run_coeval(perturb=perturbed_field, initial_conditions=ic) assert coeval.init_struct == ic assert coeval.perturb_struct == perturbed_field @@ -242,8 +248,8 @@ def test_lightcone(lc, default_user_params, redshift, max_redshift): def test_lightcone_quantities(ic, max_redshift, perturbed_field): - lc = wrapper.run_lightcone( - init_box=ic, + lc = p21c.run_lightcone( + initial_conditions=ic, perturb=perturbed_field, max_redshift=max_redshift, lightcone_quantities=("dNrec_box", "density", "brightness_temp"), @@ -266,8 +272,8 @@ def test_lightcone_quantities(ic, max_redshift, perturbed_field): # Raise an error since we're not doing spin temp. with pytest.raises(AttributeError): - wrapper.run_lightcone( - init_box=ic, + p21c.run_lightcone( + initial_conditions=ic, perturb=perturbed_field, max_redshift=20.0, lightcone_quantities=("Ts_box", "density"), @@ -275,8 +281,8 @@ def test_lightcone_quantities(ic, max_redshift, perturbed_field): # And also raise an error for global quantities. with pytest.raises(AttributeError): - wrapper.run_lightcone( - init_box=ic, + p21c.run_lightcone( + initial_conditions=ic, perturb=perturbed_field, max_redshift=20.0, global_quantities=("Ts_box",), @@ -284,18 +290,16 @@ def test_lightcone_quantities(ic, max_redshift, perturbed_field): def test_run_lf(): - muv, mhalo, lf = wrapper.compute_luminosity_function(redshifts=[7, 8, 9], nbins=100) + muv, mhalo, lf = p21c.compute_luminosity_function(redshifts=[7, 8, 9], nbins=100) assert np.all(lf[~np.isnan(lf)] > -30) assert lf.shape == (3, 100) # Check that memory is in-tact and a second run also works: - muv, mhalo, lf2 = wrapper.compute_luminosity_function( - redshifts=[7, 8, 9], nbins=100 - ) + muv, mhalo, lf2 = p21c.compute_luminosity_function(redshifts=[7, 8, 9], nbins=100) assert lf2.shape == (3, 100) assert np.allclose(lf2[~np.isnan(lf2)], lf[~np.isnan(lf)]) - muv_minih, mhalo_minih, lf_minih = wrapper.compute_luminosity_function( + muv_minih, mhalo_minih, lf_minih = p21c.compute_luminosity_function( redshifts=[7, 8, 9], nbins=100, component=0, @@ -312,32 +316,35 @@ def test_run_lf(): def test_coeval_st(ic, perturbed_field): - coeval = wrapper.run_coeval( - init_box=ic, + coeval = p21c.run_coeval( + initial_conditions=ic, perturb=perturbed_field, flag_options={"USE_TS_FLUCT": True}, ) - assert isinstance(coeval.spin_temp_struct, wrapper.TsBox) + assert isinstance(coeval.spin_temp_struct, p21c.TsBox) def _global_Tb(coeval_box): - assert isinstance(coeval_box, wrapper.Coeval) + assert isinstance(coeval_box, p21c.Coeval) global_Tb = coeval_box.brightness_temp.mean(dtype=np.float64).astype(np.float32) assert np.isclose(global_Tb, coeval_box.brightness_temp_struct.global_Tb) return global_Tb -def test_coeval_callback(ic, max_redshift, perturbed_field): - lc, coeval_output = wrapper.run_lightcone( - init_box=ic, +def test_coeval_callback( + rectlcn, ic, max_redshift, perturbed_field, default_flag_options +): + lc, coeval_output = p21c.run_lightcone( + lightconer=rectlcn, + initial_conditions=ic, perturb=perturbed_field, - max_redshift=max_redshift, + flag_options=default_flag_options, lightcone_quantities=("brightness_temp",), global_quantities=("brightness_temp",), coeval_callback=_global_Tb, ) - assert isinstance(lc, wrapper.LightCone) + assert isinstance(lc, p21c.LightCone) assert isinstance(coeval_output, list) assert len(lc.node_redshifts) == len(coeval_output) assert np.allclose( @@ -345,15 +352,18 @@ def test_coeval_callback(ic, max_redshift, perturbed_field): ) -def test_coeval_callback_redshifts(ic, redshift, max_redshift, perturbed_field): +def test_coeval_callback_redshifts( + rectlcn, ic, redshift, max_redshift, perturbed_field, default_flag_options +): coeval_callback_redshifts = np.array( [max_redshift, max_redshift, (redshift + max_redshift) / 2, redshift], dtype=np.float32, ) - lc, coeval_output = wrapper.run_lightcone( - init_box=ic, + lc, coeval_output = p21c.run_lightcone( + lightconer=rectlcn, + initial_conditions=ic, perturb=perturbed_field, - max_redshift=max_redshift, + flag_options=default_flag_options, coeval_callback=lambda x: x.redshift, coeval_callback_redshifts=coeval_callback_redshifts, ) @@ -369,20 +379,23 @@ def Heaviside(x): return 1 if x > 0 else 0 -def test_coeval_callback_exceptions(ic, redshift, max_redshift, perturbed_field): +def test_coeval_callback_exceptions( + rectlcn, ic, redshift, max_redshift, perturbed_field, default_flag_options +): # should output warning in logs and not raise an error - lc, coeval_output = wrapper.run_lightcone( - init_box=ic, + lc, coeval_output = p21c.run_lightcone( + lightconer=rectlcn, + initial_conditions=ic, perturb=perturbed_field, - max_redshift=max_redshift, + flag_options=default_flag_options, coeval_callback=lambda x: 1 / Heaviside(x.redshift - (redshift + max_redshift) / 2), coeval_callback_redshifts=[max_redshift, redshift], ) # should raise an error with pytest.raises(RuntimeError) as excinfo: - lc, coeval_output = wrapper.run_lightcone( - init_box=ic, + lc, coeval_output = p21c.run_lightcone( + initial_conditions=ic, perturb=perturbed_field, max_redshift=max_redshift, coeval_callback=lambda x: 1 / 0, @@ -392,18 +405,18 @@ def test_coeval_callback_exceptions(ic, redshift, max_redshift, perturbed_field) def test_coeval_vs_low_level(ic): - coeval = wrapper.run_coeval( - redshift=20, - init_box=ic, + coeval = p21c.run_coeval( + out_redshifts=20, + initial_conditions=ic, zprime_step_factor=1.1, regenerate=True, flag_options={"USE_TS_FLUCT": True}, write=False, ) - st = wrapper.spin_temperature( + st = p21c.spin_temperature( redshift=20, - init_boxes=ic, + initial_conditions=ic, zprime_step_factor=1.1, regenerate=True, flag_options={"USE_TS_FLUCT": True}, @@ -421,16 +434,16 @@ def test_using_cached_halo_field(ic, test_direc): Prior to v3.1 this was segfaulting, so this test ensure that this behaviour does not regress. """ - halo_field = wrapper.determine_halo_list( + halo_field = p21c.determine_halo_list( redshift=10.0, - init_boxes=ic, + initial_conditions=ic, write=True, direc=test_direc, ) - pt_halos = wrapper.perturb_halo_list( + pt_halos = p21c.perturb_halo_list( redshift=10.0, - init_boxes=ic, + initial_conditions=ic, halo_field=halo_field, write=True, direc=test_direc, @@ -438,13 +451,13 @@ def test_using_cached_halo_field(ic, test_direc): print("DONE WITH FIRST BOXES!") # Now get the halo field again at the same redshift -- should be cached - new_halo_field = wrapper.determine_halo_list( - redshift=10.0, init_boxes=ic, write=False, regenerate=False + new_halo_field = p21c.determine_halo_list( + redshift=10.0, initial_conditions=ic, write=False, regenerate=False ) - new_pt_halos = wrapper.perturb_halo_list( + new_pt_halos = p21c.perturb_halo_list( redshift=10.0, - init_boxes=ic, + initial_conditions=ic, halo_field=new_halo_field, write=False, regenerate=False, @@ -467,7 +480,7 @@ def test_run_coeval_bad_inputs(): with pytest.raises( ValueError, match="Cannot use an interpolated perturb field with minihalos" ): - wrapper.run_coeval( + p21c.run_coeval( redshift=6.0, flag_options={ "USE_MINI_HALOS": True, @@ -482,24 +495,24 @@ def test_run_lc_bad_inputs(default_user_params): with pytest.raises( ValueError, match="You must provide either redshift, perturb or lightconer" ): - wrapper.run_lightcone() + p21c.run_lightcone() with pytest.warns( DeprecationWarning, match="passing redshift directly is deprecated" ): - wrapper.run_lightcone(redshift=6.0, user_params=default_user_params) + p21c.run_lightcone(redshift=6.0, user_params=default_user_params) with pytest.raises( ValueError, match="If trying to minimize memory usage, you must be caching. Set write=True", ): - wrapper.run_lightcone( + p21c.run_lightcone( redshift=6.0, user_params={"MINIMIZE_MEMORY": True}, write=False, ) - lcn = wrapper.RectilinearLightconer.with_equal_redshift_slices( + lcn = p21c.RectilinearLightconer.with_equal_redshift_slices( min_redshift=6.0, max_redshift=7.0, resolution=0.1 * un.Mpc, @@ -509,7 +522,7 @@ def test_run_lc_bad_inputs(default_user_params): ValueError, match="The lightcone redshifts are not compatible with the given redshift.", ): - wrapper.run_lightcone( + p21c.run_lightcone( redshift=8.0, lightconer=lcn, ) @@ -517,17 +530,17 @@ def test_run_lc_bad_inputs(default_user_params): def test_lc_with_lightcone_filename(rectlcn, perturbed_field, tmpdirec): fname = tmpdirec / "lightcone.h5" - lc = wrapper.run_lightcone( + lc = p21c.run_lightcone( lightconer=rectlcn, perturb=perturbed_field, lightcone_filename=fname ) assert fname.exists() - lc_loaded = wrapper.LightCone.read(fname) + lc_loaded = p21c.LightCone.read(fname) assert lc_loaded == lc del lc_loaded # This one should NOT run anything. - lc2 = wrapper.run_lightcone( + lc2 = p21c.run_lightcone( lightconer=rectlcn, lightcone_filename=fname, perturb=perturbed_field ) assert lc2 == lc @@ -542,7 +555,7 @@ def test_lc_partial_eval(rectlcn, perturbed_field, tmpdirec, lc): with pytest.raises( ValueError, match="Returning before the final redshift requires caching" ): - wrapper.run_lightcone( + p21c.run_lightcone( lightconer=rectlcn, perturb=perturbed_field, lightcone_filename=fname, @@ -550,7 +563,7 @@ def test_lc_partial_eval(rectlcn, perturbed_field, tmpdirec, lc): write=False, ) - partial = wrapper.run_lightcone( + partial = p21c.run_lightcone( lightconer=rectlcn, perturb=perturbed_field, lightcone_filename=fname, @@ -563,7 +576,7 @@ def test_lc_partial_eval(rectlcn, perturbed_field, tmpdirec, lc): assert partial._current_redshift <= 20.0 assert partial._current_redshift > 15.0 - finished = wrapper.run_lightcone( + finished = p21c.run_lightcone( lightconer=rectlcn, perturb=perturbed_field, lightcone_filename=fname, @@ -577,7 +590,7 @@ def test_lc_partial_eval(rectlcn, perturbed_field, tmpdirec, lc): fl.attrs["current_redshift"] = 2 * partial._current_redshift with pytest.raises(IOError, match="No component boxes found at z"): - wrapper.run_lightcone( + p21c.run_lightcone( lightconer=rectlcn, perturb=perturbed_field, lightcone_filename=fname, @@ -588,15 +601,17 @@ def test_lc_pass_redshift_deprecation(rectlcn, ic): with pytest.warns( DeprecationWarning, match="passing redshift directly is deprecated" ): - wrapper.run_lightcone( - lightconer=rectlcn, redshift=rectlcn.lc_redshifts.min(), init_box=ic + p21c.run_lightcone( + lightconer=rectlcn, + redshift=rectlcn.lc_redshifts.min(), + initial_conditions=ic, ) def test_coeval_lowerz_than_photon_cons(ic): with pytest.raises(ValueError, match="You have passed a redshift"): - wrapper.run_coeval( - init_box=ic, + p21c.run_coeval( + initial_conditions=ic, redshift=2.0, flag_options={ "PHOTON_CONS_TYPE": 1, @@ -608,8 +623,8 @@ def test_coeval_lowerz_than_photon_cons(ic): def test_lc_lowerz_than_photon_cons(rectlcn, ic): with pytest.raises(ValueError, match="You have passed a redshift"): - wrapper.run_lightcone( - init_box=ic, + p21c.run_lightcone( + initial_conditions=ic, redshift=2.0, flag_options={ "PHOTON_CONS_TYPE": 1,