From 483cfd480cf5a63cb71eedbca78d15bf5be806a4 Mon Sep 17 00:00:00 2001 From: TillHae Date: Thu, 30 Oct 2025 15:16:00 +0100 Subject: [PATCH 01/32] training progress unit realignment from epoch to mini_epoch --- config/default_config.yml | 4 +- config/evaluate/eval_config.yml | 4 +- integration_tests/small1.yaml | 4 +- integration_tests/small1_test.py | 10 +-- .../common/src/weathergen/common/config.py | 44 +++++----- .../weathergen/evaluate/export_inference.py | 20 ++--- .../src/weathergen/evaluate/io_reader.py | 8 +- .../evaluate/src/weathergen/evaluate/utils.py | 12 +-- src/weathergen/datasets/masking.py | 2 +- .../datasets/multi_stream_data_sampler.py | 24 +++--- src/weathergen/datasets/tokenizer_forecast.py | 2 +- src/weathergen/datasets/tokenizer_masking.py | 2 +- src/weathergen/model/model.py | 8 +- src/weathergen/run_train.py | 10 +-- src/weathergen/train/lr_scheduler.py | 14 ++-- src/weathergen/train/trainer.py | 84 +++++++++---------- src/weathergen/utils/cli.py | 4 +- src/weathergen/utils/compare_run_configs.py | 6 +- src/weathergen/utils/train_logger.py | 6 +- src/weathergen/utils/validation_io.py | 4 +- tests/test_cli.py | 4 +- tests/test_config.py | 24 +++--- 22 files changed, 150 insertions(+), 150 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..116b63267 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -113,8 +113,8 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 32 -samples_per_epoch: 4096 +num_mini_epochs: 32 +samples_per_mini_epoch: 4096 samples_per_validation: 512 shuffle: True diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 0ecd0835f..3f436f736 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -30,7 +30,7 @@ run_ids : ar40mckx: label: "pretrained model ar40mckx" results_base_dir : "./results/" - epoch: 0 + mini_epoch: 0 rank: 0 streams: ERA5: @@ -61,7 +61,7 @@ run_ids : c8g5katp: label: "2 steps window" results_base_dir : "./results/" - epoch: 0 + mini_epoch: 0 rank: 0 streams: ERA5: diff --git a/integration_tests/small1.yaml b/integration_tests/small1.yaml index 2f42b6563..102dac3b2 100644 --- a/integration_tests/small1.yaml +++ b/integration_tests/small1.yaml @@ -3,8 +3,8 @@ run_path: "./results" model_path: "./models" loss_fcts: [["mse", 1.0]] loss_fcts_val: [["mse", 1.0]] -num_epochs: 1 -samples_per_epoch: 10 +num_mini_epochs: 1 +samples_per_mini_epoch: 10 samples_per_validation: 5 lr_steps: 4 lr_steps_warmup: 2 diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index 158af7722..47349b8af 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -69,7 +69,7 @@ def test_train(setup, test_run_id): def infer(run_id): logger.info("run inference") inference_from_args( - ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"] + ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"] + [ "--from_run_id", run_id, @@ -84,7 +84,7 @@ def infer(run_id): def infer_with_missing(run_id): logger.info("run inference") inference_from_args( - ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"] + ["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"] + [ "--from_run_id", run_id, @@ -128,7 +128,7 @@ def evaluate_results(run_id): } }, "label": "MTM ERA5", - "epoch": 0, + "mini_epoch": 0, "rank": 0, } }, @@ -170,7 +170,7 @@ def assert_train_loss_below_threshold(run_id): assert loss_metric is not None, ( "'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file" ) - # Check that the loss does not explode in a single epoch + # Check that the loss does not explode in a single mini_epoch # This is meant to be a quick test, not a convergence test target = 1.5 assert loss_metric < target, ( @@ -192,7 +192,7 @@ def assert_val_loss_below_threshold(run_id): assert loss_metric is not None, ( "'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file" ) - # Check that the loss does not explode in a single epoch + # Check that the loss does not explode in a single mini_epoch # This is meant to be a quick test, not a convergence test assert loss_metric < 1.25, ( f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25" diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index bdc4039fd..01167e7bc 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -54,23 +54,23 @@ def format_cf(config: Config) -> str: return stream.getvalue() -def save(config: Config, epoch: int | None): +def save(config: Config, mini_epoch: int | None): """Save current config into the current runs model directory.""" path_models = Path(config.model_path) # save in directory with model files dirname = path_models / config.run_id dirname.mkdir(exist_ok=True, parents=True) - fname = dirname / _get_model_config_file_name(config.run_id, epoch) + fname = dirname / _get_model_config_file_name(config.run_id, mini_epoch) json_str = json.dumps(OmegaConf.to_container(config)) with fname.open("w") as f: f.write(json_str) -def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> Config: +def load_model_config(run_id: str, mini_epoch: int | None, model_path: str | None) -> Config: """ - Load a configuration file from a given run_id and epoch. + Load a configuration file from a given run_id and mini_epoch. If run_id is a full path, loads it from the full path. """ if Path(run_id).exists(): # load from the full path if a full path is provided @@ -84,13 +84,13 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> config=pconf, attribute_name="model_path", fallback="models" ) path = Path(model_path) - fname = path / run_id / _get_model_config_file_name(run_id, epoch) + fname = path / run_id / _get_model_config_file_name(run_id, mini_epoch) assert fname.exists(), ( "The fallback path to the model does not exist. Please provide a `model_path`.", fname, ) - _logger.info(f"Loading config from specified run_id and epoch: {fname}") + _logger.info(f"Loading config from specified run_id and mini_epoch: {fname}") with fname.open() as f: json_str = f.read() @@ -100,22 +100,22 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> return _apply_fixes(config) -def _get_model_config_file_name(run_id: str, epoch: int | None): - if epoch is None: - epoch_str = "" - elif epoch == -1: - epoch_str = "_latest" +def _get_model_config_file_name(run_id: str, mini_epoch: int | None): + if mini_epoch is None: + mini_epoch_str = "" + elif mini_epoch == -1: + mini_epoch_str = "_latest" else: - epoch_str = f"_epoch{epoch:05d}" - return f"model_{run_id}{epoch_str}.json" + mini_epoch_str = f"_chkpt{mini_epoch:05d}" + return f"model_{run_id}{mini_epoch_str}.json" -def get_model_results(run_id: str, epoch: int, rank: int) -> Path: +def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: """ - Get the path to the model results zarr store from a given run_id and epoch. + Get the path to the model results zarr store from a given run_id and mini_epoch. """ run_results = Path(_load_private_conf(None)["path_shared_working_dir"]) / f"results/{run_id}" - zarr_path = run_results / f"validation_epoch{epoch:05d}_rank{rank:04d}.zarr" + zarr_path = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr" if not zarr_path.exists() or not zarr_path.is_dir(): raise FileNotFoundError(f"Zarr file {zarr_path} does not exist or is not a directory.") return zarr_path @@ -150,7 +150,7 @@ def _check_logging(config: Config) -> Config: def load_config( private_home: Path | None, from_run_id: str | None, - epoch: int | None, + mini_epoch: int | None, *overwrites: Path | dict | Config, ) -> Config: """ @@ -161,7 +161,7 @@ def load_config( private_home: Configuration file containing platform dependent information and secretes from_run_id: Run id of the pretrained WeatherGenerator model to continue training or inference - epoch: epoch of the checkpoint to load. -1 indicates last checkpoint available. + mini_epoch: mini_epoch of the checkpoint to load. -1 indicates last checkpoint available. *overwrites: Additional overwrites from different sources Note: The order of precendence for merging the final config is in ascending order: @@ -191,7 +191,7 @@ def load_config( if from_run_id is None: base_config = _load_default_conf() else: - base_config = load_model_config(from_run_id, epoch, private_config.get("model_path", None)) + base_config = load_model_config(from_run_id, mini_epoch, private_config.get("model_path", None)) from_run_id = base_config.run_id with open_dict(base_config): base_config.from_run_id = from_run_id @@ -456,9 +456,9 @@ def get_path_model(config: Config) -> Path: return Path(config.model_path) / config.run_id -def get_path_output(config: Config, epoch: int) -> Path: +def get_path_output(config: Config, mini_: int) -> Path: base_path = get_path_run(config) - fname = f"validation_epoch{epoch:05d}_rank{config.rank:04d}.zarr" + fname = f"validation_chkpt{mini_epoch:05d}_rank{config.rank:04d}.zarr" return base_path / fname @@ -523,7 +523,7 @@ def validate_forecast_policy_and_steps(cf: OmegaConf): valid_forecast_policies = ( "Valid values for 'forecast_policy' are, e.g., 'fixed' when using constant " "forecast steps throughout the training, or 'sequential' when varying the forecast " - "steps over epochs, such as, e.g., 'forecast_steps: [2, 2, 4, 4]'. " + "steps over mini_epochs, such as, e.g., 'forecast_steps: [2, 2, 4, 4]'. " ) valid_forecast_steps = ( "'forecast_steps' must be a positive integer or a non-empty list of positive integers. " diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index 2c0cb4243..81602fa47 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -363,8 +363,8 @@ def get_data_worker(args: tuple) -> xr.DataArray: ------- xarray DataArray for the specified sample and forecast step. """ - sample, fstep, run_id, stream, dtype, epoch, rank = args - fname_zarr = get_model_results(run_id, epoch, rank) + sample, fstep, run_id, stream, dtype, mini_epoch, rank = args + fname_zarr = get_model_results(run_id, mini_epoch, rank) with ZarrIO(fname_zarr) as zio: out = zio.get_data(sample, stream, fstep) if dtype == "target": @@ -383,7 +383,7 @@ def get_data( channels: list, fstep_hours: int, n_processes: list, - epoch: int, + mini_epoch: int, rank: int, output_dir: str, output_format: str, @@ -402,7 +402,7 @@ def get_data( fsteps : List of forecast steps to retrieve. If None, retrieves all available forecast steps. channels :List of channels to retrieve. If None, retrieves all available channels. n_processes : Number of parallel processes to use for data retrieval. - ecpoch : Epoch number to identify the Zarr store. + mini_epoch : Mini_epoch number to identify the Zarr store. rank : Rank number to identify the Zarr store. output_dir : Directory to save the NetCDF files. output_format : Output file format (currently only 'netcdf' supported). @@ -411,7 +411,7 @@ def get_data( if dtype not in ["target", "prediction"]: raise ValueError(f"Invalid type: {dtype}. Must be 'target' or 'prediction'.") - fname_zarr = get_model_results(run_id, epoch, rank) + fname_zarr = get_model_results(run_id, mini_epoch, rank) with ZarrIO(fname_zarr) as zio: zio_forecast_steps = sorted([int(step) for step in zio.forecast_steps]) zio_samples = sorted([int(sample) for sample in zio.samples]) @@ -430,7 +430,7 @@ def get_data( for sample_idx in tqdm(samples): da_fs = [] step_tasks = [ - (sample_idx, fstep, run_id, stream, dtype, epoch, rank) for fstep in fsteps + (sample_idx, fstep, run_id, stream, dtype, mini_epoch, rank) for fstep in fsteps ] for result in tqdm( pool.imap_unordered(get_data_worker, step_tasks, chunksize=1), @@ -627,10 +627,10 @@ def parse_args(args: list) -> argparse.Namespace: ) parser.add_argument( - "--epoch", + "--mini_epoch", type=int, default=0, - help="Epoch number to identify the Zarr store", + help="mini_epoch number to identify the Zarr store", ) parser.add_argument( @@ -673,7 +673,7 @@ def export_from_args(args: list) -> None: fstep_hours = np.timedelta64(args.fstep_hours, "h") channels = args.channels n_processes = args.n_processes - epoch = args.epoch + mini_epoch = args.mini_epoch rank = args.rank # Ensure output directory exists @@ -697,7 +697,7 @@ def export_from_args(args: list) -> None: channels, fstep_hours, n_processes, - epoch, + mini_epoch, rank, output_dir, output_format, diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 892191ad5..4c605fb20 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -304,7 +304,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non super().__init__(eval_cfg, run_id, private_paths) - self.epoch = eval_cfg.epoch + self.mini_epoch = eval_cfg.mini_epoch self.rank = eval_cfg.rank # Load model configuration and set (run-id specific) directories @@ -334,7 +334,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non ) self.fname_zarr = self.results_dir.joinpath( - f"validation_epoch{self.epoch:05d}_rank{self.rank:04d}.zarr" + f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" ) if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): @@ -356,12 +356,12 @@ def get_inference_config(self): _logger.info( f"Loading config for run {self.run_id} from private paths: {self.private_paths}" ) - config = load_config(self.private_paths, self.run_id, self.epoch) + config = load_config(self.private_paths, self.run_id, self.mini_epoch) else: _logger.info( f"Loading config for run {self.run_id} from model directory: {self.model_base_dir}" ) - config = load_model_config(self.run_id, self.epoch, self.model_base_dir) + config = load_model_config(self.run_id, self.mini_epoch, self.model_base_dir) if type(config) not in [dict, oc.DictConfig]: _logger.warning("Model config not found. inference config will be empty.") diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index ce3c3545c..beb9bfc3d 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -321,8 +321,8 @@ def metric_list_to_json( Output directory. run_id : Identifier of the inference run. - epoch : - Epoch number. + mini_epoch : + Mini_epoch number. """ assert len(metrics_list) == len(npoints_sample_list) == len(streams), ( "The lengths of metrics_list, npoints_sample_list, and streams must be the same." @@ -346,7 +346,7 @@ def metric_list_to_json( # Match the expected filename pattern save_path = ( reader.metrics_dir - / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" + / f"{reader.run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json" ) _logger.info(f"Saving results to {save_path}") @@ -354,13 +354,13 @@ def metric_list_to_json( json.dump(metric_dict, f, indent=4) _logger.info( - f"Saved all results of inference run {reader.run_id} - epoch {reader.epoch:d} successfully to {reader.metrics_dir}." + f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} successfully to {reader.metrics_dir}." ) def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): """ - Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file. + Retrieve the score for a given run, stream, metric, mini_epoch, and rank from a JSON file. Parameters ---------- @@ -380,7 +380,7 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: """ score_path = ( Path(reader.metrics_dir) - / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" + / f"{reader.run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 58d5d5731..2e89ea04f 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -96,7 +96,7 @@ def __init__(self, cf: Config): def reset_rng(self, rng) -> None: """ - Reset rng after epoch to ensure proper randomization + Reset rng after mini_epoch to ensure proper randomization """ self.rng = rng diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index ca5ee6601..bbc256ac2 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -86,7 +86,7 @@ def __init__( start_date_, end_date_, batch_size, - samples_per_epoch, + samples_per_mini_epoch, stage: Stage, shuffle=True, ): @@ -194,7 +194,7 @@ def __init__( index_range = self.time_window_handler.get_index_range() self.len = int(index_range.end - index_range.start) - self.len = min(self.len, samples_per_epoch if samples_per_epoch else self.len) + self.len = min(self.len, samples_per_mini_epoch if samples_per_mini_epoch else self.len) # adjust len to split loading across all workers and ensure it is multiple of batch_size len_chunk = ((self.len // cf.world_size) // batch_size) * batch_size self.len = min(self.len, len_chunk) @@ -236,14 +236,14 @@ def __init__( else: assert False, f"Unsupported training mode: {cf.training_mode}" - self.epoch = 0 + self.mini_epoch = 0 ################################################### def advance(self): """ - Advance epoch (this is applied to the template for the worker processes) + Advance mini_epoch (this is applied to the template for the worker processes) """ - self.epoch += 1 + self.mini_epoch += 1 ################################################### def get_sources_size(self): @@ -278,17 +278,17 @@ def reset(self): self.rng = np.random.default_rng(self.data_loader_rng_seed) fsm = ( - self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)] + self.forecast_steps[min(self.mini_epoch, len(self.forecast_steps) - 1)] if self.forecast_policy != "random" else self.forecast_steps.max() ) if fsm > 0: - logger.info(f"forecast_steps at epoch={self.epoch} : {fsm}") + logger.info(f"forecast_steps at mini_epoch={self.mini_epoch} : {fsm}") # data index_range = self.time_window_handler.get_index_range() idx_end = index_range.end - # native length of datasets, independent of epoch length that has potentially been specified + # native length of datasets, independent of mini_epoch length that has potentially been specified forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs idx_end -= forecast_len + self.forecast_offset assert idx_end > 0, "dataset size too small for forecast range" @@ -466,17 +466,17 @@ def worker_workset(self): iter_end = len(self) else: - # ensure the rng seed is fully unique across workers and epochs + # ensure the rng seed is fully unique across workers and mini_epochs # the worker processes are generated as bit-wise copy of the "template" (the actual # instance of the present class that is created) whenever __iter__ is started. This - # happens for each epoch, for train and validation, and independently for each DDP + # happens for each mini_epoch, for train and validation, and independently for each DDP # worker. After the bit-wise copy, the rng seed needs to be made unique for - # DDP workers, loader process, epoch. + # DDP workers, loader process, mini_epoch. dist = torch.distributed self.data_loader_rng_seed *= ( (((dist.get_rank() + 1) * 73) if dist.is_initialized() else 1) * ((worker_info.id + 1) * 37) - * (self.epoch + 13) + * (self.mini_epoch + 13) * 7 ) # split workload diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index d54831265..c52d77790 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -29,7 +29,7 @@ class TokenizerForecast(Tokenizer): def reset_rng(self, rng) -> None: """ - Reset rng after epoch to ensure proper randomization + Reset rng after mini_epoch to ensure proper randomization """ self.rng = rng diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 548b52124..8cc3de2f5 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -34,7 +34,7 @@ def __init__(self, healpix_level: int, masker: Masker): def reset_rng(self, rng) -> None: """ - Reset rng after epoch to ensure proper randomization + Reset rng after mini_epoch to ensure proper randomization """ self.masker.reset_rng(rng) self.rng = rng diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 13c462a6f..6173316a3 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -549,16 +549,16 @@ def rename_old_state_dict(self, params: dict) -> dict: return new_params ######################################### - def load(self, run_id: str, epoch: str = -1) -> None: + def load(self, run_id: str, mini_epoch: str = -1) -> None: """Loads model state from checkpoint and checks for missing and unused keys. Args: run_id : model_id of the trained model - epoch : The epoch to load. Default (-1) is the latest epoch + mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ path_run = Path(self.cf.model_path) / run_id - epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" - filename = f"{run_id}_{epoch_id}.chkpt" + mini_epoch_id = f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + filename = f"{run_id}_{mini_epoch_id}.chkpt" params = torch.load( path_run / filename, map_location=torch.device("cpu"), weights_only=True diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index eb2cab895..b1b9b9726 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -54,7 +54,7 @@ def inference_from_args(argl: list[str]): cf = config.load_config( args.private_config, args.from_run_id, - args.epoch, + args.mini_epoch, *args.config, inference_overwrite, cli_overwrite, @@ -71,7 +71,7 @@ def inference_from_args(argl: list[str]): cf.run_history += [(args.from_run_id, cf.istep)] trainer = Trainer(cf.train_log_freq) - trainer.inference(cf, devices, args.from_run_id, args.epoch) + trainer.inference(cf, devices, args.from_run_id, args.mini_epoch) #################################################################################################### @@ -113,7 +113,7 @@ def train_continue_from_args(argl: list[str]): lr_policy_warmup="cosine", lr_policy_decay="linear", lr_policy_cooldown="linear", - num_epochs=12, # len(cf.forecast_steps) + 4 + num_mini_epochs=12, # len(cf.forecast_steps) + 4 istep=0, ) else: @@ -123,7 +123,7 @@ def train_continue_from_args(argl: list[str]): cf = config.load_config( args.private_config, args.from_run_id, - args.epoch, + args.mini_epoch, finetune_overwrite, *args.config, cli_overwrite, @@ -139,7 +139,7 @@ def train_continue_from_args(argl: list[str]): cf.run_history += [(args.from_run_id, cf.istep)] trainer = Trainer(cf.train_log_freq) - trainer.run(cf, devices, args.from_run_id, args.epoch) + trainer.run(cf, devices, args.from_run_id, args.mini_epoch) #################################################################################################### diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e64516a5d..e5421d0e0 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -154,7 +154,7 @@ def __init__( self.i_step = 0 self.lr = self.cur_scheduler.get_last_lr() - # advance manually to step_contd (last_epoch parameter for schedulers is not working and + # advance manually to step_contd (last_mini_epoch parameter for schedulers is not working and # this is also more brittle with the different phases) # optimizer.step() as required by torch; # won't have a material effect since grads are zero at this point @@ -218,8 +218,8 @@ def plot(): Use as LearningRateScheduler.plot() """ - num_epochs = 42 - num_samples_per_epoch = 4096 + num_mini_epochs = 42 + num_samples_per_mini_epoch = 4096 lr_start = 0.000001 lr_max = 0.000015 @@ -245,7 +245,7 @@ def plot(): lr_final_decay, lr_final, lr_steps_warmup, - num_epochs * num_samples_per_epoch, + num_mini_epochs * num_samples_per_mini_epoch, lr_steps_cooldown, lr_policy_warmup, lr_policy_decay, @@ -254,7 +254,7 @@ def plot(): lrs = [] for _ in range( - num_epochs * num_samples_per_epoch + lr_steps_warmup + lr_steps_cooldown + 1023 + num_mini_epochs * num_samples_per_mini_epoch + lr_steps_warmup + lr_steps_cooldown + 1023 ): optimizer.step() lrs.append(optimizer.param_groups[0]["lr"]) @@ -279,7 +279,7 @@ def plot(): lr_final_decay, lr_final, lr_steps_warmup, - num_epochs * num_samples_per_epoch, + num_mini_epochs * num_samples_per_mini_epoch, lr_steps_cooldown, lr_policy_warmup, lr_policy_decay, @@ -288,7 +288,7 @@ def plot(): lrs = [] for _ in range( - num_epochs * num_samples_per_epoch + lr_steps_warmup + lr_steps_cooldown + 1023 + num_mini_epochs * num_samples_per_mini_epoch + lr_steps_warmup + lr_steps_cooldown + 1023 ): optimizer.step() lrs.append(optimizer.param_groups[0]["lr"]) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..0ed3fe36d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -78,7 +78,7 @@ def init(self, cf: Config, devices): self.freeze_modules = cf.get("freeze_modules", "") - assert cf.samples_per_epoch % cf.batch_size_per_gpu == 0 + assert cf.samples_per_mini_epoch % cf.batch_size_per_gpu == 0 assert cf.samples_per_validation % cf.batch_size_validation_per_gpu == 0 config.validate_forecast_policy_and_steps(cf=cf) @@ -100,7 +100,7 @@ def init(self, cf: Config, devices): self.init_perf_monitoring() self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) - def inference(self, cf, devices, run_id_trained, epoch): + def inference(self, cf, devices, run_id_trained, mini_epoch): # general initalization self.init(cf, devices) @@ -140,21 +140,21 @@ def inference(self, cf, devices, run_id_trained, epoch): self.model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() self.model = self.model.to(self.devices[0]) - self.model.load(run_id_trained, epoch) - logger.info(f"Loaded model {run_id_trained} at epoch {epoch}.") + self.model.load(run_id_trained, mini_epoch) + logger.info(f"Loaded model {run_id_trained} at mini_epoch {mini_epoch}.") self.model_params = ModelParams(cf).create(cf) self.model_params = self.model_params.to(self.devices[0]) - logger.info(f"Loaded model id={run_id_trained} at epoch={epoch}.") + logger.info(f"Loaded model id={run_id_trained} at mini_epoch={mini_epoch}.") self.loss_calculator_val = LossCalculator(cf=cf, stage=VAL, device=self.devices[0]) if is_root(): - config.save(self.cf, epoch=0) + config.save(self.cf, mini_epoch=0) logger.info(f"Starting inference with id={self.cf.run_id}.") # inference validation set - self.validate(epoch=0) + self.validate(mini_epoch=0) logger.info(f"Finished inference run with id: {cf.run_id}") def init_model_and_shard(self, cf, devices): @@ -254,7 +254,7 @@ def init_model_and_shard(self, cf, devices): return model, model_params - def run(self, cf, devices, run_id_contd=None, epoch_contd=None): + def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): # general initalization self.init(cf, devices) cf = self.cf @@ -268,7 +268,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): cf.start_date, cf.end_date, cf.batch_size_per_gpu, - cf.samples_per_epoch, + cf.samples_per_mini_epoch, stage=TRAIN, shuffle=cf.shuffle, ) @@ -301,8 +301,8 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): self.model.reset_parameters() else: if is_root(): - logger.info(f"Continuing run with id={self.cf.from_run_id} at epoch {epoch_contd}.") - self.load_model(self.cf.from_run_id, epoch_contd) + logger.info(f"Continuing run with id={self.cf.from_run_id} at mini_epoch {mini_epoch_contd}.") + self.load_model(self.cf.from_run_id, mini_epoch_contd) if is_root(): logger.info(f"Loaded model id={run_id_contd}.") self.model_params.reset_parameters(cf) @@ -351,7 +351,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): # lr is updated after each batch so account for this # TODO: conf should be read-only, do not modify the conf in flight - cf.lr_steps = int((len(self.dataset) * cf.num_epochs) / cf.batch_size_per_gpu) + cf.lr_steps = int((len(self.dataset) * cf.num_mini_epochs) / cf.batch_size_per_gpu) steps_decay = cf.lr_steps - cf.lr_steps_warmup - cf.lr_steps_cooldown if is_root(): @@ -400,15 +400,15 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): self.loss_calculator = LossCalculator(cf=cf, stage=TRAIN, device=self.device) self.loss_calculator_val = LossCalculator(cf=cf, stage=VAL, device=self.device) - # recover epoch when continuing run + # recover mini_epoch when continuing run if self.world_size_original is None: - epoch_base = int(self.cf.istep / len(self.data_loader)) + mini_epoch_base = int(self.cf.istep / len(self.data_loader)) else: len_per_rank = ( len(self.dataset) // (self.world_size_original * cf.batch_size_per_gpu) ) * cf.batch_size_per_gpu - epoch_base = int( - self.cf.istep / (min(len_per_rank, cf.samples_per_epoch) * self.world_size_original) + mini_epoch_base = int( + self.cf.istep / (min(len_per_rank, cf.samples_per_mini_epoch) * self.world_size_original) ) # torch.autograd.set_detect_anomaly(True) @@ -425,18 +425,18 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): if cf.val_initial: self.validate(-1) - for epoch in range(epoch_base, cf.num_epochs): - logger.info(f"Epoch {epoch} of {cf.num_epochs}: train.") - self.train(epoch) + for mini_epoch in range(mini_epoch_base, cf.num_mini_epochs): + logger.info(f"Mini_epoch {mini_epoch} of {cf.num_mini_epochs}: train.") + self.train(mini_epoch) - logger.info(f"Epoch {epoch} of {cf.num_epochs}: validate.") - self.validate(epoch) + logger.info(f"Mini_epoch {mini_epoch} of {cf.num_mini_epochs}: validate.") + self.validate(mini_epoch) - logger.info(f"Epoch {epoch} of {cf.num_epochs}: save_model.") - self.save_model(epoch) + logger.info(f"Mini_epoch {mini_epoch} of {cf.num_mini_epochs}: save_model.") + self.save_model(mini_epoch) # log final model - self.save_model(cf.num_epochs) + self.save_model(cf.num_mini_epochs) ########################################### def _prepare_logging( @@ -564,7 +564,7 @@ def _prepare_logging( targets_lens, ) - def train(self, epoch): + def train(self, mini_epoch): cf = self.cf self.model.train() # torch.autograd.set_detect_anomaly(True) @@ -640,7 +640,7 @@ def train(self, epoch): self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() - self._log_terminal(bidx, epoch, TRAIN) + self._log_terminal(bidx, mini_epoch, TRAIN) if bidx % self.train_log_freq.metrics == 0: self._log(TRAIN) @@ -652,7 +652,7 @@ def train(self, epoch): self.dataset.advance() - def validate(self, epoch): + def validate(self, mini_epoch): cf = self.cf self.model.eval() @@ -706,7 +706,7 @@ def validate(self, epoch): sources = [[item.source_raw for item in b] for b in batch[0]] write_output( self.cf, - epoch, + mini_epoch, bidx, sources, preds_all, @@ -728,7 +728,7 @@ def validate(self, epoch): pbar.update(self.cf.batch_size_validation_per_gpu) - self._log_terminal(bidx, epoch, VAL) + self._log_terminal(bidx, mini_, VAL) self._log(VAL) # avoid that there is a systematic bias in the validation subset @@ -745,16 +745,16 @@ def batch_to_device(self, batch): [[b.to(self.device) for b in bf] for bf in batch[2]], ) - def load_model(self, run_id: str, epoch=-1): + def load_model(self, run_id: str, mini_epoch=-1): """Loads model state from checkpoint and checks for missing and unused keys. Args: run_id : model_id of the trained model - epoch : The epoch to load. Default (-1) is the latest epoch + mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ path_run = Path(self.cf.model_path) / run_id - epoch_id = f"epoch{epoch:05d}" if epoch != -1 and epoch is not None else "latest" - filename = f"{run_id}_{epoch_id}.chkpt" + mini_epoch_id = f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + filename = f"{run_id}_{mini_epoch_id}.chkpt" params = torch.load( path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True @@ -911,10 +911,10 @@ def _get_full_optimizer_state_dict(self): else: return {} - def save_model(self, epoch: int, name=None): - # Saving at epoch == max_epoch means that we are saving the latest checkpoint. - max_epoch = self.cf.num_epochs - assert epoch <= max_epoch, (epoch, max_epoch) + def save_model(self, mini_epoch: int, name=None): + # Saving at mini_epoch == max_mini_epoch means that we are saving the latest checkpoint. + max_mini_epoch = self.cf.num_mini_epochs + assert mini_epoch <= max_mini_epoch, (mini_epoch, max_mini_epoch) model_state_dict = self._get_full_model_state_dict() # optim_state_dict = self._get_full_optimizer_state_dict() @@ -923,7 +923,7 @@ def save_model(self, epoch: int, name=None): [ self.cf.run_id, "_", - "latest" if epoch == -1 else f"epoch{epoch:05d}", + "latest" if mini_epoch == -1 else f"chkpt{mini_epoch:05d}", ("_" + name) if name is not None else "", ] ) @@ -938,7 +938,7 @@ def save_model(self, epoch: int, name=None): logger.info(f"Saved model to {file_out}") # save config - config.save(self.cf, epoch) + config.save(self.cf, mini_epoch) def _prepare_losses_for_logging( self, @@ -1024,7 +1024,7 @@ def _log_instant_grad_norms(self, stage: Stage): if is_root(): self.train_logger.log_metrics(stage, grad_norms) - def _log_terminal(self, bidx: int, epoch: int, stage: Stage): + def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): print_freq = self.train_log_freq.terminal if bidx % print_freq == 0 and bidx > 0 or stage == VAL: # compute from last iteration @@ -1033,7 +1033,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): if is_root(): if stage == VAL: logger.info( - f"validation ({self.cf.run_id}) : {epoch:03d} : {avg_loss.nanmean().item()}" + f"validation ({self.cf.run_id}) : {mini_epoch:03d} : {avg_loss.nanmean().item()}" ) for _, st in enumerate(self.cf.streams): logger.info( @@ -1047,7 +1047,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): dt = time.time() - self.t_start len_dataset = len(self.data_loader) // self.cf.batch_size_per_gpu pstr = ( - f"{epoch:03d} : {bidx:05d}/{len_dataset:05d} : " + f"{mini_epoch:03d} : {bidx:05d}/{len_dataset:05d} : " + f"{self.cf.istep:06d} : loss = {avg_loss.nanmean().item():.4E} " + f"(lr={self.lr_scheduler.get_lr():.2E}, " ) diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index 9e7b8f562..1fe6441c4 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -128,11 +128,11 @@ def _add_model_loading_params(parser: argparse.ArgumentParser): ) parser.add_argument( "-e", - "--epoch", + "--mini_epoch", type=int, default=-1, help=( - "Epoch of pretrained WeatherGenerator model used" + "Mini_epoch of pretrained WeatherGenerator model used" " (Default -1 corresponds to the last checkpoint)." ), ) diff --git a/src/weathergen/utils/compare_run_configs.py b/src/weathergen/utils/compare_run_configs.py index 76f8a199c..28eade1cf 100755 --- a/src/weathergen/utils/compare_run_configs.py +++ b/src/weathergen/utils/compare_run_configs.py @@ -217,13 +217,13 @@ def main(): logger.info(f"Loading config for run_id: {run_id} from {path}") try: - cfg = load_model_config(run_id=run_id, epoch=None, model_path=path) + cfg = load_model_config(run_id=run_id, mini_epoch=None, model_path=path) except Exception: logger.warning( f"Failed to load config for run_id: {run_id} from {path}", - "Assuming epoch=0 and retrying.", + "Assuming mini_epoch=0 and retrying.", ) - cfg = load_model_config(run_id=run_id, epoch=0, model_path=path) + cfg = load_model_config(run_id=run_id, mini_epoch=0, model_path=path) actual_run_id = cfg.get("run_id", run_id) # Process streams and flatten diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 3caea6839..4c467f8d0 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -181,15 +181,15 @@ def add_val( ####################################### @staticmethod - def read(run_id: str, model_path: str = None, epoch: int = -1) -> Metrics: + def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: """ Read data for run_id """ # Load config from given model_path if provided, otherwise use path from private config if model_path: - cf = config.load_model_config(run_id=run_id, epoch=epoch, model_path=model_path) + cf = config.load_model_config(run_id=run_id, mini_epoch=mini_epoch, model_path=model_path) else: - cf = config.load_config(private_home=None, from_run_id=run_id, epoch=epoch) + cf = config.load_config(private_home=None, from_run_id=run_id, mini_epoch=mini_epoch) run_id = cf.run_id result_dir_base = Path(cf.run_path) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e28563132..2c46c710b 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -17,7 +17,7 @@ def write_output( cf, - epoch, + mini_epoch, batch_idx, sources, preds_all, @@ -63,6 +63,6 @@ def write_output( cf.forecast_offset, ) - with io.ZarrIO(config.get_path_output(cf, epoch)) as writer: + with io.ZarrIO(config.get_path_output(cf, mini_epoch)) as writer: for subset in data.items(): writer.write_zarr(subset) diff --git a/tests/test_cli.py b/tests/test_cli.py index e39535c31..1ee451dd9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -6,7 +6,7 @@ DATE_FORMATS = ["2022-12-01T00:00:00", "20221201", "2022-12-01", "12.01.2022"] EXPECTED_DATE_STR = "202212010000" -MODEL_LOADING_ARGS = ["from_run_id", "epoch", "reuse_run_id"] +MODEL_LOADING_ARGS = ["from_run_id", "mini_epoch", "reuse_run_id"] GENERAL_ARGS = ["config", "private_config", "options", "run_id"] MODEL_LOADING_PARSERS = [cli.get_continue_parser(), cli.get_inference_parser()] BASIC_ARGLIST = ["--from_run_id", "test123"] @@ -80,7 +80,7 @@ def test_inference_defaults(inference_parser): "end_date", "samples", "analysis_streams_output", - "epoch", + "mini_epoch", "private_config", ] default_values = [inference_parser.get_default(arg) for arg in default_args] diff --git a/tests/test_config.py b/tests/test_config.py index 82802a59b..c1ed19569 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,7 +19,7 @@ }, } -DUMMY_OVERWRITES = [("num_epochs", 42), ("healpix_level", 42)] +DUMMY_OVERWRITES = [("num_mini_epochs", 42), ("healpix_level", 42)] DUMMY_STREAM_CONF = { "ERA5": { @@ -205,16 +205,16 @@ def test_load_multiple_overwrites(private_config_file): assert contains(cf, expected) -@pytest.mark.parametrize("epoch", [None, 0, 1, 2, -1]) -def test_load_existing_config(epoch, private_config_file, config_fresh): - test_num_epochs = 3000 +@pytest.mark.parametrize("mini_epoch", [None, 0, 1, 2, -1]) +def test_load_existing_config(mini_epoch, private_config_file, config_fresh): + test_num_mini_epochs = 3000 - config_fresh.num_epochs = test_num_epochs # some specific change - config.save(config_fresh, epoch) + config_fresh.num_mini_epochs = test_num_mini_epochs # some specific change + config.save(config_fresh, mini_epoch) - cf = config.load_config(private_config_file, config_fresh.run_id, epoch) + cf = config.load_config(private_config_file, config_fresh.run_id, mini_epoch) - assert cf.num_epochs == test_num_epochs + assert cf.num_mini_epochs == test_num_mini_epochs @pytest.mark.parametrize("options,cf", [(["foo=1", "bar=2"], {"foo": 1, "bar": 2}), ([], {})]) @@ -317,9 +317,9 @@ def test_load_duplicate_streams_same_file(streams_dir): config.load_streams(streams_dir) -@pytest.mark.parametrize("epoch", [None, 0, 1, 2, -1]) # maybe add -5 as test case -def test_save(epoch, config_fresh): - config.save(config_fresh, epoch) +@pytest.mark.parametrize("mini_epoch", [None, 0, 1, 2, -1]) # maybe add -5 as test case +def test_save(mini_epoch, config_fresh): + config.save(config_fresh, mini_epoch) - cf = config.load_model_config(config_fresh.run_id, epoch, config_fresh.model_path) + cf = config.load_model_config(config_fresh.run_id, mini_epoch, config_fresh.model_path) assert is_equal(cf, config_fresh) From ad7d5ede341d75f4a4b9dc932fa4fe3ec83c7852 Mon Sep 17 00:00:00 2001 From: TillHae Date: Fri, 31 Oct 2025 13:05:06 +0100 Subject: [PATCH 02/32] small naming fix of mini_epoch --- src/weathergen/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0ed3fe36d..83e289e20 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -728,7 +728,7 @@ def validate(self, mini_epoch): pbar.update(self.cf.batch_size_validation_per_gpu) - self._log_terminal(bidx, mini_, VAL) + self._log_terminal(bidx, mini_epoch, VAL) self._log(VAL) # avoid that there is a systematic bias in the validation subset From 3851147948c35d4a1fd55a6832e37f79f77489d6 Mon Sep 17 00:00:00 2001 From: TillHae Date: Fri, 31 Oct 2025 13:13:25 +0100 Subject: [PATCH 03/32] ruffed --- packages/common/src/weathergen/common/config.py | 6 ++++-- .../src/weathergen/evaluate/export_inference.py | 6 ++++-- src/weathergen/model/model.py | 4 +++- src/weathergen/train/lr_scheduler.py | 10 ++++++++-- src/weathergen/train/trainer.py | 11 ++++++++--- src/weathergen/utils/train_logger.py | 4 +++- 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 01167e7bc..195f7394c 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -191,7 +191,9 @@ def load_config( if from_run_id is None: base_config = _load_default_conf() else: - base_config = load_model_config(from_run_id, mini_epoch, private_config.get("model_path", None)) + base_config = load_model_config( + from_run_id, mini_epoch, private_config.get("model_path", None) + ) from_run_id = base_config.run_id with open_dict(base_config): base_config.from_run_id = from_run_id @@ -456,7 +458,7 @@ def get_path_model(config: Config) -> Path: return Path(config.model_path) / config.run_id -def get_path_output(config: Config, mini_: int) -> Path: +def get_path_output(config: Config, mini_epoch: int) -> Path: base_path = get_path_run(config) fname = f"validation_chkpt{mini_epoch:05d}_rank{config.rank:04d}.zarr" diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index 81602fa47..3f67c37d9 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -61,6 +61,7 @@ def detect_grid_type(input_data_array: xr.DataArray) -> str: # Otherwise it's Gaussian (irregular spacing or reduced grid) return "gaussian" + def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: """ Find all the pressure levels for each variable using regex and returns a dictionary @@ -90,6 +91,7 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: pl = list(set(pl)) return var_dict, pl + def reshape_dataset_adaptive(input_data_array: xr.DataArray) -> xr.Dataset: """ Reshape dataset while preserving grid structure (regular or Gaussian). @@ -176,8 +178,6 @@ def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) -> return ds - - def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: """ Add CF conventions to the dataset attributes. @@ -201,6 +201,7 @@ def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: ds.attrs["Conventions"] = "CF-1.12" return ds + def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: """ Modified CF parser that handles both regular and Gaussian grids. @@ -323,6 +324,7 @@ def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: return dataset + def output_filename( prefix: str, run_id: str, diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 6173316a3..34e320e97 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -557,7 +557,9 @@ def load(self, run_id: str, mini_epoch: str = -1) -> None: """ path_run = Path(self.cf.model_path) / run_id - mini_epoch_id = f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + mini_epoch_id = ( + f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + ) filename = f"{run_id}_{mini_epoch_id}.chkpt" params = torch.load( diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index e5421d0e0..60c8d4fd5 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -254,7 +254,10 @@ def plot(): lrs = [] for _ in range( - num_mini_epochs * num_samples_per_mini_epoch + lr_steps_warmup + lr_steps_cooldown + 1023 + num_mini_epochs * num_samples_per_mini_epoch + + lr_steps_warmup + + lr_steps_cooldown + + 1023 ): optimizer.step() lrs.append(optimizer.param_groups[0]["lr"]) @@ -288,7 +291,10 @@ def plot(): lrs = [] for _ in range( - num_mini_epochs * num_samples_per_mini_epoch + lr_steps_warmup + lr_steps_cooldown + 1023 + num_mini_epochs * num_samples_per_mini_epoch + + lr_steps_warmup + + lr_steps_cooldown + + 1023 ): optimizer.step() lrs.append(optimizer.param_groups[0]["lr"]) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 83e289e20..439150f4d 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -301,7 +301,9 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): self.model.reset_parameters() else: if is_root(): - logger.info(f"Continuing run with id={self.cf.from_run_id} at mini_epoch {mini_epoch_contd}.") + logger.info( + f"Continuing run with id={self.cf.from_run_id} at mini_epoch {mini_epoch_contd}." + ) self.load_model(self.cf.from_run_id, mini_epoch_contd) if is_root(): logger.info(f"Loaded model id={run_id_contd}.") @@ -408,7 +410,8 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): len(self.dataset) // (self.world_size_original * cf.batch_size_per_gpu) ) * cf.batch_size_per_gpu mini_epoch_base = int( - self.cf.istep / (min(len_per_rank, cf.samples_per_mini_epoch) * self.world_size_original) + self.cf.istep + / (min(len_per_rank, cf.samples_per_mini_epoch) * self.world_size_original) ) # torch.autograd.set_detect_anomaly(True) @@ -753,7 +756,9 @@ def load_model(self, run_id: str, mini_epoch=-1): """ path_run = Path(self.cf.model_path) / run_id - mini_epoch_id = f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + mini_epoch_id = ( + f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" + ) filename = f"{run_id}_{mini_epoch_id}.chkpt" params = torch.load( diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 4c467f8d0..8c09dfcfd 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -187,7 +187,9 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: """ # Load config from given model_path if provided, otherwise use path from private config if model_path: - cf = config.load_model_config(run_id=run_id, mini_epoch=mini_epoch, model_path=model_path) + cf = config.load_model_config( + run_id=run_id, mini_epoch=mini_epoch, model_path=model_path + ) else: cf = config.load_config(private_home=None, from_run_id=run_id, mini_epoch=mini_epoch) run_id = cf.run_id From 54115d823b00298dcb5dbddf61ed8b2d1253a7a1 Mon Sep 17 00:00:00 2001 From: TillHae Date: Fri, 31 Oct 2025 13:18:31 +0100 Subject: [PATCH 04/32] linted --- src/weathergen/datasets/multi_stream_data_sampler.py | 3 ++- src/weathergen/train/lr_scheduler.py | 4 ++-- src/weathergen/train/trainer.py | 6 ++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index bbc256ac2..daafd2a25 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -288,7 +288,8 @@ def reset(self): # data index_range = self.time_window_handler.get_index_range() idx_end = index_range.end - # native length of datasets, independent of mini_epoch length that has potentially been specified + # native length of datasets, independent of mini_epoch length that has potentially been + # specified forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs idx_end -= forecast_len + self.forecast_offset assert idx_end > 0, "dataset size too small for forecast range" diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index 60c8d4fd5..9220a30ad 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -154,8 +154,8 @@ def __init__( self.i_step = 0 self.lr = self.cur_scheduler.get_last_lr() - # advance manually to step_contd (last_mini_epoch parameter for schedulers is not working and - # this is also more brittle with the different phases) + # advance manually to step_contd (last_mini_epoch parameter for schedulers is not working + # and this is also more brittle with the different phases) # optimizer.step() as required by torch; # won't have a material effect since grads are zero at this point if self.step_contd > 0: diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 439150f4d..a792c67c7 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -302,7 +302,8 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): else: if is_root(): logger.info( - f"Continuing run with id={self.cf.from_run_id} at mini_epoch {mini_epoch_contd}." + f"""Continuing run with id={self.cf.from_run_id} at mini_epoch + {mini_epoch_contd}.""" ) self.load_model(self.cf.from_run_id, mini_epoch_contd) if is_root(): @@ -1038,7 +1039,8 @@ def _log_terminal(self, bidx: int, mini_epoch: int, stage: Stage): if is_root(): if stage == VAL: logger.info( - f"validation ({self.cf.run_id}) : {mini_epoch:03d} : {avg_loss.nanmean().item()}" + f"""validation ({self.cf.run_id}) : {mini_epoch:03d} : + {avg_loss.nanmean().item()}""" ) for _, st in enumerate(self.cf.streams): logger.info( From 99173e3d97d39df218e34111e178e92dbcf03acd Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 30 Oct 2025 15:56:51 +0100 Subject: [PATCH 05/32] Fix out of bounds in data_reader_obs (#1180) * fix out of bounds access * Adding comment * Removed debgu --- src/weathergen/datasets/data_reader_obs.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/data_reader_obs.py b/src/weathergen/datasets/data_reader_obs.py index ca78c8195..5fb2b7147 100644 --- a/src/weathergen/datasets/data_reader_obs.py +++ b/src/weathergen/datasets/data_reader_obs.py @@ -179,20 +179,20 @@ def _setup_sample_index(self) -> None: ) * self.indices_start[-1], ) + self.indices_end = np.append( self.indices_end, np.ones( - (diff_in_hours_end - self.hrly_index.shape[0] - 1) // step_hrs, dtype=int + # add (len_hrs + 1) since above we also have diff_in_hours_start + len_hrs + (diff_in_hours_end - self.hrly_index.shape[0] + (len_hrs + 1)) // step_hrs, + dtype=int, ) * self.indices_end[-1], ) - # Prevent -1 in samples before the we have data + # Prevent -1 in samples before we have data self.indices_end = np.maximum(self.indices_end, 0) - if self.indices_end.shape != self.indices_start.shape: - self.indices_end = np.append(self.indices_end, self.indices_end[-1]) - # If end (yyyymmddhhmm) is not a multiple of len_hrs # truncate the last sample so that it doesn't go beyond the requested dataset end date self.indices_end = np.minimum(self.indices_end, self.hrly_index[end_range_1]) From c6df83297304ee428e134208b773199e3b488cb0 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 31 Oct 2025 12:28:43 +0100 Subject: [PATCH 06/32] Fixed to use forward function for forecast engine (#1188) * Fixed to use forward function for forecast engine, and also fstep for conditioning * Fixed missing return statement --- src/weathergen/model/engines.py | 9 +++++---- src/weathergen/model/model.py | 4 +--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 3628ff2aa..3351dabc4 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -376,10 +376,11 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, use_reentrant): - for it, block in enumerate(self.fe_blocks): - aux_info = torch.tensor([it], dtype=torch.float32, device="cuda") - tokens = checkpoint(block, tokens, aux_info, use_reentrant=use_reentrant) + def forward(self, tokens, fstep): + aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") + for block in self.fe_blocks: + tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 34e320e97..f5b2bac00 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -806,9 +806,7 @@ def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) ValueError: For unexpected arguments in checkpoint method """ - for block in self.forecast_engine.fe_blocks: - aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + tokens = self.forecast_engine(tokens, fstep) return tokens From 2d412793cbce17d3e05c670d77de9c3ee02045ac Mon Sep 17 00:00:00 2001 From: Kacper Nowak Date: Fri, 31 Oct 2025 17:46:20 +0100 Subject: [PATCH 07/32] Enable FesomDataReader to have different source and target datasets (#1046) * Implement separate target and source files, adjust masking * Fix casual masking * Fix longitude conversion flag * Fix casual masking strategy --------- Co-authored-by: Seb Hickman <56727418+shmh40@users.noreply.github.com> --- config/ifs_fesom_config.yml | 2 +- config/streams/fesom/fesom.yml | 1 + config/streams/fesom/fesom_elem.yml | 1 + src/weathergen/datasets/data_reader_fesom.py | 530 +++++++++++++++---- src/weathergen/datasets/masking.py | 75 ++- 5 files changed, 469 insertions(+), 140 deletions(-) diff --git a/config/ifs_fesom_config.yml b/config/ifs_fesom_config.yml index bee401315..4167bf91d 100644 --- a/config/ifs_fesom_config.yml +++ b/config/ifs_fesom_config.yml @@ -9,7 +9,7 @@ ae_local_num_heads: 16 start_date: 2000-10-10T00:00 end_date: 2199-12-31T00:00 start_date_val: 2200-01-01T00:00 -end_date_val: 2220-12-31T00:00 +end_date_val: 2209-12-31T00:00 num_epochs: 111 samples_per_epoch: 64 diff --git a/config/streams/fesom/fesom.yml b/config/streams/fesom/fesom.yml index ef315e22f..789011e2d 100644 --- a/config/streams/fesom/fesom.yml +++ b/config/streams/fesom/fesom.yml @@ -10,6 +10,7 @@ FESOM_NODE : type : fesom filenames : ['ocean_node'] + target_file: "/work/ab0995/a270088/Kacper/weathergenertor/AWICM3/ocean_elem" loss_weight : 1. source : null target : null diff --git a/config/streams/fesom/fesom_elem.yml b/config/streams/fesom/fesom_elem.yml index 8afe69435..f9c07e847 100644 --- a/config/streams/fesom/fesom_elem.yml +++ b/config/streams/fesom/fesom_elem.yml @@ -10,6 +10,7 @@ FESOM_ELEM : type : fesom filenames : ['ocean_elem'] + target_file: "/work/ab0995/a270088/Kacper/weathergenertor/AWICM3/ocean_node" loss_weight : 1. source : null target : null diff --git a/src/weathergen/datasets/data_reader_fesom.py b/src/weathergen/datasets/data_reader_fesom.py index bcffc632f..16971322f 100644 --- a/src/weathergen/datasets/data_reader_fesom.py +++ b/src/weathergen/datasets/data_reader_fesom.py @@ -48,6 +48,15 @@ def __init__( self.filenames = sorted(glob.glob(str(filename) + "/*")) self._tw_handler = tw_handler self._stream_info = stream_info + self.target_files = self.filenames + + self._src_lat_conv = False + self._src_lon_conv = False + self._trg_lat_conv = False + self._trg_lon_conv = False + + if "target_file" in stream_info: + self.target_files = sorted(glob.glob(str(stream_info["target_file"]) + "/*")) if len(self.filenames) == 0: self.init_empty() @@ -55,8 +64,10 @@ def __init__( return # Initialize data-dependent attributes to None. They will be set by _lazy_init. - self.time: da.Array | None = None - self.data: da.Array | None = None + self.source_time: da.Array | None = None + self.source_data: da.Array | None = None + self.target_time: da.Array | None = None + self.target_data: da.Array | None = None self.len = 0 # Default length is 0 until initialized self.source_channels = [] self.source_idx = [] @@ -65,10 +76,10 @@ def __init__( self.geoinfo_channels = [] self.geoinfo_idx = [] self.properties = {} - self._lat_needs_conversion = False - self._lon_needs_conversion = False + self.fake_specs = {} + self.fake_target = False - if len(self.filenames) == 0: + if len(self.filenames) == 0 or len(self.target_files) == 0: name = stream_info["name"] _logger.warning( f"{name} couldn't find any files matching {filename}. Stream is skipped." @@ -83,6 +94,39 @@ def __init__( # This flag ensures initialization happens only once per worker self._initialized = False + # print(f"checking stream info {list(stream_info.keys())}") + + def _get_mesh_size(self, group: zarr.Group) -> int: + if "n_points" in group.data.attrs: + return group.data.attrs["n_points"] + else: + return group.data.attrs["nod2"] + + def _reorder_groups(self, colnames: list[str], groups: list[zarr.Group]) -> list[da.Array]: + reordered_data_arrays: list[da.Array] = [] + + for group in groups: + local_colnames = group["data"].attrs["colnames"] + + # If the order is already correct, no need to do anything. + if local_colnames == colnames: + reordered_data_arrays.append(da.from_zarr(group["data"])) + else: + # Create the list of indices to re-shuffle the columns. + reorder_indices = [local_colnames.index(name) for name in colnames] + + # Lazily re-index the dask array. This operation is not executed immediately. + dask_array = da.from_zarr(group["data"]) + reordered_array = dask_array[:, reorder_indices] + reordered_data_arrays.append(reordered_array) + + return reordered_data_arrays + + def _remove_lonlat(self, colnames: list[str]) -> list[str]: + temp_colnames = list(colnames) + temp_colnames.remove("lat") + temp_colnames.remove("lon") + return temp_colnames def _lazy_init(self) -> None: """ @@ -92,45 +136,94 @@ def _lazy_init(self) -> None: if self._initialized: return + _logger.info(f"Initialising {self._stream_info['name']}") + # Each worker now opens its own file handles safely - groups: list[zarr.Group] = [zarr.open_group(name, mode="r") for name in self.filenames] - times: list[zarr.Array] = [group["dates"] for group in groups] - self.time = da.concatenate(times, axis=0) + s_groups: list[zarr.Group] = [zarr.open_group(name, mode="r") for name in self.filenames] + t_groups: list[zarr.Group] = [zarr.open_group(name, mode="r") for name in self.target_files] + + s_times: list[zarr.Array] = [group["dates"] for group in s_groups] + t_times: list[zarr.Array] = [group["dates"] for group in t_groups] + + self.source_time = da.concatenate(s_times, axis=0) + self.target_time = da.concatenate(t_times, axis=0) # Use the first group for metadata - first_group = groups[0] - if "nod2" in first_group.data.attrs: - self.mesh_size = first_group.data.attrs["nod2"] - else: - self.mesh_size = first_group.data.attrs["n_points"] + self.source_mesh_size = self._get_mesh_size(s_groups[0]) + self.target_mesh_size = self._get_mesh_size(t_groups[0]) # Metadata reading is cheap, but let's do it with the rest of the init - start_ds = self.time[0][0].compute() - end_ds = self.time[-1][0].compute() + self.start_source = self.source_time[0][0].compute() + self.end_source = self.source_time[-1][0].compute() + + if self.start_source > self._tw_handler.t_end or self.end_source < self._tw_handler.t_start: + name = self._stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + self.init_empty() + self._initialized = True + return + + self.start_target = self.target_time[0][0].compute() + self.end_target = self.target_time[-1][0].compute() - if start_ds > self._tw_handler.t_end or end_ds < self._tw_handler.t_start: + if self.start_target > self._tw_handler.t_end or self.end_target < self._tw_handler.t_start: name = self._stream_info["name"] _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") self.init_empty() self._initialized = True return - period = (self.time[self.mesh_size][0] - self.time[0][0]).compute() + self.source_period = ( + self.source_time[self.source_mesh_size][0] - self.source_time[0][0] + ).compute() + self.target_period = ( + self.target_time[self.target_mesh_size][0] - self.target_time[0][0] + ).compute() # Re-initialize the parent class with correct time info - super().__init__(self._tw_handler, self._stream_info, start_ds, end_ds, period) + super().__init__( # Initialise only for source as source-target split is not supported + self._tw_handler, + self._stream_info, + self.start_source, + self.end_source, + self.source_period, + ) + + if ( + self._tw_handler.t_start > self.start_source + and self._tw_handler.t_start > self.end_source + ): + self.source_start_idx = ( + (self._tw_handler.t_start - self.start_source) // self.source_period + 1 + ) * self.source_mesh_size + else: + self.source_start_idx = 0 - if self._tw_handler.t_start > start_ds: - self.start_idx = ((self._tw_handler.t_start - start_ds) // period + 1) * self.mesh_size + if ( + self._tw_handler.t_start > self.start_target + and self._tw_handler.t_start > self.end_target + ): + self.target_start_idx = ( + (self._tw_handler.t_start - self.start_target) // self.target_period + 1 + ) * self.target_mesh_size else: - self.start_idx = 0 + self.target_start_idx = 0 - self.end_idx = ((self._tw_handler.t_end - start_ds) // period + 1) * self.mesh_size + self.source_end_idx = ( + (self._tw_handler.t_end - self.start_source) // self.source_period + 1 + ) * self.source_mesh_size + self.target_end_idx = ( + (self._tw_handler.t_end - self.start_target) // self.target_period + 1 + ) * self.target_mesh_size - if self.end_idx > len(self.time): - self.end_idx = len(self.time) + if self.source_end_idx > len(self.source_time): + self.source_end_idx = len(self.source_time) + if self.target_end_idx > len(self.target_time): + self.target_end_idx = len(self.target_time) - self.len = (self.end_idx - self.start_idx) // self.mesh_size + self.source_len = (self.source_end_idx - self.source_start_idx) // self.source_mesh_size + self.target_len = (self.target_end_idx - self.target_start_idx) // self.target_mesh_size + self.len = min(self.source_len, self.target_len) # Check for a valid length after calculations if self.len <= 0: @@ -138,99 +231,133 @@ def _lazy_init(self) -> None: self._initialized = True return - self.colnames: list[str] = list(first_group.data.attrs["colnames"]) - self.cols_idx = list(np.arange(len(self.colnames))) - self.lat_index = self.colnames.index("lat") - self.lon_index = self.colnames.index("lon") - - reordered_data_arrays: list[zarr.Group] = [] + self.source_colnames: list[str] = list(s_groups[0].data.attrs["colnames"]) + self.target_colnames: list[str] = list(t_groups[0].data.attrs["colnames"]) - for group in groups: - local_colnames = group["data"].attrs["colnames"] + self.source_cols_idx = list(np.arange(len(self.source_colnames), dtype=int)) + self.target_cols_idx = list(np.arange(len(self.target_colnames), dtype=int)) - # If the order is already correct, no need to do anything. - if local_colnames == self.colnames: - reordered_data_arrays.append(da.from_zarr(group["data"])) - else: - # Create the list of indices to re-shuffle the columns. - reorder_indices = [local_colnames.index(name) for name in self.colnames] + self.src_lat_index: int = self.source_colnames.index("lat") + self.src_lon_index: int = self.source_colnames.index("lon") + self.trg_lat_index: int = self.target_colnames.index("lat") + self.trg_lon_index: int = self.target_colnames.index("lon") - # Lazily re-index the dask array. This operation is not executed immediately. - dask_array = da.from_zarr(group["data"]) - reordered_array = dask_array[:, reorder_indices] - reordered_data_arrays.append(reordered_array) + source_reorderd = self._reorder_groups(self.source_colnames, s_groups) + target_reorderd = self._reorder_groups(self.target_colnames, t_groups) # Modify a copy, not the original list while iterating - temp_colnames = list(self.colnames) - temp_colnames.remove("lat") - temp_colnames.remove("lon") - self.colnames = temp_colnames + self.source_colnames = self._remove_lonlat(self.source_colnames) + self.target_colnames = self._remove_lonlat(self.target_colnames) + + self.source_cols_idx.remove(self.src_lat_index) + self.source_cols_idx.remove(self.src_lon_index) + self.source_cols_idx = np.array(self.source_cols_idx) - self.cols_idx.remove(self.lat_index) - self.cols_idx.remove(self.lon_index) - self.cols_idx = np.array(self.cols_idx) + self.target_cols_idx.remove(self.trg_lat_index) + self.target_cols_idx.remove(self.trg_lon_index) + self.target_cols_idx = np.array(self.target_cols_idx) - self.properties = {"stream_id": first_group.data.attrs["obs_id"]} + self.properties = {"stream_id": s_groups[0].data.attrs["obs_id"]} - self.mean = np.concatenate((np.array([0, 0]), np.array(first_group.data.attrs["means"]))) - self.stdev = np.sqrt( - np.concatenate((np.array([1, 1]), np.array(first_group.data.attrs["std"]))) + self.source_mean = np.concatenate( + (np.array([0, 0]), np.array(s_groups[0].data.attrs["means"])) + ) + self.source_stdev = np.sqrt( + np.concatenate((np.array([1, 1]), np.array(s_groups[0].data.attrs["std"]))) ) - self.stdev[self.stdev <= 1e-5] = 1.0 + self.source_stdev[self.source_stdev <= 1e-5] = 1.0 - self.data = da.concatenate(reordered_data_arrays, axis=0) + self.target_mean = np.concatenate( + (np.array([0, 0]), np.array(t_groups[0].data.attrs["means"])) + ) + self.target_stdev = np.sqrt( + np.concatenate((np.array([1, 1]), np.array(t_groups[0].data.attrs["std"]))) + ) + self.target_stdev[self.target_stdev <= 1e-5] = 1.0 + + self.source = da.concatenate(source_reorderd, axis=0) + self.target = da.concatenate(target_reorderd, axis=0) - first_timestep_lats = self.data[: self.mesh_size, self.lat_index].compute() - first_timestep_lons = self.data[: self.mesh_size, self.lon_index].compute() + source_channels = self._stream_info.get("source") + source_excl = self._stream_info.get("source_exclude") + self.source_channels, self.source_idx = ( + self.select(self.source_colnames, self.source_cols_idx, source_channels, source_excl) + if source_channels or source_excl + else (self.source_colnames, self.source_cols_idx) + ) - if np.any(first_timestep_lats > 90.0): + target_channels = self._stream_info.get("target") + target_excl = self._stream_info.get("target_exclude") + self.target_channels, self.target_idx = ( + self.select(self.target_colnames, self.target_cols_idx, target_channels, target_excl) + if target_channels or target_excl + else (self.target_colnames, self.target_cols_idx) + ) + + src_timestep_lats = self.source[: self.source_mesh_size, self.src_lat_index].compute() + trg_timestep_lats = self.target[: self.target_mesh_size, self.trg_lat_index].compute() + + if np.any(src_timestep_lats > 90.0): _logger.warning( - f"Latitude for stream '{self._stream_info['name']}' appears to be in a [0, 180] " - f"format. It will be automatically converted to the required [-90, 90] format." + f"Latitude for stream '{self._stream_info['name']}' " + f"source appears to be in a [0, 180] format. " + f"It will be automatically converted to the required [-90, 90] format." ) - self._lat_needs_conversion = True + self._src_lat_conv = True - if np.any(first_timestep_lons > 180.0): + if np.any(trg_timestep_lats > 90.0): _logger.warning( - f"Longitude for stream '{self._stream_info['name']}' appears to be in a [0, 360] " - f"format. It will be automatically converted to the required [-180, 180] format." + f"Latitude for stream '{self._stream_info['name']}' " + f"target appears to be in a [0, 180] format. " + f"It will be automatically converted to the required [-90, 90] format." ) - self._lon_needs_conversion = True + self._trg_lat_conv = True - source_channels = self._stream_info.get("source") - source_excl = self._stream_info.get("source_exclude") - self.source_channels, self.source_idx = self.select_channels(source_channels, source_excl) + src_timestep_lons = self.source[: self.source_mesh_size, self.src_lon_index].compute() + trg_timestep_lons = self.target[: self.target_mesh_size, self.trg_lon_index].compute() - target_channels = self._stream_info.get("target") - target_excl = self._stream_info.get("target_exclude") - self.target_channels, self.target_idx = self.select_channels(target_channels, target_excl) + if np.any(src_timestep_lons > 180.0): + _logger.warning( + f"Longitude for stream '{self._stream_info['name']}' " + f"source appears to be in a [0, 360] format. " + f"It will be automatically converted to the required [-180, 180] format." + ) + self._src_lon_conv = True + + if np.any(trg_timestep_lons > 180.0): + _logger.warning( + f"Longitude for stream '{self._stream_info['name']}' " + f"target appears to be in a [0, 360] format." + f"It will be automatically converted to the required [-180, 180] format." + ) + self._trg_lon_conv = True self.geoinfo_channels = [] self.geoinfo_idx = [] self._initialized = True - def select_channels( - self, ch_filters: list[str] | None, excl: list[str] | None = None + def select( + self, + colnames: list[str], + cols_idx: NDArray, + ch_filters: list[str] | None, + excl: list[str] | None = None, ) -> tuple[list[str], NDArray]: - """ - Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ if excl and ch_filters: mask = [ any(f == c for f in ch_filters) and all(ex not in c for ex in excl) - for c in self.colnames + for c in colnames ] elif ch_filters: - mask = [any(f == c for f in ch_filters) for c in self.colnames] + mask = [any(f == c for f in ch_filters) for c in colnames] elif excl: - mask = [all(ex not in c for ex in excl) for c in self.colnames] + mask = [all(ex not in c for ex in excl) for c in colnames] else: - return self.colnames, self.cols_idx + assert False, "Cannot use select with both ch_filters and excl as None" - selected_cols_idx = self.cols_idx[np.where(mask)[0]] - selected_colnames = [self.colnames[i] for i in np.where(mask)[0]] + selected_cols_idx = cols_idx[np.where(mask)[0]] + selected_colnames = [colnames[i] for i in np.where(mask)[0]] return selected_colnames, selected_cols_idx @override @@ -244,10 +371,9 @@ def length(self) -> int: self._lazy_init() return self.len - @override - def _get_dataset_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: + def _get_source_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: """ - Get dataset indexes for a given time window index, when the dataset is periodic. + Get source dataset indexes for a given time window index, when the dataset is periodic. This function assumes state of a variable is persistent, thus if no data is found in the time window, last measurement is used before the beggining of the windows is used. @@ -268,66 +394,160 @@ def _get_dataset_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: dtr = tw_handler.window(idx) # If there is no or only marginal overlap with the dataset, return empty index ranges if ( - not self.data_start_time - or not self.data_end_time - or dtr.end < self.data_start_time - or dtr.start > self.data_end_time - or dtr.start < self.data_start_time - or dtr.end > self.data_end_time - or (self.data_end_time is not None and dtr.start > self.data_end_time) + not self.start_source + or not self.end_source + or dtr.end < self.start_source + or dtr.start > self.end_source + or dtr.start < self.start_source + or dtr.end > self.end_source + or (self.end_source is not None and dtr.start > self.end_source) ): return (np.array([], dtype=np.int64), dtr) # relative time in dataset - delta_t_start = dtr.start - self.data_start_time - delta_t_end = dtr.end - self.data_start_time - t_epsilon + delta_t_start = dtr.start - self.start_source + delta_t_end = dtr.end - self.start_source - t_epsilon assert isinstance(delta_t_start, np.timedelta64), "delta_t_start must be timedelta64" - start_didx = delta_t_start // self.period - end_didx = delta_t_end // self.period + start_didx = delta_t_start // self.source_period + end_didx = delta_t_end // self.source_period # adjust start_idx if not exactly on start time - if (delta_t_start % self.period) > np.timedelta64(0, "s"): + if (delta_t_start % self.source_period) > np.timedelta64(0, "s"): # empty window in between two timesteps if start_didx == end_didx: return (np.array([start_didx], dtype=np.int64), dtr) start_didx += 1 - end_didx = start_didx + int((dtr.end - dtr.start - t_epsilon) / self.period) + end_didx = start_didx + int((dtr.end - dtr.start - t_epsilon) / self.source_period) + return (np.arange(start_didx, end_didx + 1, dtype=np.int64), dtr) + + def _get_target_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: + """ + Get target dataset indexes for a given time window index, when the dataset is periodic. + + This function assumes state of a variable is persistent, thus if no data is found + in the time window, last measurement is used before the beggining of the windows is used. + + Parameters + ---------- + idx : TIndex + Index of the time window. + + Returns + ------- + NDArray[np.int64] + Array of dataset indexes corresponding to the time window. + """ + tw_handler = self.time_window_handler + + # Function is separated from the class to allow testing without instantiating the class. + dtr = tw_handler.window(idx) + # If there is no or only marginal overlap with the dataset, return empty index ranges + if ( + not self.start_target + or not self.end_target + or dtr.end < self.start_target + or dtr.start > self.end_target + or dtr.start < self.start_target + or dtr.end > self.end_target + or (self.end_target is not None and dtr.start > self.end_target) + ): + return (np.array([], dtype=np.int64), dtr) + # relative time in dataset + delta_t_start = dtr.start - self.start_target + delta_t_end = dtr.end - self.start_target - t_epsilon + assert isinstance(delta_t_start, np.timedelta64), "delta_t_start must be timedelta64" + start_didx = delta_t_start // self.target_period + end_didx = delta_t_end // self.target_period + + # adjust start_idx if not exactly on start time + if (delta_t_start % self.target_period) > np.timedelta64(0, "s"): + # empty window in between two timesteps + if start_didx == end_didx: + return (np.array([start_didx], dtype=np.int64), dtr) + start_didx += 1 + + end_didx = start_didx + int((dtr.end - dtr.start - t_epsilon) / self.target_period) return (np.arange(start_didx, end_didx + 1, dtype=np.int64), dtr) @override - def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + def get_source(self, idx: TIndex) -> ReaderData: + self._lazy_init() + (t_idxs, dtr) = self._get_source_idxs(idx) + if self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(self.source_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + start_row = t_idxs[0] * self.source_mesh_size + end_row = (t_idxs[-1] + 1) * self.source_mesh_size + + # Note: we read all columns from start_row to end_row once, + # then select the ones we need. This is more efficient for Zarr. + full_data_slice = self.source[start_row:end_row] + datetimes_lazy = self.source_time[start_row:end_row] + + # Define the specific slices we need from the larger block + data_lazy = full_data_slice[:, self.source_idx] + lat_lazy = full_data_slice[:, self.src_lat_index] + lon_lazy = full_data_slice[:, self.src_lon_index] + + # Dask optimizes this to a single (or few) efficient read operation(s). + data, lat, lon, datetimes = dask.compute( + data_lazy, lat_lazy, lon_lazy, datetimes_lazy, scheduler="single-threaded" + ) + + if self._src_lat_conv: + lat = 90.0 - lat + + if self._src_lon_conv: + lon = ((lon + 180.0) % 360.0) - 180.0 + + coords = np.stack([lat, lon], axis=1) + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + datetimes = np.squeeze(datetimes) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + + return rd + + @override + def get_target(self, idx: TIndex) -> ReaderData: self._lazy_init() - (t_idxs, dtr) = self._get_dataset_idxs(idx) + (t_idxs, dtr) = self._get_target_idxs(idx) if self.len == 0 or len(t_idxs) == 0: return ReaderData.empty( - num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + num_data_fields=len(self.source_idx), num_geo_fields=len(self.geoinfo_idx) ) - start_row = t_idxs[0] * self.mesh_size - end_row = (t_idxs[-1] + 1) * self.mesh_size + start_row = t_idxs[0] * self.target_mesh_size + end_row = (t_idxs[-1] + 1) * self.target_mesh_size # Note: we read all columns from start_row to end_row once, # then select the ones we need. This is more efficient for Zarr. - full_data_slice = self.data[start_row:end_row] - time_slice = self.time[start_row:end_row] + full_data_slice = self.target[start_row:end_row] + datetimes_lazy = self.target_time[start_row:end_row] # Define the specific slices we need from the larger block - data_lazy = full_data_slice[:, channels_idx] - lat_lazy = full_data_slice[:, self.lat_index] - lon_lazy = full_data_slice[:, self.lon_index] - datetimes_lazy = time_slice + data_lazy = full_data_slice[:, self.target_idx] + lat_lazy = full_data_slice[:, self.trg_lat_index] + lon_lazy = full_data_slice[:, self.trg_lon_index] # Dask optimizes this to a single (or few) efficient read operation(s). data, lat, lon, datetimes = dask.compute( data_lazy, lat_lazy, lon_lazy, datetimes_lazy, scheduler="single-threaded" ) - if self._lat_needs_conversion: + if self._trg_lat_conv: lat = 90.0 - lat - if self._lon_needs_conversion: + if self._trg_lon_conv: lon = ((lon + 180.0) % 360.0) - 180.0 coords = np.stack([lat, lon], axis=1) @@ -342,3 +562,87 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: ) return rd + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + return self.get_source(idx) + + @override + def normalize_source_channels(self, source: NDArray) -> NDArray: + """ + Normalize source channels + + Parameters + ---------- + data : + data to be normalized + + Returns + ------- + Normalized data + """ + assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" + for i, ch in enumerate(self.source_idx): + source[..., i] = (source[..., i] - self.source_mean[ch]) / self.source_stdev[ch] + + return source + + @override + def normalize_target_channels(self, target: NDArray) -> NDArray: + """ + Normalize target channels + + Parameters + ---------- + data : + data to be normalized + + Returns + ------- + Normalized data + """ + assert target.shape[-1] == len(self.target_idx), "incorrect number of target channels" + for i, ch in enumerate(self.target_idx): + target[..., i] = (target[..., i] - self.target_mean[ch]) / self.target_stdev[ch] + + return target + + @override + def denormalize_source_channels(self, source: NDArray) -> NDArray: + """ + Denormalize source channels + + Parameters + ---------- + data : + data to be denormalized + + Returns + ------- + Denormalized data + """ + assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" + for i, ch in enumerate(self.source_idx): + source[..., i] = (source[..., i] * self.source_stdev[ch]) + self.source_mean[ch] + + return source + + @override + def denormalize_target_channels(self, data: NDArray) -> NDArray: + """ + Denormalize target channels + + Parameters + ---------- + data : + data to be denormalized (target or pred) + + Returns + ------- + Denormalized data + """ + assert data.shape[-1] == len(self.target_idx), "incorrect number of target channels" + for i, ch in enumerate(self.target_idx): + data[..., i] = (data[..., i] * self.target_stdev[ch]) + self.target_mean[ch] + + return data diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 2e89ea04f..93de031fe 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -68,31 +68,32 @@ def __init__(self, cf: Config): if self.current_strategy == "healpix": hl_data = self.healpix_level_data hl_mask = self.masking_strategy_config.get("hl_mask") - assert hl_data is not None and hl_mask is not None, ( - "If HEALPix masking, hl_mask must be given in masking_strategy_config." - ) + assert ( + hl_data is not None and hl_mask is not None + ), "If HEALPix masking, hl_mask must be given in masking_strategy_config." assert hl_mask < hl_data, "hl_mask must be less than hl_data for HEALPix masking." if self.current_strategy == "channel": # Ensure that masking_strategy_config contains either 'global' or 'per_cell' - assert self.masking_strategy_config.get("mode") in ["global", "per_cell"], ( - "masking_strategy_config must contain 'mode' key with value 'global' or 'per_cell'." - ) + assert self.masking_strategy_config.get("mode") in [ + "global", + "per_cell", + ], "masking_strategy_config must contain 'mode' key with value 'global' or 'per_cell'." # check all streams that source and target channels are identical for stream in cf.streams: # check explicit includes source_include = stream.get("source_include", []) target_include = stream.get("target_include", []) - assert set(source_include) == set(target_include), ( - "Source and target channels not identical. Required for masking_mode=channel" - ) + assert set(source_include) == set( + target_include + ), "Source and target channels not identical. Required for masking_mode=channel" # check excludes source_exclude = stream.get("source_exclude", []) target_exclude = stream.get("target_exclude", []) - assert set(source_exclude) == set(target_exclude), ( - "Source and target channels not identical. Required for masking_mode=channel" - ) + assert set(source_exclude) == set( + target_exclude + ), "Source and target channels not identical. Required for masking_mode=channel" def reset_rng(self, rng) -> None: """ @@ -277,6 +278,9 @@ def mask_target( # process all tokens used for embedding for cc, pp in zip(target_tokenized_data, self.perm_sel, strict=True): + if len(cc) == 0: # Skip if there's no target data + pass + if self.current_strategy == "channel": # If masking strategy is channel, handle target tokens differently. # We don't have Booleans per cell, instead per channel per cell, @@ -293,11 +297,28 @@ def mask_target( elif self.current_strategy == "causal": # select only the target times where mask is True - selected_tensors = [c for i, c in enumerate(cc) if pp[i]] - + if len(cc) == len(pp): + selected_tensors = [c for i, c in enumerate(cc) if pp[i]] + elif len(pp) == 0: + selected_tensors = cc + else: # If length of target and mask doesn't match, create new mask + ratio = np.sum(pp) / len(pp) # Ratio of masked tokens in source + indx = max(1, int(ratio * len(cc))) # Get the same for target + selected_tensors = cc[-indx:] + + elif self.current_strategy == "healpix": + selected_tensors = ( + cc if len(pp) > 0 and pp[0] else [] + ) # All tokens inside healpix cell have the same mask + + elif self.current_strategy == "random": + # For random masking, we simply select the tensors where the mask is True. + # When there's no mask it's assumed to be False. This is done via strict=False + selected_tensors = [c for c, p in zip(cc, pp, strict=False) if p] else: - # For other masking strategies, we simply select the tensors where the mask is True. - selected_tensors = [c for c, p in zip(cc, pp, strict=True) if p] + raise NotImplementedError( + f"Masking strategy {self.current_strategy} is not supported." + ) # Append the selected tensors to the processed_target_tokens list. if selected_tensors: @@ -346,9 +367,9 @@ def _generate_healpix_mask(self, token_lens: list[int], rate: float) -> np.typin hl_data = self.healpix_level_data hl_mask = self.masking_strategy_config.get("hl_mask") - assert len(token_lens) == self.healpix_num_cells, ( - f"Expected {self.healpix_num_cells} cells at level {hl_data}, got {len(token_lens)}." - ) + assert ( + len(token_lens) == self.healpix_num_cells + ), f"Expected {self.healpix_num_cells} cells at level {hl_data}, got {len(token_lens)}." # Calculate the number of parent cells at the mask level (hl_mask) num_parent_cells = 12 * (4**hl_mask) @@ -487,14 +508,16 @@ def _generate_causal_mask( # Create masks with list comprehension # Needed to handle variable lengths full_mask = [ - np.concatenate( - [ - np.zeros(start_idx, dtype=bool), - np.ones(max(0, token_len - start_idx), dtype=bool), - ] + ( + np.concatenate( + [ + np.zeros(start_idx, dtype=bool), + np.ones(max(0, token_len - start_idx), dtype=bool), + ] + ) + if token_len > 1 + else (np.zeros(1, dtype=bool) if token_len == 1 else np.array([], dtype=bool)) ) - if token_len > 1 - else (np.zeros(1, dtype=bool) if token_len == 1 else np.array([], dtype=bool)) for token_len, start_idx in zip(token_lens, start_mask_indices, strict=False) ] From 97041782d3c0e807c872a4729fc24236f080fcbc Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 3 Nov 2025 10:58:29 +0100 Subject: [PATCH 08/32] Add support for constant learning rate (#1186) * Added support for constant learning rate and minor clean-up in code * Fixed issues with overlap between lr phases * Changing default lr to constant --- config/default_config.yml | 2 +- src/weathergen/train/lr_scheduler.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 116b63267..19c8769fb 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -126,7 +126,7 @@ lr_final: 0.0 lr_steps_warmup: 512 lr_steps_cooldown: 512 lr_policy_warmup: "cosine" -lr_policy_decay: "linear" +lr_policy_decay: "constant" lr_policy_cooldown: "linear" grad_clip: 1.0 diff --git a/src/weathergen/train/lr_scheduler.py b/src/weathergen/train/lr_scheduler.py index 9220a30ad..f6ba7ab5d 100644 --- a/src/weathergen/train/lr_scheduler.py +++ b/src/weathergen/train/lr_scheduler.py @@ -123,6 +123,10 @@ def __init__( self.decay_factor = self.lr_max_scaled * np.sqrt(n_steps_warmup) self.scheduler_decay = None + elif policy_decay == "constant": + self.decay_factor = 0.0 + self.scheduler_decay = None + else: assert False, "Unsupported decay policy for learning rate scheduler" @@ -173,11 +177,10 @@ def step(self): if self.i_step >= (self.n_steps_warmup + self.n_steps_decay + self.n_steps_cooldown): return self.lr - if ( - self.policy_decay == "sqrt" - and self.i_step > self.n_steps_warmup - and self.i_step < self.n_steps_warmup + self.n_steps_decay - ): + end_decay = self.n_steps_warmup + self.n_steps_decay + phase_decay = (self.i_step > self.n_steps_warmup) and (self.i_step <= end_decay) + + if self.policy_decay == "sqrt" and phase_decay: self.lr = ( (self.decay_factor / np.sqrt(self.i_step)) if self.i_step > 0 @@ -185,6 +188,13 @@ def step(self): ) for g in self.optimizer.param_groups: g["lr"] = self.lr + elif self.policy_decay == "constant" and phase_decay: + cur_lr = self.lr + self.lr = self.lr_max_scaled + # make sure lr_max_scaled rate is used if warm-up end is not lr_max_scaled + if cur_lr < self.lr: + for g in self.optimizer.param_groups: + g["lr"] = self.lr else: self.cur_scheduler.step() self.lr = self.cur_scheduler.get_last_lr()[0] From 362165d29adfd2e2abb8c61de9364120053b5324 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:41:57 +0100 Subject: [PATCH 09/32] [issue 1123] restore probabilistic scores (#1128) * rebase * add ensemble * fix deterministic * fix plotting * lint * fix eval_config * probabilistic scores working now * lint * Fix spoofing and refactor handling of multiple source files (#1118) * Cleaning up spoofing and related code on data preprocessing for model * Fixed typo * Updated comments * Removed merge cells and implemented necessary adjustments * Fixed forecasting * Fixed missing handling of NaNs in coordinates and channel data * Minor clean up * Fix to removing/renaming variables * Changed funtion name to improve readability * Fixed bug with incorrect handling of multiple input datasources. * Addressed reviewer comments * resolve conflict * [1131] fixes circular dependencies (#1134) * fixes dependencies * cleanup * make the type checker not fail * cleanup * cleanup of type issues * Give option to plot only prediction maps (#1139) * add plot_preds_only feature * minor changes after comments * Tell FSDP2 about embedding engine forward functions (#1133) * Tell FSDP2 about embedding engine forward functions Note DO NOT add print functions in forward functions of the model, it will break with FSDP2 * Add comment * recover 'all' option (#1146) * Fixed problem in inferecne (#1145) * implement vrmse (#1147) * [1144] Extra fixes (#1148) * Fixed problem in inferecne * more fixes * fixes * lint * lint --------- Co-authored-by: Christian Lessig * Jk/log grad norms/log grad norms (#1068) * Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> * Add forecast and observation activity (#1126) * Add calculation methods for forecast and observation activity metrics in Scores class * Add new calculation methods for forecast activity metrics in Scores class * ruff * fix func name * Rename observation activity calculation method to target activity in Scores class * typo * refactor to common calc_act function for activity * fix cases * have calc_tact and calc_fact that use _calc_act for maintainability * fix small thing in style --------- Co-authored-by: iluise * hotfix: use correct methot `create` instead of `construct` (#1090) * restore develop * fix deterministic * fix plotting * lint * fix eval_config * probabilistic scores working now * lint * update utils * packages/evaluate/src/weathergen/evaluate/score.py * lint * removing duplication --------- Co-authored-by: Christian Lessig Co-authored-by: Timothy Hunter Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Co-authored-by: Sophie X <24638638+sophie-xhonneux@users.noreply.github.com> Co-authored-by: Julius Polz <56866670+jpolz@users.noreply.github.com> Co-authored-by: Julian Kuehnert Co-authored-by: Simon Grasse <161459968+grassesi@users.noreply.github.com> --- .../src/weathergen/evaluate/io_reader.py | 8 +- .../src/weathergen/evaluate/plotter.py | 73 ++- .../evaluate/src/weathergen/evaluate/score.py | 480 +++++++++--------- .../evaluate/src/weathergen/evaluate/utils.py | 35 +- 4 files changed, 344 insertions(+), 252 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 4c605fb20..2aff8147e 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -164,7 +164,6 @@ def check_availability( fsteps = requested_data.fsteps samples = requested_data.samples ensemble = requested_data.ensemble - requested = { "channel": set(channels) if channels is not None else None, "fstep": set(fsteps) if fsteps is not None else None, @@ -478,6 +477,13 @@ def get_data( _logger.debug(f"Selecting ensemble members {ensemble}.") pred = pred.sel(ens=ensemble) + if ensemble == ["mean"]: + _logger.debug("Averaging over ensemble members.") + pred = pred.mean("ens", keepdims=True) + else: + _logger.debug(f"Selecting ensemble members {ensemble}.") + pred = pred.sel(ens=ensemble) + da_tars_fs.append(target.squeeze()) da_preds_fs.append(pred.squeeze()) pps.append(npoints) diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index 696d47ef1..0d35a12a1 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -138,7 +138,6 @@ def select_from_da(self, da: xr.DataArray, selection: dict) -> xr.DataArray: ------- xarray DataArray with selected data. """ - for key, value in selection.items(): if key in da.coords and key not in da.dims: # Coordinate like 'sample' aligned to another dim @@ -710,6 +709,77 @@ def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. Skipping ensemble plotting." ) + def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: + """ + Plot ensemble spread for a data array. + + Parameters + ---------- + data: xr.xArray + DataArray to be plotted + x_dim: str + Dimension to be used for the x-axis. + label: str + Label for the dataset + Returns + ------- + None + """ + averaged = data.mean(dim=[dim for dim in data.dims if dim != x_dim], skipna=True).sortby( + x_dim + ) + + lines = plt.plot( + averaged[x_dim], + averaged.values, + label=label, + marker="o", + linestyle="-", + ) + line = lines[0] + color = line.get_color() + + ens = data.mean( + dim=[dim for dim in data.dims if dim not in [x_dim, "ens"]], skipna=True + ).sortby(x_dim) + + if self.plot_ensemble == "std": + std_dev = ens.std(dim="ens", skipna=True).sortby(x_dim) + plt.fill_between( + averaged[x_dim], + (averaged - std_dev).values, + (averaged + std_dev).values, + label=f"{label} - std dev", + color=color, + alpha=0.2, + ) + + elif self.plot_ensemble == "minmax": + ens_min = ens.min(dim="ens", skipna=True).sortby(x_dim) + ens_max = ens.max(dim="ens", skipna=True).sortby(x_dim) + + plt.fill_between( + averaged[x_dim], + ens_min.values, + ens_max.values, + label=f"{label} - min max", + color=color, + alpha=0.2, + ) + + elif self.plot_ensemble == "members": + for j in range(ens.ens.size): + plt.plot( + ens[x_dim], + ens.isel(ens=j).values, + color=color, + alpha=0.2, + ) + else: + _logger.warning( + f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. Skipping ensemble plotting." + ) + def plot( self, data: xr.DataArray | list, @@ -737,7 +807,6 @@ def plot( Name of the dimension to be used for the y-axis. print_summary: If True, print a summary of the values from the graph. - Returns ------- None diff --git a/packages/evaluate/src/weathergen/evaluate/score.py b/packages/evaluate/src/weathergen/evaluate/score.py index 62448c3f7..e5a45b9b7 100755 --- a/packages/evaluate/src/weathergen/evaluate/score.py +++ b/packages/evaluate/src/weathergen/evaluate/score.py @@ -241,9 +241,11 @@ def get_score( if score_name in self.det_metrics_dict.keys(): f = self.det_metrics_dict[score_name] elif score_name in self.prob_metrics_dict.keys(): - assert self.ens_dim in data.prediction.dims, ( - f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} not found in prediction data" - ) + if self._ens_dim not in data.prediction.dims: + _logger.error( + f"Probablistic score {score_name} chosen, but ensemble dimension {self._ens_dim} not found in prediction data. Skipping score calculation." + ) + return None f = self.prob_metrics_dict[score_name] else: raise ValueError( @@ -288,22 +290,34 @@ def get_score( keys = score_args_map.get(score_name, ["p", "gt"]) args = {k: available[k] for k in keys} - # Add group_by_coord if provided - if group_by_coord is not None: - if self._validate_groupby_coord(data, group_by_coord): - args["group_by_coord"] = group_by_coord - for an in arg_names: if an in kwargs: args[an] = kwargs[an] - # Call lazy evaluation function - result = f(**args) + if group_by_coord is not None and self._validate_groupby_coord(data, group_by_coord): + # Apply groupby to all DataArrays in args + grouped_args = { + k: (v.groupby(group_by_coord) if isinstance(v, xr.DataArray) else v) + for k, v in args.items() + } + + # Apply function f to each group and concatenate results + group_names = list(next(iter(grouped_args.values())).groups.keys()) + results = [] + for name in group_names: + group_slice = {k: v[name] for k, v in grouped_args.items()} + res = f(**group_slice) + # Add coordinate for concatenation + res = res.expand_dims({group_by_coord: [name]}) + results.append(res) + result = xr.concat(results, dim=group_by_coord) + else: + # No grouping: just call the function + result = f(**args) if compute: return result.compute() else: - # Return lazy evaluation result return result def _validate_agg_dims(self, dims: str | list[str]) -> list[str] | str: @@ -400,14 +414,23 @@ def get_2x2_event_counts( p: xr.DataArray, gt: xr.DataArray, thresh: float, - group_by_coord: str | None = None, ) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray]: """ Get counts of 2x2 contingency tables + + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + thresh: float + Threshold to define event occurrence + Returns + ------- + tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray] + Counts of hits (a), false alarms (b), misses (c), and correct negatives (d) """ - if group_by_coord: - p = p.groupby(group_by_coord) - gt = gt.groupby(group_by_coord) a = self._sum((p >= thresh) & (gt >= thresh)) b = self._sum((p >= thresh) & (gt >= thresh)) @@ -422,10 +445,24 @@ def calc_ets( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, thresh: float = 0.1, - ): - a, b, c, d = self.get_2x2_event_counts(p, gt, thresh, group_by_coord) + ) -> xr.DataArray: + """ + Calculate the equitable threat score (ETS) of forecast data w.r.t. reference data. + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + thresh: float + Threshold to define event occurrence + Returns + ------- + xr.DataArray + Equitable threat score (ETS) + """ + a, b, c, d = self.get_2x2_event_counts(p, gt, thresh) n = a + b + c + d ar = (a + b) * (a + c) / n # random reference forecast @@ -440,10 +477,25 @@ def calc_fbi( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, thresh: float = 0.1, - ): - a, b, c, _ = self.get_2x2_event_counts(p, gt, thresh, group_by_coord) + ) -> xr.DataArray: + """ + Calculate the frequency bias index (FBI) of forecast data w.r.t. reference data. + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + thresh: float + Threshold to define event occurrence + Returns + ------- + xr.DataArray + Frequency bias index (FBI) + """ + + a, b, c, _ = self.get_2x2_event_counts(p, gt, thresh) denom = a + c fbi = (a + b) / denom @@ -456,10 +508,25 @@ def calc_pss( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, thresh: float = 0.1, - ): - a, b, c, d = self.get_2x2_event_counts(p, gt, thresh, group_by_coord) + ) -> xr.DataArray: + """ + Calculate the Peirce skill score (PSS) of forecast data w.r.t. reference data. + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + thresh: float + Threshold to define event occurrence + Returns + ------- + xr.DataArray + Pierce skill score (PSS) + """ + + a, b, c, d = self.get_2x2_event_counts(p, gt, thresh) denom = (a + c) * (b + d) pss = (a * d - b * c) / denom @@ -472,19 +539,27 @@ def calc_l1( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, scale_dims: list | None = None, - ): + ) -> xr.DataArray: """ Calculate the L1 error norm of forecast data w.r.t. reference data. Note that the L1 error norm is calculated as the sum of absolute differences. - If scale_dims is not None, the L1 will scaled by the number of elements in the average dimensions. + Parameters + ---------- + p: xr.DataArray + Forecast data array + gt: xr.DataArray + Ground truth data array + scale_dims: list | None + List of dimensions over which the L1 score will be scaled. + If provided, the L1 score will be divided by the product of the sizes of these dimensions. + Returns + ------- + xr.DataArray + L1 error norm """ l1 = np.abs(p - gt) - if group_by_coord: - l1 = l1.groupby(group_by_coord) - l1 = self._sum(l1) if scale_dims: @@ -503,10 +578,9 @@ def calc_l2( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, scale_dims: list | None = None, squared_l2: bool = False, - ): + ) -> xr.DataArray: """ Calculate the L2 error norm of forecast data w.r.t. reference data. @@ -516,9 +590,6 @@ def calc_l2( Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the L2 score. scale_dims: list | None List of dimensions over which the L2 score will be scaled. If provided, the L2 score will be divided by the product of the sizes of these dimensions. @@ -526,12 +597,13 @@ def calc_l2( If True, the L2 score will be returned as the sum of squared differences. If False, the L2 score will be returned as the square root of the sum of squared differences. Default is False, i.e. the L2 score is returned as the square root of the sum of squared differences. + Returns + ------- + xr.DataArray + L2 error norm """ l2 = np.square(p - gt) - if group_by_coord: - l2 = l2.groupby(group_by_coord) - l2 = self._sum(l2) if not squared_l2: @@ -549,7 +621,7 @@ def calc_l2( return l2 - def calc_mae(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None): + def calc_mae(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate mean absolute error (MAE) of forecast data w.r.t. reference data. @@ -559,24 +631,15 @@ def calc_mae(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the MAE score. """ if self._agg_dims is None: raise ValueError( "Cannot calculate mean absolute error without aggregation dimensions (agg_dims=None)." ) - mae = np.abs(p - gt) - - if group_by_coord: - mae = mae.groupby(group_by_coord) - mae = self._mean(mae) + return self._mean(np.abs(p - gt)) - return mae - - def calc_mse(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None): + def calc_mse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate mean squared error (MSE) of forecast data w.r.t. reference data. @@ -586,24 +649,19 @@ def calc_mse(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. + Returns + ------- + xr.DataArray + Mean squared error (MSE) """ if self._agg_dims is None: raise ValueError( "Cannot calculate mean squared error without aggregation dimensions (agg_dims=None)." ) - mse = np.square(p - gt) - - if group_by_coord: - mse = mse.groupby(group_by_coord) - - mse = self._mean(mse) + return self._mean(np.square(p - gt)) - return mse - - def calc_rmse(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None): + def calc_rmse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate root mean squared error (RMSE) of forecast data w.r.t. reference data Parameters @@ -612,20 +670,22 @@ def calc_rmse(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | Non Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the RMSE score. + Returns + ------- + xr.DataArray + Root mean squared error (RMSE) + """ if self._agg_dims is None: raise ValueError( "Cannot calculate root mean squared error without aggregation dimensions (agg_dims=None)." ) - rmse = np.sqrt(self.calc_mse(p, gt, group_by_coord)) + rmse = np.sqrt(self.calc_mse(p, gt)) return rmse - def calc_vrmse(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None): + def calc_vrmse(self, p: xr.DataArray, gt: xr.DataArray): """ Calculate variance-normalized root mean squared error (VRMSE) of forecast data w.r.t. reference data Parameters @@ -634,16 +694,13 @@ def calc_vrmse(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | No Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the VRMSE score. """ if self._agg_dims is None: raise ValueError( "Cannot calculate variance-normalized root mean squared error without aggregation dimensions (agg_dims=None)." ) - vrmse = np.sqrt(self.calc_mse(p, gt, group_by_coord) / (gt.var(dim=self._agg_dims) + 1e-6)) + vrmse = np.sqrt(self.calc_mse(p, gt) / (gt.var(dim=self._agg_dims) + 1e-6)) return vrmse @@ -713,7 +770,7 @@ def calc_change_rate( self, s0: xr.DataArray, s1: xr.DataArray, - ): + ) -> xr.DataArray: """ Calculate the "change rate" of a data array as the mean absolute difference between two consecutive time steps. @@ -745,8 +802,7 @@ def calc_froct( gt: xr.DataArray, p_next: xr.DataArray, gt_next: xr.DataArray, - group_by_coord: str | None = None, - ): + ) -> xr.DataArray: """ Calculate forecast rate of change over time @@ -760,9 +816,10 @@ def calc_froct( Next forecast step data array gt_next: xr.DataArray Next ground truth step data array (not used in calculation, but kept for consistency) - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the FROCT score. + Returns + ------- + xr.DataArray + Forecast rate of change over time """ if self._agg_dims is None: raise ValueError( @@ -771,9 +828,6 @@ def calc_froct( froct = self.calc_change_rate(p, p_next) - if group_by_coord: - froct = froct.groupby(group_by_coord) - froct = self._mean(froct) return froct @@ -784,7 +838,6 @@ def calc_troct( gt: xr.DataArray, gt_next: xr.DataArray, p_next: xr.DataArray, - group_by_coord: str | None = None, ): """ Calculate target rate of change over time @@ -799,9 +852,10 @@ def calc_troct( Next forecast step data array (not used in calculation, but kept for consistency) gt_next: xr.DataArray Next ground truth step data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the FROCT score. + Returns + ------- + xr.DataArray + Target rate of change over time """ if self._agg_dims is None: raise ValueError( @@ -809,10 +863,6 @@ def calc_troct( ) troct = self.calc_change_rate(gt, gt_next) - - if group_by_coord: - troct = troct.groupby(group_by_coord) - troct = self._mean(troct) return troct @@ -821,7 +871,6 @@ def _calc_act( self, x: xr.DataArray, c: xr.DataArray, - group_by_coord: str | None = None, spatial_dims: list = None, ): """ @@ -836,9 +885,6 @@ def _calc_act( Forecast or target data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the activity score. spatial_dims: List[str] Names of spatial dimensions over which activity is calculated. Note: No averaging is possible over these dimensions. @@ -857,20 +903,7 @@ def _calc_act( # Calculate anomalies ano = x - c - - if group_by_coord: - # Apply groupby and calculate activity within each group using apply - ano_grouped = ano.groupby(group_by_coord) - - # Use apply to calculate activity for each group - this preserves the coordinate structure - act = xr.concat( - [ano_group.std(dim=spatial_dims) for group_label, ano_group in ano_grouped], - dim=group_by_coord, - ).assign_coords({group_by_coord: list(ano_grouped.groups.keys())}) - - else: - # Calculate forecast activity over spatial dimensions (no grouping) - act = ano.std(dim=spatial_dims) + act = ano.std(dim=spatial_dims) return act @@ -878,7 +911,6 @@ def calc_fact( self, p: xr.DataArray, c: xr.DataArray, - group_by_coord: str | None = None, spatial_dims: list = None, ): """ @@ -893,21 +925,17 @@ def calc_fact( Forecast data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the activity score. spatial_dims: List[str] Names of spatial dimensions over which activity is calculated. Note: No averaging is possible over these dimensions. """ - return self._calc_act(p, c, group_by_coord, spatial_dims) + return self._calc_act(p, c, spatial_dims) def calc_tact( self, gt: xr.DataArray, c: xr.DataArray, - group_by_coord: str | None = None, spatial_dims: list = None, ): """ @@ -922,48 +950,20 @@ def calc_tact( Target data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the activity score. spatial_dims: List[str] Names of spatial dimensions over which activity is calculated. Note: No averaging is possible over these dimensions. """ - return self._calc_act(gt, c, group_by_coord, spatial_dims) - - def _calc_acc_group( - self, fcst: xr.DataArray, obs: xr.DataArray, spatial_dims: list[str] - ) -> xr.DataArray: - """Calculate ACC for a single group - Parameters - ---------- - ---------- - fcst: xr.DataArray - Forecast data for the group - Forecast data for the group - obs: xr.DataArray - Observation data for the group - spatial_dims: List[str] - Names of spatial dimensions over which ACC is calculated. - Returns - ------- - xr.DataArray - ACC for the group - """ - - return (fcst * obs).sum(spatial_dims) / np.sqrt( - (fcst**2).sum(spatial_dims) * (obs**2).sum(spatial_dims) - ) + return self._calc_act(gt, c, spatial_dims) def calc_acc( self, p: xr.DataArray, gt: xr.DataArray, c: xr.DataArray, - group_by_coord: str | None = None, spatial_dims: list = None, - ): + ) -> xr.DataArray: """ Calculate anomaly correlation coefficient (ACC). @@ -979,12 +979,13 @@ def calc_acc( Ground truth data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the ACC score. spatial_dims: List[str] Names of spatial dimensions over which ACC is calculated. Note: No averaging is possible over these dimensions. + Returns + ------- + xr.DataArray + Anomaly correlation coefficient (ACC) """ # Check if spatial_dims are in the data @@ -1001,29 +1002,14 @@ def calc_acc( # Calculate anomalies fcst_ano, obs_ano = p - c, gt - c - if group_by_coord: - # Apply groupby and calculate ACC within each group using apply - fcst_grouped = fcst_ano.groupby(group_by_coord) - obs_grouped = obs_ano.groupby(group_by_coord) - - # Use apply to calculate ACC for each group - this preserves the coordinate structure - acc = xr.concat( - [ - self._calc_acc_group(fcst_group, obs_grouped[group_label], spatial_dims) - for group_label, fcst_group in fcst_grouped - ], - dim=group_by_coord, - ).assign_coords({group_by_coord: list(fcst_grouped.groups.keys())}) - - else: - # Calculate ACC over spatial dimensions (no grouping) - acc = self._calc_acc_group(fcst_ano, obs_ano, spatial_dims) - - acc = self._calc_acc_group(fcst_ano, obs_ano, spatial_dims) + # Calculate ACC over spatial dimensions (no grouping) + acc = (fcst_ano * obs_ano).sum(spatial_dims) / np.sqrt( + (fcst_ano**2).sum(spatial_dims) * (obs_ano**2).sum(spatial_dims) + ) return acc - def calc_bias(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None): + def calc_bias(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate mean bias of forecast data w.r.t. reference data @@ -1033,16 +1019,12 @@ def calc_bias(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | Non Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the bias score. + Returns + ------- + xr.DataArray + Mean bias """ - bias = p - gt - - if group_by_coord: - bias = bias.groupby(group_by_coord) - - bias = self._mean(bias) + bias = self._mean(p - gt) return bias @@ -1050,9 +1032,8 @@ def calc_psnr( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, pixel_max: float = 1.0, - ): + ) -> xr.DataArray: """ Calculate PSNR of forecast data w.r.t. reference data @@ -1062,14 +1043,15 @@ def calc_psnr( Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the PSNR score. pixel_max: float Maximum pixel value in the data. Default is 1.0. + Returns + ------- + xr.DataArray + Peak signal-to-noise ratio (PSNR) """ - mse = self.calc_mse(p, gt, group_by_coord) + mse = self.calc_mse(p, gt) if np.count_nonzero(mse) == 0: psnr = mse psnr[...] = 100.0 @@ -1082,10 +1064,9 @@ def calc_spatial_variability( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, order: int = 1, non_spatial_avg_dims: list[str] = None, - ): + ) -> xr.DataArray: """ Calculates the ratio between the spatial variability of differental operator with order 1 (higher values unsupported yest) forecast and ground truth data using the calc_geo_spatial-method. @@ -1099,14 +1080,15 @@ def calc_spatial_variability( Forecast data array gt: xr.DataArray Ground truth data array - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the spatial variability ratio. order: int Order of the spatial differential operator to be applied. Supported orders: 1 non_spatial_avg_dims: List[str] List of dimensions over which the spatial variability ratio should be averaged. It must be non-spatial dimensions, i.e. not latitude or longitude. + Returns + ------- + xr.DataArray + Ratio of spatial variability between forecast and ground truth data """ fcst_grad = self.calc_geo_spatial_diff(p, order=order) @@ -1114,9 +1096,6 @@ def calc_spatial_variability( ratio_spat_variability = fcst_grad / ref_grd - if group_by_coord: - ratio_spat_variability = ratio_spat_variability.groupby(group_by_coord) - if non_spatial_avg_dims is not None: ratio_spat_variability = ratio_spat_variability.mean(dim=non_spatial_avg_dims) @@ -1130,8 +1109,7 @@ def calc_seeps( t1: xr.DataArray, t3: xr.DataArray, spatial_dims: list, - group_by_coord: str | None = None, - ): + ) -> xr.DataArray: """ Calculates stable equitable error in probabiliyt space (SEEPS), see Rodwell et al., 2011 @@ -1153,10 +1131,6 @@ def calc_seeps( Threshold for strong precipitation events spatial_dims: List[str] List of spatial dimensions of the data, e.g. ["lat", "lon"] - group_by_coord: str - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the sseps score. - Returns ------- xr.DataArray @@ -1226,9 +1200,6 @@ def seeps(ground_truth, prediction, thr_light, thr_heavy, seeps_weights): if lstack: seeps_values_all = seeps_values_all.unstack() - if group_by_coord: - seeps_values_all = seeps_values_all.groupby(group_by_coord) - if self._agg_dims is not None: seeps_values = self._mean(seeps_values_all) else: @@ -1238,18 +1209,24 @@ def seeps(ground_truth, prediction, thr_light, thr_heavy, seeps_weights): ### Probablistic scores - def calc_spread(self, p: xr.DataArray, group_by_coord: str | None = None): + def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray: """ Calculate the spread of the forecast ensemble - """ - ens_std = p.std(dim=self.ens_dim) + Parameters + ---------- + p: xr.DataArray + Forecast data array with ensemble dimension - if group_by_coord: - ens_std = ens_std.groupby(group_by_coord) + Returns + ------- + xr.DataArray + Spread of the forecast ensemble + """ + ens_std = p.std(dim=self._ens_dim) return self._mean(np.sqrt(ens_std**2)) - def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None): + def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate the Spread-Skill Ratio (SSR) of the forecast ensemble data w.r.t. reference data @@ -1259,13 +1236,12 @@ def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None Forecast data array with ensemble dimension gt: xr.DataArray Ground truth data array - group_by_coord: str | None - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the SSR score. + Returns + ------- + xr.DataArray + Spread-Skill Ratio (SSR) """ - ssr = self.calc_spread(p, group_by_coord) / self.calc_rmse( - p, gt, group_by_coord - ) # spread/rmse + ssr = self.calc_spread(p) / self.calc_rmse(p, gt) # spread/rmse return ssr @@ -1273,10 +1249,9 @@ def calc_crps( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, method: str = "ensemble", **kwargs, - ): + ) -> xr.DataArray: """ Wrapper around CRPS-methods provided by xskillscore-package. See https://xskillscore.readthedocs.io/en/stable/api @@ -1287,9 +1262,6 @@ def calc_crps( Forecast data array with ensemble dimension gt: xr.DataArray Ground truth data array - group_by_coord: str | None - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the CRPS score. method: str Method to calculate CRPS. Supported methods: ["ensemble", "gaussian"] kwargs: dict @@ -1302,22 +1274,18 @@ def calc_crps( """ crps_methods = ["ensemble", "gaussian"] - if group_by_coord: - p = p.groupby(group_by_coord) - gt = gt.groupby(group_by_coord) - if method == "ensemble": func_kwargs = { "forecasts": p, - "member_dim": self.ens_dim, + "member_dim": self._ens_dim, "dim": self._agg_dims, **kwargs, } crps_func = xskillscore.crps_ensemble elif method == "gaussian": func_kwargs = { - "mu": p.mean(dim=self.ens_dim), - "sig": p.std(dim=self.ens_dim), + "mu": p.mean(dim=self._ens_dim), + "sig": p.std(dim=self._ens_dim), "dim": self._agg_dims, **kwargs, } @@ -1335,11 +1303,10 @@ def calc_rank_histogram( self, p: xr.DataArray, gt: xr.DataArray, - group_by_coord: str | None = None, norm: bool = True, add_noise: bool = True, noise_fac=1.0e-03, - ): + ) -> xr.DataArray: """ Calculate the rank histogram of the forecast data w.r.t. reference data. @@ -1352,19 +1319,17 @@ def calc_rank_histogram( norm: bool Flag if normalized counts should be returned. If True, the rank histogram will be normalized by the number of ensemble members in the forecast data. - group_by_coord: str | None - Name of the coordinate to group by. - If provided, the coordinate becomes a new dimension of the rank histogram. add_noise: bool Flag if a small amount of random noise should be added to the data to avoid ties in the rank histogram. This is recommended for fair computations, cf. Sec. 4.2.2 in Harris et al. 2022 noise_fac: float Magnitude of random noise to be added to the data if add_noise is True. Default is 1.0e-03. This value is only relevant if add_noise is True + Returns + ------- + xr.DataArray + Rank histogram data array averaged over the provided dimensions """ - if group_by_coord is not None: - p = p.groupby(group_by_coord) - gt = gt.groupby(group_by_coord) # unstack stacked time-dimension beforehand if required (time may be stacked for forecast data) ground_truth = gt @@ -1398,18 +1363,38 @@ def calc_rank_histogram( da.random.random(size=fcst_stacked.shape, chunks=fcst_stacked.chunks) * noise_fac ) + # preserve the other coordinates + preserved_coords = { + c: obs_stacked[c].values + for c in obs_stacked.coords + if all(dim not in {self._ens_dim, "npoints"} for dim in obs_stacked[c].dims) + } # calculate ranks for all data points - rank = (obs_stacked >= fcst_stacked).sum(dim=self.ens_dim) + rank = (obs_stacked >= fcst_stacked).sum(dim=self._ens_dim) # and count occurence of rank values rank.name = "rank" # name for xr.DataArray is required for histogram-method rank_counts = histogram( rank, dim=["npoints"], - bins=np.arange(len(fcst_stacked[self.ens_dim]) + 2), + bins=np.arange(len(fcst_stacked[self._ens_dim]) + 2), block_size=None if rank.chunks is None else "auto", ) + # Reattach preserved coordinates by broadcasting + for coord_name, coord_values in preserved_coords.items(): + # Only keep unique values along npoints if necessary + if coord_name in rank_counts.coords: + continue + rank_counts = rank_counts.assign_coords({coord_name: coord_values}) + + # Reattach preserved coordinates by broadcasting + for coord_name, coord_values in preserved_coords.items(): + # Only keep unique values along npoints if necessary + if coord_name in rank_counts.coords: + continue + rank_counts = rank_counts.assign_coords({coord_name: coord_values}) + # provide normalized rank counts if desired if norm: npoints = len(fcst_stacked["npoints"]) @@ -1417,13 +1402,23 @@ def calc_rank_histogram( return rank_counts - def calc_rank_histogram_xskillscore(self, p: xr.DataArray, gt: xr.DataArray): + def calc_rank_histogram_xskillscore(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Wrapper around rank_histogram-method by xskillscore-package. See https://xskillscore.readthedocs.io/en/stable/api Note: this version is found to be very slow. Use calc_rank_histogram alternatively. + Parameters + ---------- + p: xr.DataArray + Forecast data array with ensemble dimension + gt: xr.DataArray + Ground truth data array + Returns + ------- + xr.DataArray + Rank histogram data array averaged over the provided dimensions """ - rank_hist = xskillscore.rank_histogram(gt, p, member_dim=self.ens_dim, dim=self._agg_dims) + rank_hist = xskillscore.rank_histogram(gt, p, member_dim=self._ens_dim, dim=self._agg_dims) return rank_hist @@ -1433,15 +1428,26 @@ def calc_geo_spatial_diff( order: int = 1, r_e: float = 6371.0e3, dom_avg: bool = True, - ): + ) -> xr.DataArray: """ Calculates the amplitude of the gradient (order=1) or the Laplacian (order=2) of a scalar field given on a regular, geographical grid (i.e. dlambda = const. and dphi=const.) - :param scalar_field: scalar field as data array with latitude and longitude as coordinates - :param order: order of spatial differential operator - :param r_e: radius of the sphere - :return: the amplitude of the gradient/laplacian at each grid point or over the whole domain (see dom_avg) + + Parameters + ---------- + scalar_field: + Scalar field as data array with latitude and longitude as coordinates + order: + Order of spatial differential operator + r_e: + Radius of the sphere + dom_avg: + Flag whether to return the domain-averaged amplitude or the amplitude at each grid point + Returns + ------- + xr.DataArray + the amplitude of the gradient/laplacian at each grid point or over the whole domain (see dom_avg) """ method = Scores.calc_geo_spatial_diff.__name__ # sanity checks diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index beb9bfc3d..11a6aa7d4 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -119,18 +119,31 @@ def calc_scores_per_stream( # Build up computation graphs for all metrics _logger.debug(f"Build computation graphs for metrics for stream {stream}...") - combined_metrics = [ - get_score( - score_data, - metric, - agg_dims="ipoint", - group_by_coord="sample", - ) + # Add it only if it is not None + valid_scores = [ + score for metric in metrics + if ( + score := get_score( + score_data, + metric, + agg_dims="ipoint", + group_by_coord="sample", + ) + ) + is not None + ] + + # Keep only metrics corresponding to valid_scores + valid_metric_names = [ + metric + for metric, score in zip(metrics, valid_scores, strict=False) + if score is not None ] - combined_metrics = xr.concat(combined_metrics, dim="metric") - combined_metrics["metric"] = metrics + # Concatenate along a new "metric" dimension and assign metric names + combined_metrics = xr.concat(valid_scores, dim="metric") + combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) _logger.debug(f"Running computation of metrics for stream {stream}...") combined_metrics = combined_metrics.compute() @@ -152,11 +165,11 @@ def calc_scores_per_stream( "forecast_step": int(combined_metrics.forecast_step), "sample": combined_metrics.sample, "channel": combined_metrics.channel, + "metric": combined_metrics.metric, } if "ens" in combined_metrics.dims: criteria["ens"] = combined_metrics.ens - metric_stream.loc[criteria] = combined_metrics _logger.info(f"Scores for run {reader.run_id} - {stream} calculated successfully.") @@ -467,7 +480,6 @@ def common_ranges( if not isinstance(maps_config[var].get("vmax"), (int | float)): list_max = calc_bounds(data_tars, data_preds, var, "max") list_max = np.concatenate([arr.flatten() for arr in list_max]).tolist() - maps_config[var].update({"vmax": float(max(list_max))}) if not isinstance(maps_config[var].get("vmin"), (int | float)): @@ -478,7 +490,6 @@ def common_ranges( else: list_max = calc_bounds(data_tars, data_preds, var, "max") list_max = np.concatenate([arr.flatten() for arr in list_max]).tolist() - list_min = calc_bounds(data_tars, data_preds, var, "min") list_min = np.concatenate([arr.flatten() for arr in list_min]).tolist() From a4e7e8bb4c0de09bb17144474127d0547f35599b Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 4 Nov 2025 11:18:20 +0100 Subject: [PATCH 10/32] Adding config to issue templates The issue template seems to have disappeared, attempting to solve that. --- .github/ISSUE_TEMPLATE/config.yml | 1 + 1 file changed, 1 insertion(+) create mode 100644 .github/ISSUE_TEMPLATE/config.yml diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..0086358db --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true From dd8f5a6367f12a5efecd147e774d2f262ca14f6d Mon Sep 17 00:00:00 2001 From: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Date: Tue, 4 Nov 2025 11:26:26 +0100 Subject: [PATCH 11/32] Add the duration of animation as global plotting option (#1189) * Add the animation duration as global plotting option * Linting * Use FPS instead of milliseconds * Linting --- .../src/weathergen/evaluate/plotter.py | 24 ++++++++++++------- .../evaluate/src/weathergen/evaluate/utils.py | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index 0d35a12a1..2ca09920e 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -64,6 +64,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): self.image_format = plotter_cfg.get("image_format") self.dpi_val = plotter_cfg.get("dpi_val") self.fig_size = plotter_cfg.get("fig_size") + self.fps = plotter_cfg.get("fps") self.plot_subtimesteps = plotter_cfg.get( "plot_subtimesteps", False ) # True if plots are created for each valid time separately @@ -518,6 +519,9 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: self.update_data_selection(select) map_output_dir = self.get_map_output_dir(tag) + # Convert FPS to duration in milliseconds + duration_ms = int(1000 / self.fps) if self.fps > 0 else 400 + for _, sa in enumerate(samples): for _, var in enumerate(variables): _logger.info(f"Creating animation for {var} sample: {sa} - {tag}") @@ -542,14 +546,18 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: names = glob.glob(fname) image_paths += names - images = [Image.open(path) for path in image_paths] - images[0].save( - f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{var}.gif", - save_all=True, - append_images=images[1:], - duration=500, - loop=0, - ) + if image_paths: + images = [Image.open(path) for path in image_paths] + images[0].save( + f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{var}.gif", + save_all=True, + append_images=images[1:], + duration=duration_ms, + loop=0, + ) + + else: + _logger.warning(f"No images found for animation {var} sample {sa}") return image_paths diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 11a6aa7d4..da98a1b6b 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -213,6 +213,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: "image_format": global_plotting_opts.get("image_format", "png"), "dpi_val": global_plotting_opts.get("dpi_val", 300), "fig_size": global_plotting_opts.get("fig_size", (8, 10)), + "fps": global_plotting_opts.get("fps", 2), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False), } From 787aa9c5baf735cd04435a0d8390164e167b6022 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 4 Nov 2025 11:41:41 +0100 Subject: [PATCH 12/32] Attempt to fix the bug report template --- .github/ISSUE_TEMPLATE/bug_report.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 68f4080a0..6ba186f3d 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,6 +1,9 @@ name: Bug Report description: Report a bug related to WeatherGenerator. -labels: ["bug"] +title: Bug report +labels: + - "bug" +assignees: [] body: - type: textarea id: what-happened @@ -32,3 +35,5 @@ body: attributes: label: Hedgedoc link to logs and more information. This ticket is public, do not attach files directly. description: Please put all relevant information (logs, plots, etc.) in the Hedgedoc and link it here. + validations: + required: false From a760c1e4412a1ad32f6ef3bed237e30e4db44628 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 4 Nov 2025 11:44:00 +0100 Subject: [PATCH 13/32] Attempt to fix initiative template --- .github/ISSUE_TEMPLATE/initiative.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/initiative.yml b/.github/ISSUE_TEMPLATE/initiative.yml index 68aac136a..83bb8db58 100644 --- a/.github/ISSUE_TEMPLATE/initiative.yml +++ b/.github/ISSUE_TEMPLATE/initiative.yml @@ -1,8 +1,9 @@ name: Initiative description: A piece of work that will likely take more than a week to complete. -title: "" +title: "Initiative" labels: ["initiative"] + body: - type: textarea id: description From 926bdf3439216ca17eb448dae0085b7b7ebdc63f Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 4 Nov 2025 11:45:38 +0100 Subject: [PATCH 14/32] Update task template --- .github/ISSUE_TEMPLATE/task.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/task.yml b/.github/ISSUE_TEMPLATE/task.yml index af7a4cd1a..2008dd81a 100644 --- a/.github/ISSUE_TEMPLATE/task.yml +++ b/.github/ISSUE_TEMPLATE/task.yml @@ -1,6 +1,6 @@ name: Task / Issue description: A task or issue that should take less than a week to complete. -title: "" +title: "Task" body: - type: textarea From 20ce94a7450155474d4d6476c935212cc2630eac Mon Sep 17 00:00:00 2001 From: Michael Tarnawa <18899420+mtar@users.noreply.github.com> Date: Tue, 4 Nov 2025 15:17:41 +0100 Subject: [PATCH 15/32] [1081][Evaluation] Use parent ruff rules (#1177) * use ruff settings from parent * fix code checks * check fixes 2nd round * reformat to line length --- packages/evaluate/pyproject.toml | 53 +------- .../src/weathergen/evaluate/clim_utils.py | 13 +- .../weathergen/evaluate/derived_channels.py | 9 +- .../weathergen/evaluate/export_inference.py | 62 +++++++--- .../src/weathergen/evaluate/io_reader.py | 69 +++++++---- .../src/weathergen/evaluate/plot_utils.py | 3 +- .../src/weathergen/evaluate/plotter.py | 45 ++++--- .../src/weathergen/evaluate/run_evaluation.py | 3 +- .../evaluate/src/weathergen/evaluate/score.py | 114 +++++++++++------- .../src/weathergen/evaluate/score_utils.py | 3 +- .../evaluate/src/weathergen/evaluate/utils.py | 15 ++- src/weathergen/datasets/masking.py | 24 ++-- 12 files changed, 231 insertions(+), 182 deletions(-) diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index 8af89ba95..d1ef2d40a 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -28,58 +28,7 @@ export = "weathergen.evaluate.export_inference:export" # The linting configuration [tool.ruff] - -# Wide rows -line-length = 100 - -[tool.ruff.lint] -# All disabled until the code is formatted. -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # Banned imports - "TID", - # Naming conventions - "N", - # print - "T201" -] - -# These rules are sensible and should be enabled at a later stage. -ignore = [ - # "B006", - "B011", - "UP008", - "SIM117", - "SIM118", - "SIM102", - "SIM401", - "E501", # to be removed - "E721", - # To ignore, not relevant for us - "SIM108", # in case additional norm layer supports are added in future - "N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted) - "E731", # overly restrictive and less readable code - "N812", # prevents us following the convention for importing torch.nn.functional as F -] - -[tool.ruff.lint.flake8-tidy-imports.banned-api] -"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example" - -[tool.ruff.format] -# Use Unix `\n` line endings for all files -line-ending = "lf" - +extend = "../../pyproject.toml" [tool.pyrefly] project-includes = ["src/"] diff --git a/packages/evaluate/src/weathergen/evaluate/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/clim_utils.py index d9fed13c0..768f93803 100644 --- a/packages/evaluate/src/weathergen/evaluate/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/clim_utils.py @@ -52,7 +52,8 @@ def match_climatology_time(target_datetime: pd.Timestamp, clim_data: xr.Dataset) # To Do: leap years and other edge cases if len(matching_indices) == 0: _logger.warning( - f"No matching climatology time found for {target_datetime} (DOY: {target_doy}, Hour: {target_hour})" + f"No matching climatology time found for {target_datetime} (DOY: {target_doy}, " + f"Hour: {target_hour})" f"Please check that climatology data and stream input data filenames match." ) return None @@ -156,8 +157,8 @@ def align_clim_data( if np.any(unmatched_mask): n_unmatched = np.sum(unmatched_mask) raise ValueError( - f"Found {n_unmatched} target coordinates with no matching climatology coordinates. " - f"This will cause incorrect ACC calculations. " + f"Found {n_unmatched} target coordinates with no matching climatology " + f"coordinates. This will cause incorrect ACC calculations. " f"Check coordinate alignment between target and climatology data." ) # Cache the computed indices and target coords @@ -175,8 +176,10 @@ def align_clim_data( except (ValueError, IndexError) as e: raise ValueError( f"Failed to align climatology data with target data for ACC calculation. " - f"This error typically occurs when the number of points per sample varies between samples. " - f"ACC metric is currently only supported for forecasting data with constant points per sample. " + f"This error typically occurs when the number of points per sample varies " + f"between samples. " + f"ACC metric is currently only supported for forecasting data with constant " + f"points per sample. " f"Please ensure all samples have the same spatial coverage and grid points. " f"Original error: {e}" ) from e diff --git a/packages/evaluate/src/weathergen/evaluate/derived_channels.py b/packages/evaluate/src/weathergen/evaluate/derived_channels.py index 7b8dccaf0..7811407d7 100644 --- a/packages/evaluate/src/weathergen/evaluate/derived_channels.py +++ b/packages/evaluate/src/weathergen/evaluate/derived_channels.py @@ -21,9 +21,11 @@ def __init__( Initializes the DeriveChannels class with necessary configurations for channel derivation. Args: - available_channels (np.array): an array of all available channel names in the datasets (target or pred). + available_channels (np.array): an array of all available channel names + in the datasets (target or pred). channels (list): A list of channels of interest to be evaluated and/or plotted. - stream_cfg (dict): A dictionary containing the stream configuration settings for evaluation and plottings. + stream_cfg (dict): A dictionary containing the stream configuration settings for + evaluation and plottings. Returns: None @@ -147,6 +149,7 @@ def get_derived_channels( ) else: _logger.debug( - f"Calculation of {tag} is skipped because it is included in the available channels..." + f"Calculation of {tag} is skipped because it is included " + "in the available channels..." ) return data_tars, data_preds, self.channels diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index 3f67c37d9..bda1c6cd3 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -10,7 +10,9 @@ # weathergen-common = { path = "../../../../../packages/common" } # weathergen = { path = "../../../../../" } # /// -## Example USAGE: uv run export --run-id grwnhykd --stream ERA5 --output-dir /p/home/jusers/owens1/jureca/WeatherGen/test_output1 --format netcdf --type prediction target --fsteps 1 --samples 1 +## Example USAGE: uv run export --run-id grwnhykd --stream ERA5 \ +## --output-dir /p/home/jusers/owens1/jureca/WeatherGen/test_output1 \ +## --format netcdf --type prediction target --fsteps 1 --samples 1 import argparse import logging import re @@ -66,9 +68,11 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: """ Find all the pressure levels for each variable using regex and returns a dictionary mapping variable names to their corresponding pressure levels. + Parameters ---------- all_variables : list of variable names with pressure levels (e.g.,'q_500','t_2m'). + Returns ------- A tuple containing: @@ -333,7 +337,9 @@ def output_filename( forecast_ref_time: np.datetime64, ) -> Path: """ - Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample index, output directory, format and forecast_ref_time. + Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample + index, output directory, format and forecast_ref_time. + Parameters ---------- prefix : Prefix for file name (e.g., 'pred' or 'targ'). @@ -341,6 +347,7 @@ def output_filename( output_dir : Directory to save the output file. output_format : Output file format (currently only 'netcdf' supported). forecast_ref_time : Forecast reference time to include in the filename. + Returns ------- Full path to the output file. @@ -358,9 +365,11 @@ def output_filename( def get_data_worker(args: tuple) -> xr.DataArray: """ Worker function to retrieve data for a single sample and forecast step. + Parameters ---------- args : Tuple containing (sample, fstep, run_id, stream, type). + Returns ------- xarray DataArray for the specified sample and forecast step. @@ -397,18 +406,30 @@ def get_data( Parameters ---------- - run_id : Run ID to identify the Zarr store. - samples : Sample to process - stream : Stream name to retrieve data for (e.g., 'ERA5'). - type : Type of data to retrieve ('target' or 'prediction'). - fsteps : List of forecast steps to retrieve. If None, retrieves all available forecast steps. - channels :List of channels to retrieve. If None, retrieves all available channels. - n_processes : Number of parallel processes to use for data retrieval. - mini_epoch : Mini_epoch number to identify the Zarr store. - rank : Rank number to identify the Zarr store. - output_dir : Directory to save the NetCDF files. - output_format : Output file format (currently only 'netcdf' supported). - config : Loaded config for cf_parser function. + run_id : str + Run ID to identify the Zarr store. + samples : list + Sample to process + stream : str + Stream name to retrieve data for (e.g., 'ERA5'). + dtype : str + Type of data to retrieve ('target' or 'prediction'). + fsteps : list + List of forecast steps to retrieve. If None, retrieves all available forecast steps. + channels : list + List of channels to retrieve. If None, retrieves all available channels. + n_processes : list + Number of parallel processes to use for data retrieval. + mini_epoch : int + Mini_epoch number to identify the Zarr store. + rank : int + Rank number to identify the Zarr store. + output_dir : str + Directory to save the NetCDF files. + output_format : str + Output file format (currently only 'netcdf' supported). + config : OmegaConf + Loaded config for cf_parser function. """ if dtype not in ["target", "prediction"]: raise ValueError(f"Invalid type: {dtype}. Must be 'target' or 'prediction'.") @@ -451,7 +472,8 @@ def get_data( f"{list(set(channels) - set(existing_channels))}. Skipping them." ) result = result.sel(channel=existing_channels) - # reshape result - use adaptive function to handle both regular and Gaussian grids + # reshape result - use adaptive function to handle both regular and Gaussian + # grids result = reshape_dataset_adaptive(result) da_fs.append(result) @@ -484,12 +506,14 @@ def save_sample_to_netcdf( ) -> None: """ Uses list of pred/target xarray DataArrays to save one sample to a NetCDF file. + Parameters ---------- type_str : str Type of data ('pred' or 'targ') to include in the filename. dict_sample_all_steps : dict - Dictionary where keys is sample index and values is a list of xarray DataArrays for all the forecast steps + Dictionary where keys is sample index and values is a list of xarray DataArrays + for all the forecast steps fstep_hours : np.timedelta64 Time difference between forecast steps (e.g., 6 hours). run_id : str @@ -595,7 +619,8 @@ def parse_args(args: list) -> argparse.Namespace: type=int, nargs="+", default=None, - help="List of forecast steps to retrieve (e.g. 1 2 3). If not provided, retrieves all available forecast steps.", + help="List of forecast steps to retrieve (e.g. 1 2 3). " + "If not provided, retrieves all available forecast steps.", ) parser.add_argument( @@ -611,7 +636,8 @@ def parse_args(args: list) -> argparse.Namespace: type=str, nargs="+", default=None, - help="List of channels to retrieve (e.g., 'q_500 t_2m'). If not provided, retrieves all available channels.", + help="List of channels to retrieve (e.g., 'q_500 t_2m'). " + "If not provided, retrieves all available channels.", ) parser.add_argument( diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 2aff8147e..a3d836d7f 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -140,7 +140,8 @@ def check_availability( ii) available in the source file (e.g. the Zarr file, return error otherwise) Additionally, if channels, forecast steps or samples is None/'all', it will i) set the variable to all available vars in source file - ii) return True only if the respective variable contains the same indeces in metric file and source file (return False otherwise) + ii) return True only if the respective variable contains the same indeces in metric file + and source file (return False otherwise) Parameters ---------- @@ -173,18 +174,26 @@ def check_availability( # fill info from available metric file (if provided) available = { - "channel": set(available_data["channel"].values.ravel()) - if available_data is not None - else set(), - "fstep": set(available_data["forecast_step"].values.ravel()) - if available_data is not None - else set(), - "sample": set(available_data.coords["sample"].values.ravel()) - if available_data is not None - else set(), - "ensemble": set(available_data["ens"].values.ravel()) - if available_data is not None and "ens" in available_data.coords - else set(), + "channel": ( + set(available_data["channel"].values.ravel()) + if available_data is not None + else set() + ), + "fstep": ( + set(available_data["forecast_step"].values.ravel()) + if available_data is not None + else set() + ), + "sample": ( + set(available_data.coords["sample"].values.ravel()) + if available_data is not None + else set() + ), + "ensemble": ( + set(available_data["ens"].values.ravel()) + if available_data is not None and "ens" in available_data.coords + else set() + ), } # fill info from reader @@ -204,7 +213,8 @@ def check_availability( # If file with metrics exists, must exactly match if available_data is not None and reader_data[name] != available[name]: _logger.info( - f"Requested all {name}s for {mode}, but previous config was a strict subset. Recomputing." + f"Requested all {name}s for {mode}, but previous config was a " + "strict subset. Recomputing." ) check_score = False @@ -233,7 +243,8 @@ def check_availability( if check_score and not corrected: scope = "metric file" if available_data is not None else "Zarr file" _logger.info( - f"All checks passed – All channels, samples, fsteps requested for {mode} are present in {scope}..." + f"All checks passed – All channels, samples, fsteps requested for {mode} are " + f"present in {scope}..." ) return DataAvailability( @@ -246,7 +257,8 @@ def check_availability( def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailability: """ - Get channels, fsteps and samples for a given run and stream from the config. Replace 'all' with None. + Get channels, fsteps and samples for a given run and stream from the config. + Replace 'all' with None. Parameters ---------- @@ -344,7 +356,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non def get_inference_config(self): """ - load the config associated to the inference run (different from the eval_cfg which contains plot and evaluaiton options.) + load the config associated to the inference run (different from the eval_cfg which + contains plot and evaluaiton options.) Returns ------- @@ -386,7 +399,8 @@ def get_data( cfg : Configuration dictionary containing all information for the evaluation. results_dir : Path - Directory where the inference results are stored. Expected scheme `/`. + Directory where the inference results are stored. + Expected scheme `/`. stream : Stream name to retrieve data for. region : @@ -406,7 +420,8 @@ def get_data( A dataclass containing: - target: Dictionary of xarray DataArrays for targets, indexed by forecast step. - prediction: Dictionary of xarray DataArrays for predictions, indexed by forecast step. - - points_per_sample: xarray DataArray containing the number of points per sample, if `return_counts` is True. + - points_per_sample: xarray DataArray containing the number of points per sample, + if `return_counts` is True. """ bbox = RegionBoundingBox.from_region_name(region) @@ -458,7 +473,8 @@ def get_data( if region != "global": _logger.debug( - f"Applying bounding box mask for region '{region}' to targets and predictions..." + f"Applying bounding box mask for region '{region}' to targets " + "and predictions..." ) target = bbox.apply_mask(target) pred = bbox.apply_mask(pred) @@ -466,7 +482,8 @@ def get_data( npoints = len(target.ipoint) if npoints == 0: _logger.info( - f"Skipping {stream} sample {sample} forecast step: {fstep}. Dataset is empty." + f"Skipping {stream} sample {sample} forecast step: {fstep}. " + "Dataset is empty." ) continue @@ -492,7 +509,8 @@ def get_data( fsteps_final.append(fstep) _logger.debug( - f"Concatenating targets and predictions for stream {stream}, forecast_step {fstep}..." + f"Concatenating targets and predictions for stream {stream}, " + f"forecast_step {fstep}..." ) if da_tars_fs: @@ -515,7 +533,8 @@ def get_data( if set(channels) != set(all_channels): _logger.debug( - f"Restricting targets and predictions to channels {channels} for stream {stream}..." + f"Restricting targets and predictions to channels {channels} " + f"for stream {stream}..." ) da_tars_fs, da_preds_fs, channels = dc.get_derived_channels( @@ -572,8 +591,8 @@ def get_climatology_filename(self, stream: str) -> str | None: clim_data_path = Path(clim_base_dir).join(clim_fn) else: _logger.warning( - f"No climatology path specified for stream {stream}. Setting climatology to NaN. " - "Add 'climatology_path' to evaluation config to use metrics like ACC." + f"No climatology path specified for stream {stream}. Setting climatology to " + "NaN. Add 'climatology_path' to evaluation config to use metrics like ACC." ) return clim_data_path diff --git a/packages/evaluate/src/weathergen/evaluate/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plot_utils.py index 3821fdab6..55e9c9f9a 100644 --- a/packages/evaluate/src/weathergen/evaluate/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plot_utils.py @@ -215,7 +215,8 @@ def bar_plot_metric_region( br_plotter.plot(selected_data, run_ids, channels_set, name) else: _logger.info( - f"Only one run_id for ({region}) region under stream : {stream}. Creating bar plot is skipped..." + f"Only one run_id for ({region}) region under stream : {stream}. " + "Creating bar plot is skipped..." ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index 2ca09920e..f9e34dbb3 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -133,7 +133,8 @@ def select_from_da(self, da: xr.DataArray, selection: dict) -> xr.DataArray: da: xarray DataArray to select data from. selection: - Dictionary of selectors where keys are coordinate names and values are the values to select. + Dictionary of selectors where keys are coordinate names and values are the values to + select. Returns ------- @@ -264,7 +265,8 @@ def plot_histogram( plt.xlabel(f"Variable: {varname}") plt.ylabel("Frequency") plt.title( - f"Histogram of Target and Prediction: {self.stream}, {varname} : fstep = {self.fstep:03}" + f"Histogram of Target and Prediction: {self.stream}, {varname} : " + f"fstep = {self.fstep:03}" ) plt.legend(frameon=False) @@ -322,7 +324,8 @@ def create_maps_per_sample( Additional keyword arguments for the map. Known keys are: - marker_size: base size of the marker (default is 1) - - scale_marker_size: if True, the marker size will be scaled based on latitude (default is False) + - scale_marker_size: if True, the marker size will be scaled based on latitude + (default is False) - marker: marker style (default is 'o') Unknown keys will be passed to the scatter plot function. @@ -618,17 +621,17 @@ def _check_lengths(self, data: xr.DataArray | list, labels: str | list) -> tuple ------- data_list, label_list - lists of data and labels """ - assert type(data) == xr.DataArray or type(data) == list, ( + assert isinstance(data, xr.DataArray | list), ( "Compare::plot - Data should be of type xr.DataArray or list" ) - assert type(labels) == str or type(labels) == list, ( + assert isinstance(labels, str | list), ( "Compare::plot - Labels should be of type str or list" ) # convert to lists - data_list = [data] if type(data) == xr.DataArray else data - label_list = [labels] if type(labels) == str else labels + data_list = [data] if isinstance(data, xr.DataArray) else data + label_list = [labels] if isinstance(labels, str) else labels assert len(data_list) == len(label_list), "Compare::plot - Data and Labels do not match" @@ -714,7 +717,8 @@ def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: ) else: _logger.warning( - f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. Skipping ensemble plotting." + f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. " + "Skipping ensemble plotting." ) def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: @@ -785,7 +789,8 @@ def _plot_ensemble(self, data: xr.DataArray, x_dim: str, label: str) -> None: ) else: _logger.warning( - f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. Skipping ensemble plotting." + f"LinePlot:: Unknown option for plot_ensemble: {self.plot_ensemble}. " + "Skippingensemble plotting." ) def plot( @@ -837,7 +842,8 @@ def plot( else: if non_zero_dims: _logger.info( - f"LinePlot:: Found multiple entries for dimensions: {non_zero_dims}. Averaging..." + f"LinePlot:: Found multiple entries for dimensions: {non_zero_dims}. " + "Averaging..." ) averaged = data.mean( @@ -906,7 +912,8 @@ def plot( self, data: list[xr.DataArray], runs: list[str], channels: list[str], tag: str ) -> None: """ - Plot score cards comparing performance between run_ids against a baseline over channels of interest. + Plot score cards comparing performance between run_ids against a baseline over channels + of interest. Parameters ---------- @@ -1119,7 +1126,8 @@ def get_plot_symbols( color: str The color "red" or "blue" that indicates improvement or deterioration over baseline. triangle: str - The triangle symbol "^" or "v" that indicates improvement or deterioration over baseline. + The triangle symbol "^" or "v" that indicates improvement or deterioration over + baseline. size: xr.DataArray Size of the triangles in the final plot """ @@ -1179,7 +1187,8 @@ def plot( self, data: list[xr.DataArray], runs: list[str], channels: list[str], tag: str ) -> None: """ - Plot (ratio) bar plots comparing performance between different run_ids over channels of interest. + Plot (ratio) bar plots comparing performance between different run_ids over channels of + interest. Parameters ---------- @@ -1220,7 +1229,8 @@ def plot( ) ax[run_index - 1].invert_yaxis() ax[run_index - 1].set_xlabel( - f"Relative {data[0].coords['metric'].item().upper()}: Target Model ({runs[run_index]}) / Reference Model ({runs[0]})" + f"Relative {data[0].coords['metric'].item().upper()}: " + f"Target Model ({runs[run_index]}) / Reference Model ({runs[0]})" ) _logger.info(f"Saving bar plots to: {self.out_plot_dir}") @@ -1276,8 +1286,8 @@ def calc_ratio_per_run_id( def colors(self, ratio_score: np.array) -> list[tuple]: """ - This function calculates colormaps based on the skill scores. From negative value blue color variations - should be given otherwise red color variations should be given. + This function calculates colormaps based on the skill scores. From negative value blue + color variations should be given otherwise red color variations should be given. Parameters ---------- @@ -1298,7 +1308,8 @@ def calculate_average_over_dim( x_dim: str, baseline_var: xr.DataArray, data_var: xr.DataArray ) -> tuple[xr.DataArray, xr.DataArray]: """ - Calculate average over xarray dimensions that are larger than 1. Those might be the forecast-steps or the samples. + Calculate average over xarray dimensions that are larger than 1. Those might be the + forecast-steps or the samples. Parameters ---------- diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index e98d38389..b94a19480 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -91,7 +91,8 @@ def evaluate_from_config(cfg): stream_dict = reader.get_stream(stream) if not stream_dict: _logger.info( - f"Stream {stream} does not exist in source data or config file is empty. Skipping." + f"Stream {stream} does not exist in source data or config file is empty. " + "Skipping." ) continue diff --git a/packages/evaluate/src/weathergen/evaluate/score.py b/packages/evaluate/src/weathergen/evaluate/score.py index e5a45b9b7..5d1de4186 100755 --- a/packages/evaluate/src/weathergen/evaluate/score.py +++ b/packages/evaluate/src/weathergen/evaluate/score.py @@ -28,7 +28,7 @@ except Exception: _logger.warning( "Could not import xskillscore and xhistogram. Thus, CRPS and " - + "rank histogram-calculations are not supported." + "rank histogram-calculations are not supported." ) @@ -40,7 +40,8 @@ def _get_skill_score( ) -> xr.DataArray: """ Calculate the skill score of a forecast data array w.r.t. a reference and a perfect score. - Definition follows Wilks, Statistical Methods in the Atmospheric Sciences (2006), Chapter 7.1.4, Equation 7.4 + Definition follows Wilks, Statistical Methods in the Atmospheric Sciences (2006), + Chapter 7.1.4, Equation 7.4 Parameters ---------- @@ -241,11 +242,11 @@ def get_score( if score_name in self.det_metrics_dict.keys(): f = self.det_metrics_dict[score_name] elif score_name in self.prob_metrics_dict.keys(): - if self._ens_dim not in data.prediction.dims: - _logger.error( - f"Probablistic score {score_name} chosen, but ensemble dimension {self._ens_dim} not found in prediction data. Skipping score calculation." - ) - return None + assert self.ens_dim in data.prediction.dims, ( + f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} " + "not found in prediction data. Skipping score calculation." + ) + return None f = self.prob_metrics_dict[score_name] else: raise ValueError( @@ -264,7 +265,8 @@ def get_score( for dim in self._agg_dims_in: if dim not in data.prediction.dims: raise ValueError( - f"Average dimension '{dim}' not found in prediction data dimensions: {data.prediction.dims}" + f"Average dimension '{dim}' not found in prediction data " + f"dimensions: {data.prediction.dims}" ) self._agg_dims = self._agg_dims_in @@ -336,10 +338,11 @@ def _validate_ens_dim(self, dim: str) -> str: def _validate_groupby_coord(self, data: VerifiedData, group_by_coord: str | None) -> bool: """ - Check if the group_by_coord is present in both prediction and ground truth data and compatible. - Raises ValueError if conditions are not met. + Check if the group_by_coord is present in both prediction and ground truth data + and compatible. Raises ValueError if conditions are not met. If group_by_coord does not have more than one unique value in the prediction data, - a warning is logged and the function returns False, indicating that grouping is not applicable. + a warning is logged and the function returns False, indicating that grouping is + not applicable. Parameters ---------- @@ -356,7 +359,8 @@ def _validate_groupby_coord(self, data: VerifiedData, group_by_coord: str | None p, gt = data.prediction, data.ground_truth if group_by_coord not in p.coords or group_by_coord not in gt.coords: raise ValueError( - f"Coordinate '{group_by_coord}' must be present in both prediction and ground truth data." + f"Coordinate '{group_by_coord}' must be present in both prediction " + "and ground truth data." ) # Check if the dims associated with the groupby_coord are compatible @@ -544,6 +548,7 @@ def calc_l1( """ Calculate the L1 error norm of forecast data w.r.t. reference data. Note that the L1 error norm is calculated as the sum of absolute differences. + Parameters ---------- p: xr.DataArray @@ -552,7 +557,9 @@ def calc_l1( Ground truth data array scale_dims: list | None List of dimensions over which the L1 score will be scaled. - If provided, the L1 score will be divided by the product of the sizes of these dimensions. + If provided, the L1 score will be divided by the product of the sizes of these + dimensions. + Returns ------- xr.DataArray @@ -566,7 +573,8 @@ def calc_l1( scale_dims = to_list(scale_dims) assert all([dim in p.dims for dim in scale_dims]), ( - f"Provided scale dimensions {scale_dims} are not all present in the prediction data dimensions {p.dims}." + f"Provided scale dimensions {scale_dims} are not all present in the prediction " + f"data dimensions {p.dims}." ) len_dims = np.array([p.sizes[dim] for dim in scale_dims]) @@ -592,11 +600,14 @@ def calc_l2( Ground truth data array scale_dims: list | None List of dimensions over which the L2 score will be scaled. - If provided, the L2 score will be divided by the product of the sizes of these dimensions. + If provided, the L2 score will be divided by the product of the sizes of these + dimensions. squared_l2: bool If True, the L2 score will be returned as the sum of squared differences. - If False, the L2 score will be returned as the square root of the sum of squared differences. - Default is False, i.e. the L2 score is returned as the square root of the sum of squared differences. + If False, the L2 score will be returned as the square root of the sum of squared + differences. Default is False, i.e. the L2 score is returned as the square root of the + sum of squared differences. + Returns ------- xr.DataArray @@ -613,7 +624,8 @@ def calc_l2( scale_dims = to_list(scale_dims) assert all([dim in p.dims for dim in scale_dims]), ( - f"Provided scale dimensions {scale_dims} are not all present in the prediction data dimensions {p.dims}." + f"Provided scale dimensions {scale_dims} are not all present in the prediction " + f"data dimensions {p.dims}." ) len_dims = np.array([p.sizes[dim] for dim in scale_dims]) @@ -634,7 +646,8 @@ def calc_mae(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ if self._agg_dims is None: raise ValueError( - "Cannot calculate mean absolute error without aggregation dimensions (agg_dims=None)." + "Cannot calculate mean absolute error without aggregation dimensions " + "(agg_dims=None)." ) return self._mean(np.abs(p - gt)) @@ -656,7 +669,8 @@ def calc_mse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ if self._agg_dims is None: raise ValueError( - "Cannot calculate mean squared error without aggregation dimensions (agg_dims=None)." + "Cannot calculate mean squared error without aggregation dimensions " + "(agg_dims=None)." ) return self._mean(np.square(p - gt)) @@ -678,7 +692,8 @@ def calc_rmse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ if self._agg_dims is None: raise ValueError( - "Cannot calculate root mean squared error without aggregation dimensions (agg_dims=None)." + "Cannot calculate root mean squared error without aggregation dimensions " + "(agg_dims=None)." ) rmse = np.sqrt(self.calc_mse(p, gt)) @@ -687,7 +702,9 @@ def calc_rmse(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: def calc_vrmse(self, p: xr.DataArray, gt: xr.DataArray): """ - Calculate variance-normalized root mean squared error (VRMSE) of forecast data w.r.t. reference data + Calculate variance-normalized root mean squared error (VRMSE) of forecast data w.r.t. + reference data + Parameters ---------- p: xr.DataArray @@ -697,7 +714,8 @@ def calc_vrmse(self, p: xr.DataArray, gt: xr.DataArray): """ if self._agg_dims is None: raise ValueError( - "Cannot calculate variance-normalized root mean squared error without aggregation dimensions (agg_dims=None)." + "Cannot calculate variance-normalized root mean squared error without aggregation " + "dimensions (agg_dims=None)." ) vrmse = np.sqrt(self.calc_mse(p, gt) / (gt.var(dim=self._agg_dims) + 1e-6)) @@ -759,7 +777,8 @@ def sort_by_coords(da_to_sort: xr.DataArray, da_reference: xr.DataArray) -> xr.D if np.any(unmatched_mask): n_unmatched = np.sum(unmatched_mask) _logger.info( - f"Found {n_unmatched} reference coordinates with no matching coordinates in array to sort. Returning NaN DataArray." + f"Found {n_unmatched} reference coordinates with no matching coordinates in array" + "to sort. Returning NaN DataArray." ) return xr.full_like(da_reference, np.nan) @@ -772,7 +791,8 @@ def calc_change_rate( s1: xr.DataArray, ) -> xr.DataArray: """ - Calculate the "change rate" of a data array as the mean absolute difference between two consecutive time steps. + Calculate the "change rate" of a data array as the mean absolute difference between two + consecutive time steps. Parameters ---------- @@ -1068,8 +1088,9 @@ def calc_spatial_variability( non_spatial_avg_dims: list[str] = None, ) -> xr.DataArray: """ - Calculates the ratio between the spatial variability of differental operator with order 1 (higher values unsupported yest) - forecast and ground truth data using the calc_geo_spatial-method. + Calculates the ratio between the spatial variability of differental operator + with order 1 (higher values unsupported yet) forecast and ground truth data using + the calc_geo_spatial-method. NOTE: Requires that data is provided on a regular lat/lon-grid! @@ -1168,7 +1189,8 @@ def seeps(ground_truth, prediction, thr_light, thr_heavy, seeps_weights): # check dimensioning of data assert prediction.ndim <= 2, ( - f"Data must be one- or two-dimensional, but has {prediction.ndim} dimensions. Check if stacking with spatial_dims may help." + f"Data must be one- or two-dimensional, but has {prediction.ndim} dimensions. " + "Check if stacking with spatial_dims may help." ) if prediction.ndim == 1: @@ -1265,7 +1287,8 @@ def calc_crps( method: str Method to calculate CRPS. Supported methods: ["ensemble", "gaussian"] kwargs: dict - Other keyword parameters supported by respective CRPS-method from the xskillscore package + Other keyword parameters supported by respective CRPS-method from + the xskillscore package Returns ------- @@ -1292,7 +1315,8 @@ def calc_crps( crps_func = xskillscore.crps_gaussian else: raise ValueError( - f"Unsupported CRPS-calculation method {method} chosen. Supported methods: {', '.join(crps_methods)}" + f"Unsupported CRPS-calculation method {method} chosen." + + f"Supported methods: {', '.join(crps_methods)}" ) crps = crps_func(gt, **func_kwargs) @@ -1317,21 +1341,24 @@ def calc_rank_histogram( gt: xr.DataArray Ground truth data array norm: bool - Flag if normalized counts should be returned. If True, the rank histogram will be normalized by - the number of ensemble members in the forecast data. + Flag if normalized counts should be returned. If True, the rank histogram will be + normalized by the number of ensemble members in the forecast data. add_noise: bool - Flag if a small amount of random noise should be added to the data to avoid ties in the rank histogram. + Flag if a small amount of random noise should be added to the data to avoid ties in the + rank histogram. This is recommended for fair computations, cf. Sec. 4.2.2 in Harris et al. 2022 noise_fac: float - Magnitude of random noise to be added to the data if add_noise is True. Default is 1.0e-03. - This value is only relevant if add_noise is True + Magnitude of random noise to be added to the data if add_noise is True. + Default is 1.0e-03. This value is only relevant if add_noise is True + Returns ------- xr.DataArray Rank histogram data array averaged over the provided dimensions """ - # unstack stacked time-dimension beforehand if required (time may be stacked for forecast data) + # unstack stacked time-dimension beforehand if required (time may be stacked for forecast + # data) ground_truth = gt if "time" in ground_truth.indexes: if isinstance(ground_truth.indexes["time"], pd.MultiIndex): @@ -1436,18 +1463,21 @@ def calc_geo_spatial_diff( Parameters ---------- - scalar_field: + scalar_field: xr.DataArray Scalar field as data array with latitude and longitude as coordinates - order: + order: int Order of spatial differential operator - r_e: + r_e: float Radius of the sphere - dom_avg: - Flag whether to return the domain-averaged amplitude or the amplitude at each grid point + dom_avg: bool + Flag whether to return the domain-averaged amplitude or the amplitude at each + grid point + Returns ------- xr.DataArray - the amplitude of the gradient/laplacian at each grid point or over the whole domain (see dom_avg) + the amplitude of the gradient/laplacian at each grid point or over the whole domain + (see dom_avg) """ method = Scores.calc_geo_spatial_diff.__name__ # sanity checks diff --git a/packages/evaluate/src/weathergen/evaluate/score_utils.py b/packages/evaluate/src/weathergen/evaluate/score_utils.py index 1561216ca..a6339d009 100644 --- a/packages/evaluate/src/weathergen/evaluate/score_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/score_utils.py @@ -70,7 +70,8 @@ def validate(self): ) if not (-180 <= self.lon_min <= 180 and -180 <= self.lon_max <= 180): raise ValueError( - f"Longitude bounds must be between -180 and 180. Got: {self.lon_min}, {self.lon_max}" + "Longitude bounds must be between -180 and 180. " + + f"Got: {self.lon_min}, {self.lon_max}" ) if self.lat_min >= self.lat_max: raise ValueError( diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index da98a1b6b..3245d1dc8 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -151,9 +151,11 @@ def calc_scores_per_stream( combined_metrics = scalar_coord_to_dim(combined_metrics, "sample") combined_metrics = scalar_coord_to_dim(combined_metrics, "ens") else: - # depending on the datset, there might be no data (e.g. no CERRA in southern hemisphere region) + # depending on the datset, there might be no data (e.g. no CERRA in southern + # hemisphere region) _logger.warning( - f"No data available for stream {stream} at forecast step {fstep} in region {region}. Skipping metrics calculation." + f"No data available for stream {stream} at forecast step {fstep} in " + f"region {region}. Skipping metrics calculation." ) continue @@ -368,7 +370,8 @@ def metric_list_to_json( json.dump(metric_dict, f, indent=4) _logger.info( - f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} successfully to {reader.metrics_dir}." + f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} " + f"successfully to {reader.metrics_dir}." ) @@ -473,7 +476,8 @@ def common_ranges( Returns ------- maps_config : - the global plotting configuration with the ranges added and included for each variable (and for each stream). + the global plotting configuration with the ranges added and included for each variable (and + for each stream). """ for var in plot_chs: @@ -527,7 +531,8 @@ def calc_bounds( bound, ): """ - Calculate the minimum and maximum values per variable for all forecasteps for both targets and predictions + Calculate the minimum and maximum values per variable for all forecasteps for both targets and + predictions Parameters ---------- diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 93de031fe..fbcf10f3a 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -68,9 +68,9 @@ def __init__(self, cf: Config): if self.current_strategy == "healpix": hl_data = self.healpix_level_data hl_mask = self.masking_strategy_config.get("hl_mask") - assert ( - hl_data is not None and hl_mask is not None - ), "If HEALPix masking, hl_mask must be given in masking_strategy_config." + assert hl_data is not None and hl_mask is not None, ( + "If HEALPix masking, hl_mask must be given in masking_strategy_config." + ) assert hl_mask < hl_data, "hl_mask must be less than hl_data for HEALPix masking." if self.current_strategy == "channel": @@ -85,15 +85,15 @@ def __init__(self, cf: Config): # check explicit includes source_include = stream.get("source_include", []) target_include = stream.get("target_include", []) - assert set(source_include) == set( - target_include - ), "Source and target channels not identical. Required for masking_mode=channel" + assert set(source_include) == set(target_include), ( + "Source and target channels not identical. Required for masking_mode=channel" + ) # check excludes source_exclude = stream.get("source_exclude", []) target_exclude = stream.get("target_exclude", []) - assert set(source_exclude) == set( - target_exclude - ), "Source and target channels not identical. Required for masking_mode=channel" + assert set(source_exclude) == set(target_exclude), ( + "Source and target channels not identical. Required for masking_mode=channel" + ) def reset_rng(self, rng) -> None: """ @@ -367,9 +367,9 @@ def _generate_healpix_mask(self, token_lens: list[int], rate: float) -> np.typin hl_data = self.healpix_level_data hl_mask = self.masking_strategy_config.get("hl_mask") - assert ( - len(token_lens) == self.healpix_num_cells - ), f"Expected {self.healpix_num_cells} cells at level {hl_data}, got {len(token_lens)}." + assert len(token_lens) == self.healpix_num_cells, ( + f"Expected {self.healpix_num_cells} cells at level {hl_data}, got {len(token_lens)}." + ) # Calculate the number of parent cells at the mask level (hl_mask) num_parent_cells = 12 * (4**hl_mask) From f3e195acf3987b0fa723b62f724595efbd162760 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Wed, 5 Nov 2025 09:33:55 +0100 Subject: [PATCH 16/32] [1092] Adds pushing metrics to the evaluation pipeline (#1127) * changes * changes * changes * changes * changes * scores successfully pushed to MLFlow, still need to refactor * try to batch upload all metrics form same runid * batch logging all scores of each run_id * get parent_run by from_run_id * changes * cleanups * bug fixes * typing issue * Cleanup * pdb * integration test --------- Co-authored-by: Jubeku --- integration_tests/small1_test.py | 3 +- packages/common/src/weathergen/common/io.py | 1 - .../src/weathergen/common/platform_env.py | 38 +++ packages/evaluate/pyproject.toml | 3 +- .../src/weathergen/evaluate/plot_utils.py | 4 +- .../src/weathergen/evaluate/run_evaluation.py | 73 ++++- packages/metrics/pyproject.toml | 102 ++++++ .../src/weathergen/metrics/__init__.py | 0 .../src/weathergen/metrics/mlflow_utils.py | 176 +++++++++++ pyproject.toml | 4 +- scripts/actions.sh | 10 + uv.lock | 297 ++++++++++++++++++ 12 files changed, 701 insertions(+), 10 deletions(-) create mode 100644 packages/common/src/weathergen/common/platform_env.py create mode 100644 packages/metrics/pyproject.toml create mode 100644 packages/metrics/src/weathergen/metrics/__init__.py create mode 100644 packages/metrics/src/weathergen/metrics/mlflow_utils.py diff --git a/integration_tests/small1_test.py b/integration_tests/small1_test.py index 47349b8af..e9b35ed8e 100644 --- a/integration_tests/small1_test.py +++ b/integration_tests/small1_test.py @@ -134,7 +134,8 @@ def evaluate_results(run_id): }, } ) - evaluate_from_config(cfg) + # Not passing the mlflow client for tests. + evaluate_from_config(cfg, None) def load_metrics(run_id): diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 3e7594d1c..2dba8b727 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -83,7 +83,6 @@ def combine(cls, others: list["IOReaderData"]) -> "IOReaderData": others is list of ReaderData instances. """ - assert len(others) > 0, len(others) other = others[0] diff --git a/packages/common/src/weathergen/common/platform_env.py b/packages/common/src/weathergen/common/platform_env.py new file mode 100644 index 000000000..485969588 --- /dev/null +++ b/packages/common/src/weathergen/common/platform_env.py @@ -0,0 +1,38 @@ +""" +Platform environment configuration for WeatherGenerator. + +These are loaded from secrets in the private repository. +""" + +import importlib +import importlib.util +from functools import lru_cache +from typing import Protocol + +from weathergen.common.config import _REPO_ROOT + + +class PlatformEnv(Protocol): + """ + Interface for platform environment configuration. + """ + + def get_hpc(self) -> str | None: ... + + def get_hpc_user(self) -> str | None: ... + + def get_hpc_config(self) -> str | None: ... + + def get_hpc_certificate(self) -> str | None: ... + + +@lru_cache(maxsize=1) +def get_platform_env() -> PlatformEnv: + """ + Loads the platform environment module from the private repository. + """ + env_script_path = _REPO_ROOT.parent / "WeatherGenerator-private" / "hpc" / "platform-env.py" + spec = importlib.util.spec_from_file_location("platform_env", env_script_path) + platform_env = importlib.util.module_from_spec(spec) + spec.loader.exec_module(platform_env) # type: ignore + return platform_env # type: ignore diff --git a/packages/evaluate/pyproject.toml b/packages/evaluate/pyproject.toml index d1ef2d40a..862358e5d 100644 --- a/packages/evaluate/pyproject.toml +++ b/packages/evaluate/pyproject.toml @@ -10,8 +10,9 @@ dependencies = [ "xhistogram", "panel", "omegaconf", - "weathergen-common", "plotly>=6.2.0", + "weathergen-common", + "weathergen-metrics", ] [dependency-groups] diff --git a/packages/evaluate/src/weathergen/evaluate/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plot_utils.py index 55e9c9f9a..3000b0767 100644 --- a/packages/evaluate/src/weathergen/evaluate/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plot_utils.py @@ -30,7 +30,7 @@ def collect_streams(runs: dict): return sorted({s for run in runs.values() for s in run["streams"].keys()}) -def collect_channels(scores_dict: dict, metric: str, region: str, runs) -> dict: +def collect_channels(scores_dict: dict, metric: str, region: str, runs) -> list[str]: """Get all unique channels available for given metric and region across runs. Parameters @@ -56,7 +56,7 @@ def collect_channels(scores_dict: dict, metric: str, region: str, runs) -> dict: if run_id not in run_data: continue values = run_data[run_id]["channel"].values - channels.update(np.atleast_1d(values)) + channels.update([str(x) for x in np.atleast_1d(values)]) return list(channels) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index b94a19480..3e30acedd 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -3,6 +3,7 @@ # dependencies = [ # "weathergen-evaluate", # "weathergen-common", +# "weathergen-metrics", # ] # [tool.uv.sources] # weathergen-evaluate = { path = "../../../../../packages/evaluate" } @@ -14,10 +15,15 @@ from collections import defaultdict from pathlib import Path +import mlflow +from mlflow.client import MlflowClient from omegaconf import OmegaConf +from xarray import DataArray from weathergen.common.config import _REPO_ROOT +from weathergen.common.platform_env import get_platform_env from weathergen.evaluate.io_reader import WeatherGenReader +from weathergen.evaluate.plot_utils import collect_channels from weathergen.evaluate.utils import ( calc_scores_per_stream, metric_list_to_json, @@ -25,11 +31,19 @@ plot_summary, retrieve_metric_from_json, ) +from weathergen.metrics.mlflow_utils import ( + MlFlowUpload, + get_or_create_mlflow_parent_run, + log_scores, + setup_mlflow, +) _logger = logging.getLogger(__name__) _DEFAULT_PLOT_DIR = _REPO_ROOT / "plots" +_platform_env = get_platform_env() + def evaluate() -> None: # By default, arguments from the command line are read. @@ -37,6 +51,8 @@ def evaluate() -> None: def evaluate_from_args(argl: list[str]) -> None: + # configure logging + logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Fast evaluation of WeatherGenerator runs.") parser.add_argument( "--config", @@ -44,6 +60,12 @@ def evaluate_from_args(argl: list[str]) -> None: default=None, help="Path to the configuration yaml file for plotting. e.g. config/plottig_config.yaml", ) + parser.add_argument( + "--push-metrics", + required=False, + action="store_true", + help="(optional) Upload scores to MLFlow.", + ) args = parser.parse_args(argl) if args.config: @@ -53,13 +75,19 @@ def evaluate_from_args(argl: list[str]) -> None: "No config file provided, using the default template config (please edit accordingly)" ) config = Path(_REPO_ROOT / "config" / "evaluate" / "eval_config.yml") - evaluate_from_config(OmegaConf.load(config)) + mlflow_client: MlflowClient | None = None + if args.push_metrics: + hpc_conf = _platform_env.get_hpc_config() + assert hpc_conf is not None + private_home = Path(hpc_conf) + private_cf = OmegaConf.load(private_home) + mlflow_client = setup_mlflow(private_cf) + _logger.info(f"MLFlow client set up: {mlflow_client}") + evaluate_from_config(OmegaConf.load(config), mlflow_client) -def evaluate_from_config(cfg): - # configure logging - logging.basicConfig(level=logging.INFO) +def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: # load configuration runs = cfg.run_ids @@ -149,6 +177,43 @@ def evaluate_from_config(cfg): {"metric": metric} ) + if mlflow_client: + # Reorder scores_dict to push to MLFlow per run_id: + # Create a new defaultdict with the target structure: [run_id][metric][region][stream] + reordered_dict: dict[str, dict[str, dict[str, dict[str, DataArray]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(dict)) + ) + + # Iterate through the original dictionary to get all keys and the final value + for metric, regions_dict in scores_dict.items(): + for region, streams_dict in regions_dict.items(): + for stream, runs_dict in streams_dict.items(): + for run_id, final_dict in runs_dict.items(): + # Assign the final_dict to the new structure using the reordered keys + reordered_dict[run_id][metric][region][stream] = final_dict + + channels_set = collect_channels(scores_dict, metric, region, runs) + + for run_id, run in runs.items(): + reader = WeatherGenReader(run, run_id, private_paths) + from_run_id = reader.inference_cfg["from_run_id"] + parent_run = get_or_create_mlflow_parent_run(mlflow_client, from_run_id) + _logger.info(f"MLFlow parent run: {parent_run}") + phase = "eval" + with mlflow.start_run(run_id=parent_run.info.run_id): + with mlflow.start_run( + run_name=f"{phase}_{from_run_id}_{run_id}", + parent_run_id=parent_run.info.run_id, + nested=True, + ) as run: + mlflow.set_tags(MlFlowUpload.run_tags(run_id, phase, from_run_id)) + log_scores( + reordered_dict[run_id], + mlflow_client, + run.info.run_id, + channels_set, + ) + # plot summary if scores_dict and cfg.evaluation.get("summary_plots", True): _logger.info("Started creating summary plots..") diff --git a/packages/metrics/pyproject.toml b/packages/metrics/pyproject.toml new file mode 100644 index 000000000..ba54aa4a1 --- /dev/null +++ b/packages/metrics/pyproject.toml @@ -0,0 +1,102 @@ +[project] +name = "weathergen-metrics" +version = "0.1.0" +description = "The WeatherGenerator Machine Learning Earth System Model" +readme = "../../README.md" +requires-python = ">=3.12,<3.13" +dependencies = [ + "mlflow-skinny", + "weathergen-common", +] + +[dependency-groups] +dev = [ + "pytest~=8.3.5", + "pytest-mock>=3.14.1", + "ruff==0.9.7", + "pyrefly==0.36.0", +] + + +[tool.pyrefly] +project-includes = ["src/"] +project-excludes = [ +] + +[tool.pyrefly.errors] +bad-argument-type = false +unsupported-operation = false +missing-attribute = false +no-matching-overload = false +bad-context-manager = false + +# To do: +bad-assignment = false +bad-return = false +index-error = false +not-iterable = false +not-callable = false + + + + +# The linting configuration +[tool.ruff] + +# Wide rows +line-length = 100 + +[tool.ruff.lint] +# All disabled until the code is formatted. +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # Banned imports + "TID", + # Naming conventions + "N", + # print + "T201" +] + +# These rules are sensible and should be enabled at a later stage. +ignore = [ + # "B006", + "B011", + "UP008", + "SIM117", + "SIM118", + "SIM102", + "SIM401", + # To ignore, not relevant for us + "SIM108", # in case additional norm layer supports are added in future + "N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted) + "E731", # overly restrictive and less readable code + "N812", # prevents us following the convention for importing torch.nn.functional as F +] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example" + +[tool.ruff.format] +# Use Unix `\n` line endings for all files +line-ending = "lf" + + + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/weathergen"] diff --git a/packages/metrics/src/weathergen/metrics/__init__.py b/packages/metrics/src/weathergen/metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/metrics/src/weathergen/metrics/mlflow_utils.py b/packages/metrics/src/weathergen/metrics/mlflow_utils.py new file mode 100644 index 000000000..27a8bec8e --- /dev/null +++ b/packages/metrics/src/weathergen/metrics/mlflow_utils.py @@ -0,0 +1,176 @@ +import logging +import os + +import mlflow +import mlflow.client +import numpy as np +from mlflow.client import MlflowClient +from mlflow.entities.metric import Metric +from mlflow.entities.run import Run +from xarray import DataArray + +from weathergen.common.config import Config +from weathergen.common.platform_env import get_platform_env + +_logger = logging.getLogger(__name__) + +project_name = "WeatherGenerator" +project_lifecycle = "dev" + +_platform_env = get_platform_env() + + +class MlFlowUpload: + tracking_uri = "databricks" + registry_uri = "databricks-uc" + experiment_name = "/Shared/weathergen-dev/core-model/defaultExperiment" + + experiment_tags = { + "project": project_name, + "lifecycle": project_lifecycle, + } + + @classmethod + def run_tags(cls, run_id: str, phase: str, from_run_id: str | None) -> dict[str, str]: + """ + Returns the tags to be set for a run. + """ + dct = { + "lifecycle": project_lifecycle, + "hpc": _platform_env.get_hpc() or "unknown", + "run_id": run_id, + "stage": phase, + "project": project_name, + "uploader": _platform_env.get_hpc_user() or "unknown", + "completion_status": "success", + } + if from_run_id: + dct["from_run_id"] = from_run_id + return dct + + +def log_metrics( + metrics: list[dict[str, float | int]], + mlflow_client: MlflowClient, + mlflow_run_id: str, +): + """ + Logs the metrics to MLFlow. + """ + if not metrics: + return + + # Converts teh metrics to a single batch of metrics object. This limits the IO and DB calls + def _convert_to_mlflow_metric(dct): + # Convert the metric to a mlflow metric + ts = int(dct.get("weathergen.timestamp", 0)) + step = int(dct.get("weathergen.step", 0)) + return [ + Metric(key=k, value=v, timestamp=ts, step=step) + for k, v in dct.items() + if not k.startswith("weathergen.") + ] + + mlflow_metrics = [met for dct in metrics for met in _convert_to_mlflow_metric(dct)] + mlflow_client.log_batch( + run_id=mlflow_run_id, + metrics=mlflow_metrics, + ) + + +def log_scores( + metrics_dict: dict[str, dict[str, dict[str, DataArray]]], + mlflow_client: MlflowClient, + mlflow_run_id: str, + channels_set: list[str], + x_dim="forecast_step", +): + """ + Logs the evaluation scores to MLFlow. + metrics_dict: metric -> region -> stream -> DataArray + """ + + ts = 0 + + mlflow_metrics = [] + for metric, regions_dict in metrics_dict.items(): + for region, streams_dict in regions_dict.items(): + for stream, data in streams_dict.items(): + for ch in channels_set: + # skip if channel is missing or contains NaN + if ch not in np.atleast_1d(data.channel.values) or data.isnull().all(): + _logger.info( + f"Skipping channel {ch} for {metric} - {region} - {stream} ", + "due to missing data.", + ) + continue + _logger.info(f"Collecting data for {metric} - {region} - {stream} - {ch}.") + data_ch = data.sel(channel=ch) + non_zero_dims = [ + dim for dim in data_ch.dims if dim != x_dim and data_ch[dim].shape[0] > 1 + ] + if "ens" in non_zero_dims: + _logger.info("Uploading ensembles not yet imnplemented") + else: + if non_zero_dims: + _logger.info( + f"LinePlot:: Found multiple entries for dimensions: {non_zero_dims}" + + ". Averaging..." + ) + averaged = data_ch.mean( + dim=[dim for dim in data_ch.dims if dim != x_dim], skipna=True + ).sortby(x_dim) + label = f"score.{region}.{metric}.{stream}.{ch}" + + mlflow_metrics.append( + [ + Metric(key=label, value=y, timestamp=ts, step=int(x)) + for x, y in zip( + averaged[x_dim].values, averaged.values, strict=False + ) + ] + ) + + all_metrics = [met for dict in mlflow_metrics for met in dict] + _logger.info(f"Logging total of {len(all_metrics)} metrics to MLFlow.") + mlflow_client.log_batch( + run_id=mlflow_run_id, + metrics=all_metrics, + ) + + +def setup_mlflow(private_config: Config) -> MlflowClient: + os.environ["DATABRICKS_HOST"] = private_config["mlflow"]["tracking_uri"] + os.environ["DATABRICKS_TOKEN"] = private_config["secrets"]["mlflow_token"] + mlflow.set_tracking_uri(MlFlowUpload.tracking_uri) + mlflow.set_registry_uri(MlFlowUpload.registry_uri) + mlflow_client = mlflow.client.MlflowClient( + tracking_uri=MlFlowUpload.tracking_uri, registry_uri=MlFlowUpload.registry_uri + ) + return mlflow_client + + +def get_or_create_mlflow_parent_run(mlflow_client: MlflowClient, run_id: str) -> Run: + exp_name = MlFlowUpload.experiment_name + _logger.info(f"Setting experiment name to {exp_name}: host: {os.environ['DATABRICKS_HOST']}") + exp = mlflow.set_experiment(exp_name) + _logger.info(f"Experiment {exp_name} created with ID {exp.experiment_id}: {exp}") + runs = mlflow_client.search_runs( + experiment_ids=[exp.experiment_id], + filter_string=f"tags.run_id='{run_id}' AND tags.stage='unknown'", + ) + if len(runs) == 0: + _logger.info(f"No existing parent run found for run_id {run_id}, creating new run") + return mlflow_client.create_run( + experiment_id=exp.experiment_id, + tags=MlFlowUpload.run_tags(run_id, "unknown", from_run_id=None), + run_name=run_id, + ) + if len(runs) > 1: + _logger.warning( + ( + f"Multiple existing parent runs found for run_id {run_id},", + f" using the first one: {runs[0].info.run_id}", + ) + ) + return runs[0] diff --git a/pyproject.toml b/pyproject.toml index 411d29683..016552e8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,6 +213,7 @@ explicit = true [tool.uv.sources] weathergen-common = { workspace = true } +weathergen-metrics = { workspace = true } weathergen-evaluate = { workspace = true } @@ -246,6 +247,7 @@ log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.uv.workspace] members = [ "packages/evaluate", - "packages/common" + "packages/common", + "packages/metrics", ] diff --git a/scripts/actions.sh b/scripts/actions.sh index a99b9d896..c19d20f4b 100755 --- a/scripts/actions.sh +++ b/scripts/actions.sh @@ -49,6 +49,16 @@ case "$1" in exit 1 fi + # weathergen-metrics + uv sync --project packages/metrics --no-install-workspace + uv pip list + uv run --project packages/metrics --frozen pyrefly check packages/metrics + # Fail for errors on weathergen-metrics: + if [ $? -ne 0 ]; then + echo "Type checking failed for weathergen-metrics." + exit 1 + fi + # weathergen-evaluate uv sync --project packages/evaluate --no-install-workspace --package weathergen-evaluate uv pip list diff --git a/uv.lock b/uv.lock index 74a894572..c4f489d1b 100644 --- a/uv.lock +++ b/uv.lock @@ -19,6 +19,7 @@ members = [ "weathergen", "weathergen-common", "weathergen-evaluate", + "weathergen-metrics", ] [[package]] @@ -114,6 +115,20 @@ version = "4.9.3" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } +[[package]] +name = "anyio" +version = "4.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "sniffio", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, +] + [[package]] name = "array-api-compat" version = "1.12.0" @@ -230,6 +245,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/48/08b2382e739236aa3360b7976360ba3e0c043b6234e25951c18c1eb6fa06/bokeh-3.7.3-py3-none-any.whl", hash = "sha256:b0e79dd737f088865212e4fdcb0f3b95d087f0f088bf8ca186a300ab1641e2c7", size = 7031447, upload-time = "2025-05-12T12:13:27.47Z" }, ] +[[package]] +name = "cachetools" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/7e/b975b5814bd36faf009faebe22c1072a1fa1168db34d285ef0ba071ad78c/cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201", size = 31325, upload-time = "2025-10-12T14:55:30.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/c5/1e741d26306c42e2bf6ab740b2202872727e0f606033c9dd713f8b93f5a8/cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701", size = 11280, upload-time = "2025-10-12T14:55:28.382Z" }, +] + [[package]] name = "cartopy" version = "0.24.1" @@ -426,6 +450,20 @@ array = [ { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] +[[package]] +name = "databricks-sdk" +version = "0.69.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/ba/1dc248e4cc646a1a29504bcbb910bfb28d3affe58063df622e7e3c5c0634/databricks_sdk-0.69.0.tar.gz", hash = "sha256:5ad7514325d941afe47da4cf8748ba9f7da7250977666c519f534c9f6298d2f5", size = 794676, upload-time = "2025-10-20T11:38:15.004Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/73/6f82f2a926a2129f9a08ba550b3f5c837d23156082c8d1f4226801168456/databricks_sdk-0.69.0-py3-none-any.whl", hash = "sha256:f75c37c0da2126d9fec31cefd7b5c5491a7c8b5d62481cd661d3e9f1efec0b1f", size = 749754, upload-time = "2025-10-20T11:38:13.451Z" }, +] + [[package]] name = "debugpy" version = "1.8.15" @@ -599,6 +637,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/c3/6f0e3896f193528bbd2b4d2122d4be8108a37efab0b8475855556a8c4afa/fancycompleter-0.11.1-py3-none-any.whl", hash = "sha256:44243d7fab37087208ca5acacf8f74c0aa4d733d04d593857873af7513cdf8a6", size = 11207, upload-time = "2025-05-26T12:59:09.857Z" }, ] +[[package]] +name = "fastapi" +version = "0.119.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "starlette", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/f4/152127681182e6413e7a89684c434e19e7414ed7ac0c632999c3c6980640/fastapi-0.119.1.tar.gz", hash = "sha256:a5e3426edce3fe221af4e1992c6d79011b247e3b03cc57999d697fe76cbf8ae0", size = 338616, upload-time = "2025-10-20T11:30:27.734Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/26/e6d959b4ac959fdb3e9c4154656fc160794db6af8e64673d52759456bf07/fastapi-0.119.1-py3-none-any.whl", hash = "sha256:0b8c2a2cce853216e150e9bd4faaed88227f8eb37de21cb200771f491586a27f", size = 108123, upload-time = "2025-10-20T11:30:26.185Z" }, +] + [[package]] name = "fasteners" version = "0.19" @@ -751,6 +803,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599, upload-time = "2025-01-02T07:32:40.731Z" }, ] +[[package]] +name = "google-auth" +version = "2.41.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pyasn1-modules", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "rsa", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/af/5129ce5b2f9688d2fa49b463e544972a7c82b0fdb50980dafee92e121d9f/google_auth-2.41.1.tar.gz", hash = "sha256:b76b7b1f9e61f0cb7e88870d14f6a94aeef248959ef6992670efee37709cbfd2", size = 292284, upload-time = "2025-09-30T22:51:26.363Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/a4/7319a2a8add4cc352be9e3efeff5e2aacee917c85ca2fa1647e29089983c/google_auth-2.41.1-py2.py3-none-any.whl", hash = "sha256:754843be95575b9a19c604a848a41be03f7f2afd8c019f716dc1f51ee41c639d", size = 221302, upload-time = "2025-09-30T22:51:24.212Z" }, +] + [[package]] name = "grpcio" version = "1.74.0" @@ -769,6 +835,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/59/900aa2445891fc47a33f7d2f76e00ca5d6ae6584b20d19af9c06fa09bf9a/grpcio-1.74.0-cp312-cp312-win_amd64.whl", hash = "sha256:42f8fee287427b94be63d916c90399ed310ed10aadbf9e2e5538b3e497d269bc", size = 4490123, upload-time = "2025-07-24T18:53:39.528Z" }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + [[package]] name = "hatchling" version = "1.27.0" @@ -793,6 +868,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -1115,6 +1202,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] +[[package]] +name = "mlflow-skinny" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "click", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "cloudpickle", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "databricks-sdk", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "fastapi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "gitpython", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "importlib-metadata", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "opentelemetry-api", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "opentelemetry-proto", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "opentelemetry-sdk", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "packaging", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "python-dotenv", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pyyaml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "requests", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "sqlparse", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "uvicorn", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/12/3143c5275531cc318146a1b36f0780991e899639551e5554d27573ba74be/mlflow_skinny-3.5.0.tar.gz", hash = "sha256:d9cf914ed6746a6097ef51d1a377a4c5c0f46aa174d3f89efbdc31feb2cf572b", size = 1925967, upload-time = "2025-10-16T14:04:13.777Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9a/bc/1e0c324bdd4e49d386625e6d5259a1352d8b4a39dc4af36b9dd474536843/mlflow_skinny-3.5.0-py3-none-any.whl", hash = "sha256:496cb9bf4e0d5b96082407a923e34636ea748ab928d35c288d1f19ec5493705e", size = 2311609, upload-time = "2025-10-16T14:04:12.142Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -1443,6 +1560,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/d8/0f354c375628e048bd0570645b310797299754730079853095bf000fba69/opentelemetry_api-1.38.0.tar.gz", hash = "sha256:f4c193b5e8acb0912b06ac5b16321908dd0843d75049c091487322284a3eea12", size = 65242, upload-time = "2025-10-16T08:35:50.25Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/a2/d86e01c28300bd41bab8f18afd613676e2bd63515417b77636fc1add426f/opentelemetry_api-1.38.0-py3-none-any.whl", hash = "sha256:2891b0197f47124454ab9f0cf58f3be33faca394457ac3e09daba13ff50aa582", size = 65947, upload-time = "2025-10-16T08:35:30.23Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/51/14/f0c4f0f6371b9cb7f9fa9ee8918bfd59ac7040c7791f1e6da32a1839780d/opentelemetry_proto-1.38.0.tar.gz", hash = "sha256:88b161e89d9d372ce723da289b7da74c3a8354a8e5359992be813942969ed468", size = 46152, upload-time = "2025-10-16T08:36:01.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/6a/82b68b14efca5150b2632f3692d627afa76b77378c4999f2648979409528/opentelemetry_proto-1.38.0-py3-none-any.whl", hash = "sha256:b6ebe54d3217c42e45462e2a1ae28c3e2bf2ec5a5645236a490f55f45f1a0a18", size = 72535, upload-time = "2025-10-16T08:35:45.749Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "opentelemetry-semantic-conventions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/cb/f0eee1445161faf4c9af3ba7b848cc22a50a3d3e2515051ad8628c35ff80/opentelemetry_sdk-1.38.0.tar.gz", hash = "sha256:93df5d4d871ed09cb4272305be4d996236eedb232253e3ab864c8620f051cebe", size = 171942, upload-time = "2025-10-16T08:36:02.257Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/2e/e93777a95d7d9c40d270a371392b6d6f1ff170c2a3cb32d6176741b5b723/opentelemetry_sdk-1.38.0-py3-none-any.whl", hash = "sha256:1c66af6564ecc1553d72d811a01df063ff097cdc82ce188da9951f93b8d10f6b", size = 132349, upload-time = "2025-10-16T08:35:46.995Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.59b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/bc/8b9ad3802cd8ac6583a4eb7de7e5d7db004e89cb7efe7008f9c8a537ee75/opentelemetry_semantic_conventions-0.59b0.tar.gz", hash = "sha256:7a6db3f30d70202d5bf9fa4b69bc866ca6a30437287de6c510fb594878aed6b0", size = 129861, upload-time = "2025-10-16T08:36:03.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/7d/c88d7b15ba8fe5c6b8f93be50fc11795e9fc05386c44afaf6b76fe191f9b/opentelemetry_semantic_conventions-0.59b0-py3-none-any.whl", hash = "sha256:35d3b8833ef97d614136e253c1da9342b4c3c083bbaf29ce31d572a1c3825eed", size = 207954, upload-time = "2025-10-16T08:35:48.054Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -1740,6 +1909,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "2.22" @@ -1928,6 +2118,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/b0/4bc07ccd3572a2f9df7e6782f52b0c6c90dcbb803ac4a167702d7d0dfe1e/python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab", size = 41978, upload-time = "2025-06-24T04:21:07.341Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556, upload-time = "2025-06-24T04:21:06.073Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -2038,6 +2237,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/f2/c2d64f6564f32af913bf5f3f7ae41c7c263c5ae4c4e8f1a17af8af66cd46/rpds_py-0.25.1-cp312-cp312-win_arm64.whl", hash = "sha256:6d50841c425d16faf3206ddbba44c21aa3310a0cebc3c1cdfc3e3f4f9f6f5728", size = 225399, upload-time = "2025-05-21T12:43:53.351Z" }, ] +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + [[package]] name = "ruff" version = "0.9.7" @@ -2138,6 +2349,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, ] +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "sqlparse" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/40/edede8dd6977b0d3da179a342c198ed100dd2aba4be081861ee5911e4da4/sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272", size = 84999, upload-time = "2024-12-10T12:05:30.728Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca", size = 44415, upload-time = "2024-12-10T12:05:27.824Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -2152,6 +2381,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] +[[package]] +name = "starlette" +version = "0.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/72/2db2f49247d0a18b4f1bb9a5a39a0162869acf235f3a96418363947b3d46/starlette-0.48.0-py3-none-any.whl", hash = "sha256:0764ca97b097582558ecb498132ed0c7d942f233f365b86ba37770e026510659", size = 73736, upload-time = "2025-09-13T08:41:03.869Z" }, +] + [[package]] name = "statsmodels" version = "0.14.5" @@ -2445,6 +2687,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680, upload-time = "2025-04-10T15:23:37.377Z" }, ] +[[package]] +name = "uvicorn" +version = "0.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "h11", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/f06b84e2697fef4688ca63bdb2fdf113ca0a3be33f94488f2cadb690b0cf/uvicorn-0.38.0.tar.gz", hash = "sha256:fd97093bdd120a2609fc0d3afe931d4d4ad688b6e75f0f929fde1bc36fe0e91d", size = 80605, upload-time = "2025-10-18T13:46:44.63Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" }, +] + [[package]] name = "wcwidth" version = "0.2.13" @@ -2596,6 +2851,7 @@ dependencies = [ { name = "panel", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "plotly", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "weathergen-metrics", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "xhistogram", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "xskillscore", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] @@ -2615,6 +2871,7 @@ requires-dist = [ { name = "panel" }, { name = "plotly", specifier = ">=6.2.0" }, { name = "weathergen-common", editable = "packages/common" }, + { name = "weathergen-metrics", editable = "packages/metrics" }, { name = "xhistogram" }, { name = "xskillscore" }, ] @@ -2627,6 +2884,37 @@ dev = [ { name = "ruff", specifier = "==0.9.7" }, ] +[[package]] +name = "weathergen-metrics" +version = "0.1.0" +source = { editable = "packages/metrics" } +dependencies = [ + { name = "mlflow-skinny", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyrefly", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pytest", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pytest-mock", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "ruff", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] + +[package.metadata] +requires-dist = [ + { name = "mlflow-skinny" }, + { name = "weathergen-common", editable = "packages/common" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyrefly", specifier = "==0.36.0" }, + { name = "pytest", specifier = "~=8.3.5" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, + { name = "ruff", specifier = "==0.9.7" }, +] + [[package]] name = "webencodings" version = "0.5.1" @@ -2746,3 +3034,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/21/d1/764ca5b66d91b20de wheels = [ { url = "https://files.pythonhosted.org/packages/b4/d1/c84022a44afc7b7ccc442fba3daee56bdd03593d91ee4bc245a08e4fcc55/zarr-2.18.4-py3-none-any.whl", hash = "sha256:2795e20aff91093ce7e4da36ab1a138aededbd8ab66bf01fd01512e61d31e5d1", size = 210600, upload-time = "2024-12-12T16:04:06.642Z" }, ] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +] From 2964381bc14ca7eee14b0761a7d4f813db375fde Mon Sep 17 00:00:00 2001 From: Jifeng Wang Date: Wed, 5 Nov 2025 14:08:59 +0100 Subject: [PATCH 17/32] Fix the issue - "Empty source still have embedding network" (#1114) * Replace cf.rank==0 with utils.distributed.is_root * fix empty source inputs still have embedding layer * fix lint * fix source empty or source exclude all * fix source empty or source exclude all * fix forecast mode empty source --------- Co-authored-by: wang85 Co-authored-by: wang85 Co-authored-by: wang85 Co-authored-by: wang85 --- src/weathergen/datasets/multi_stream_data_sampler.py | 4 +++- src/weathergen/model/engines.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index daafd2a25..6519bbd92 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -248,7 +248,9 @@ def advance(self): ################################################### def get_sources_size(self): return [ - ds[0].get_source_num_channels() + 0 + if ds[0].get_source_num_channels() == 0 + else ds[0].get_source_num_channels() + ds[0].get_geoinfo_size() + ds[0].get_coords_size() + self.tokenizer.get_size_time_embedding() diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 3351dabc4..7359d1403 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -47,7 +47,7 @@ def __init__(self, cf: Config, sources_size) -> None: for i, si in enumerate(self.cf.streams): stream_name = si.get("name", i) - if "diagnostic" in si and si["diagnostic"]: + if si.get("diagnostic", False) or self.sources_size[i] == 0: self.embeds.append(torch.nn.Identity()) continue From 52ef0aa3fdacc4d069f31fe987353bcc5eb3dc21 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 6 Nov 2025 08:24:07 +0100 Subject: [PATCH 18/32] [930][evaluation] implement CSVReader (#932) * first version of quaver reader * working version * add CSVReader * rebase to develop * add polimorphism * fix names * lint --- .../src/weathergen/evaluate/io_reader.py | 181 +++++++++++++++++- .../src/weathergen/evaluate/run_evaluation.py | 51 ++--- .../evaluate/src/weathergen/evaluate/utils.py | 34 ---- 3 files changed, 208 insertions(+), 58 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index a3d836d7f..ae30ea1e5 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import json import logging import re from dataclasses import dataclass @@ -14,6 +15,7 @@ import numpy as np import omegaconf as oc +import pandas as pd import xarray as xr from tqdm import tqdm @@ -85,8 +87,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | self.eval_cfg = eval_cfg self.run_id = run_id self.private_paths = private_paths - self.streams = eval_cfg.streams.keys() + self.data = None # If results_base_dir and model_base_dir are not provided, default paths are used self.model_base_dir = self.eval_cfg.get("model_base_dir", None) @@ -128,6 +130,10 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" return list() + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """Placeholder to load pre-computed scores for a given run, stream, metric""" + return None + def check_availability( self, stream: str, @@ -309,6 +315,146 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili ) +##### Helper function for CSVReader #### +def _rename_channels(data) -> pd.DataFrame: + """ + The scores downloaded from Quaver have a different convention. Need renaming. + Rename channel names to include underscore between letters and digits. + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' + + Parameters + ---------- + name : str + Original channel name. + + Returns + ------- + pd.DataFrame + Dataset with renamed channel names. + """ + for name in list(data.index): + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged + if re.match(r"^\d", name): + continue + + # Otherwise, insert underscore between letters and digits + data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}) + + return data + + +class CsvReader(Reader): + """ + Reader class to read evaluation data from CSV files and convert to xarray DataArray. + """ + + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + """ + Initialize the CsvReader. + + Parameters + ---------- + eval_cfg : dir + config with plotting and evaluation options for that run id + run_id : str + run id of the model + private_paths: lists + list of private paths for the supported HPC + """ + + super().__init__(eval_cfg, run_id, private_paths) + self.csv_path = eval_cfg.get("csv_path") + assert self.csv_path is not None, "CSV path must be provided in the config." + + pd_data = pd.read_csv(self.csv_path, index_col=0) + + self.data = _rename_channels(pd_data) + self.metrics_base_dir = Path(self.csv_path).parent + # for backward compatibility allow metric_dir to be specified in the run config + self.metrics_dir = Path( + self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") + ) + + assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." + self.stream = list(eval_cfg.streams.keys())[0] + self.channels = self.data.index.tolist() + self.samples = [0] + self.forecast_steps = [int(col.split()[0]) for col in self.data.columns] + self.npoints_per_sample = [0] + self.epoch = eval_cfg.get("epoch", 0) + self.metric = eval_cfg.get("metric") + self.region = eval_cfg.get("region") + + def get_samples(self) -> set[int]: + """get set of samples for the retrieved scores (initialisation times)""" + return set(self.samples) # Placeholder implementation + + def get_forecast_steps(self) -> set[int]: + """get set of forecast steps""" + return set(self.forecast_steps) # Placeholder implementation + + # TODO: get this from config + def get_channels(self, stream: str | None = None) -> list[str]: + """get set of channels""" + assert stream == self.stream, "streams do not match in CSVReader." + return list(self.channels) # Placeholder implementation + + def get_values(self) -> xr.DataArray: + """get score values in the right format""" + return self.data.values[np.newaxis, :, :, np.newaxis].T + + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """ + Load the existing scores for a given run, stream and metric. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + region : + Region name. + metric : + Metric name. + + Returns + ------- + xr.DataArray + The metric DataArray. + """ + + available_data = self.check_availability(stream, mode="evaluation") + + # fill it only for matching metric + if metric == self.metric and region == self.region and stream == self.stream: + data = self.get_values() + else: + data = np.full( + ( + len(available_data.samples), + len(available_data.fsteps), + len(available_data.channels), + 1, + ), + np.nan, + ) + + da = xr.DataArray( + data.astype(np.float32), + dims=("sample", "forecast_step", "channel", "metric"), + coords={ + "sample": available_data.samples, + "forecast_step": available_data.fsteps, + "channel": available_data.channels, + "metric": [metric], + }, + attrs={"npoints_per_sample": self.npoints_per_sample}, + ) + + return da + + class WeatherGenReader(Reader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" @@ -656,6 +802,39 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) return list(dummy.prediction.as_xarray().coords["ens"].values) + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + """ + Load the pre-computed scores for a given run, stream and metric and epoch. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + region : + Region name. + metric : + Metric name. + + Returns + ------- + xr.DataArray + The metric DataArray or None if the file does not exist. + """ + score_path = ( + Path(self.metrics_dir) + / f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + return xr.DataArray.from_dict(data_dict) + else: + return None + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): """ Get the value of a key for a specific stream from the a model config. diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 3e30acedd..8884b5e91 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -22,14 +22,13 @@ from weathergen.common.config import _REPO_ROOT from weathergen.common.platform_env import get_platform_env -from weathergen.evaluate.io_reader import WeatherGenReader +from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader from weathergen.evaluate.plot_utils import collect_channels from weathergen.evaluate.utils import ( calc_scores_per_stream, metric_list_to_json, plot_data, plot_summary, - retrieve_metric_from_json, ) from weathergen.metrics.mlflow_utils import ( MlFlowUpload, @@ -111,7 +110,13 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: for run_id, run in runs.items(): _logger.info(f"RUN {run_id}: Getting data...") - reader = WeatherGenReader(run, run_id, private_paths) + type = run.get("type", "zarr") + if type == "zarr": + reader = WeatherGenReader(run, run_id, private_paths) + elif type == "csv": + reader = CsvReader(run, run_id, private_paths) + else: + raise ValueError(f"Unknown run type {type} for run {run_id}. Supported: zarr, csv.") for stream in reader.streams: _logger.info(f"RUN {run_id}: Processing stream {stream}...") @@ -135,29 +140,29 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: metrics_to_compute = [] for metric in metrics: - try: - metric_data = retrieve_metric_from_json( - reader, - stream, - region, - metric, - ) + metric_data = reader.load_scores( + stream, + region, + metric, + ) - available_data = reader.check_availability( - stream, metric_data, mode="evaluation" - ) + if metric_data is None: + metrics_to_compute.append(metric) + continue + + available_data = reader.check_availability( + stream, metric_data, mode="evaluation" + ) - if not available_data.score_availability: - metrics_to_compute.append(metric) - else: - # simply select the chosen eval channels, samples, fsteps here... - scores_dict[metric][region][stream][run_id] = metric_data.sel( - sample=available_data.samples, - channel=available_data.channels, - forecast_step=available_data.fsteps, - ) - except (FileNotFoundError, KeyError): + if not available_data.score_availability: metrics_to_compute.append(metric) + else: + # simply select the chosen eval channels, samples, fsteps here... + scores_dict[metric][region][stream][run_id] = metric_data.sel( + sample=available_data.samples, + channel=available_data.channels, + forecast_step=available_data.fsteps, + ) if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 3245d1dc8..98a463a92 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -375,40 +375,6 @@ def metric_list_to_json( ) -def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): - """ - Retrieve the score for a given run, stream, metric, mini_epoch, and rank from a JSON file. - - Parameters - ---------- - reader : - Reader object containing all info for a specific run_id - stream : - Stream name. - region : - Region name. - metric : - Metric name. - - Returns - ------- - xr.DataArray - The metric DataArray. - """ - score_path = ( - Path(reader.metrics_dir) - / f"{reader.run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") - - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - return xr.DataArray.from_dict(data_dict) - else: - raise FileNotFoundError(f"File {score_path} not found in the archive.") - - def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): """ Plot summary of the evaluation results. From fe5e4cf8a2487f720f5eca9166407c5d87d5bd0c Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Thu, 6 Nov 2025 09:31:34 +0100 Subject: [PATCH 19/32] Iluise/hot fixes (#1209) * fix froct * fix 1150 --- .../evaluate/src/weathergen/evaluate/clim_utils.py | 6 +++++- .../evaluate/src/weathergen/evaluate/io_reader.py | 12 +++++++----- packages/evaluate/src/weathergen/evaluate/score.py | 4 +++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/clim_utils.py index 768f93803..65091dc4d 100644 --- a/packages/evaluate/src/weathergen/evaluate/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/clim_utils.py @@ -128,6 +128,10 @@ def align_clim_data( timestamp = target_data.valid_time.values[sample_mask][0] # Prepare climatology data for each sample matching_time_idx = match_climatology_time(timestamp, clim_data) + + if matching_time_idx is None: + continue + prepared_clim_data = ( clim_data.data.isel( time=matching_time_idx, @@ -209,7 +213,7 @@ def get_climatology(reader, da_tars, stream: str) -> xr.Dataset | None: aligned_clim_data = None - if clim_data_path: + if clim_data_path is not None: clim_data = xr.open_dataset(clim_data_path) _logger.info("Aligning climatological data with target structure...") aligned_clim_data = align_clim_data(da_tars, clim_data) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index ae30ea1e5..5b78ad507 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -495,10 +495,10 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non ) if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): - _logger.error(f"Zarr file {self.fname_zarr} does not exist or is not a directory.") - raise FileNotFoundError( - f"Zarr file {self.fname_zarr} does not exist or is not a directory." - ) + _logger.error(f"Zarr file {self.fname_zarr} does not exist.") + # raise FileNotFoundError( + # f"Zarr file {self.fname_zarr} does not exist or is not a directory." + # ) def get_inference_config(self): """ @@ -626,6 +626,8 @@ def get_data( pred = bbox.apply_mask(pred) npoints = len(target.ipoint) + pps.append(npoints) + if npoints == 0: _logger.info( f"Skipping {stream} sample {sample} forecast step: {fstep}. " @@ -649,7 +651,6 @@ def get_data( da_tars_fs.append(target.squeeze()) da_preds_fs.append(pred.squeeze()) - pps.append(npoints) if len(da_tars_fs) > 0: fsteps_final.append(fstep) @@ -692,6 +693,7 @@ def get_data( da_tars.append(da_tars_fs) da_preds.append(da_preds_fs) + if return_counts: points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps) diff --git a/packages/evaluate/src/weathergen/evaluate/score.py b/packages/evaluate/src/weathergen/evaluate/score.py index 5d1de4186..384a0a1bb 100755 --- a/packages/evaluate/src/weathergen/evaluate/score.py +++ b/packages/evaluate/src/weathergen/evaluate/score.py @@ -307,7 +307,9 @@ def get_score( group_names = list(next(iter(grouped_args.values())).groups.keys()) results = [] for name in group_names: - group_slice = {k: v[name] for k, v in grouped_args.items()} + group_slice = { + k: (v[name] if v is not None else v) for k, v in grouped_args.items() + } res = f(**group_slice) # Add coordinate for concatenation res = res.expand_dims({group_by_coord: [name]}) From c55844b38a31dd98c1d52460c482bafbe4fc728b Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Sun, 9 Nov 2025 09:21:41 +0100 Subject: [PATCH 20/32] Fix plot_train verbosity (#1225) --- src/weathergen/utils/train_logger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 8c09dfcfd..4281743a8 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -412,7 +412,6 @@ def clean_df(df, columns: list[str] | None): df = df.with_columns( (df[_weathergen_timestamp] - df[_weathergen_timestamp].min()).alias(_weathergen_reltime) ) - _logger.info(f"schema {df.schema}") if columns: columns = list(set(columns)) # remove duplicates From 8af56c49ac20130d5bfcfb11a7aedbd611980e91 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Mon, 10 Nov 2025 09:43:25 +0100 Subject: [PATCH 21/32] [1206] Experimentation for extra data readers (#1207) * initial implementation * changes * toml --- packages/readers_extra/pyproject.toml | 106 ++++ .../src/weathergen/readers_extra/__init__.py | 7 + .../readers_extra/data_reader_icon.py | 530 ++++++++++++++++++ .../src/weathergen/readers_extra/registry.py | 24 + pyproject.toml | 7 +- scripts/check_tomls.py | 6 +- src/weathergen/datasets/icon_dataset.py | 484 ---------------- .../datasets/multi_stream_data_sampler.py | 18 +- uv.lock | 36 ++ 9 files changed, 721 insertions(+), 497 deletions(-) create mode 100644 packages/readers_extra/pyproject.toml create mode 100644 packages/readers_extra/src/weathergen/readers_extra/__init__.py create mode 100644 packages/readers_extra/src/weathergen/readers_extra/data_reader_icon.py create mode 100644 packages/readers_extra/src/weathergen/readers_extra/registry.py delete mode 100644 src/weathergen/datasets/icon_dataset.py diff --git a/packages/readers_extra/pyproject.toml b/packages/readers_extra/pyproject.toml new file mode 100644 index 000000000..21179f146 --- /dev/null +++ b/packages/readers_extra/pyproject.toml @@ -0,0 +1,106 @@ +[project] +name = "weathergen-readers-extra" +version = "0.1.0" +description = "The WeatherGenerator Machine Learning Earth System Model" +readme = "../../README.md" +requires-python = ">=3.12,<3.13" +# TODO: incomplete: it also implicitly depends on the main project for the base classes +# There is currently a circular dependency readers-extra => root => readers-extra +# It needs to be broken by moving the base class of the readers code to its own package. +dependencies = [ + "xarray", + "zarr", + "weathergen-common", +] + +[dependency-groups] +dev = [ + "pytest~=8.3.5", + "pytest-mock>=3.14.1", + "ruff==0.9.7", + "pyrefly==0.36.0", +] + + +[tool.pyrefly] +project-includes = ["src/"] +project-excludes = [ +] + +[tool.pyrefly.errors] +bad-argument-type = false +unsupported-operation = false +missing-attribute = false +no-matching-overload = false +bad-context-manager = false + +# To do: +bad-assignment = false +bad-return = false +index-error = false +not-iterable = false +not-callable = false + + + + +# The linting configuration +[tool.ruff] + +# Wide rows +line-length = 100 + +[tool.ruff.lint] +# All disabled until the code is formatted. +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # Banned imports + "TID", + # Naming conventions + "N", + # print + "T201" +] + +# These rules are sensible and should be enabled at a later stage. +ignore = [ + # "B006", + "B011", + "UP008", + "SIM117", + "SIM118", + "SIM102", + "SIM401", + # To ignore, not relevant for us + "SIM108", # in case additional norm layer supports are added in future + "N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted) + "E731", # overly restrictive and less readable code + "N812", # prevents us following the convention for importing torch.nn.functional as F +] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example" + +[tool.ruff.format] +# Use Unix `\n` line endings for all files +line-ending = "lf" + + + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/weathergen"] diff --git a/packages/readers_extra/src/weathergen/readers_extra/__init__.py b/packages/readers_extra/src/weathergen/readers_extra/__init__.py new file mode 100644 index 000000000..df6164120 --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/__init__.py @@ -0,0 +1,7 @@ +""" +readers-extra package. + +Contains additional data readers for the WeatherGenerator project. + +This code is not as stable and tested as the main readers. +""" diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon.py new file mode 100644 index 000000000..78a103ff6 --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_icon.py @@ -0,0 +1,530 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +import logging +from pathlib import Path +from typing import override + +import fsspec +import numpy as np +import xarray as xr +import zarr +from numpy.typing import NDArray + +from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + +frequencies = { + "3hrPt": np.timedelta64(10800000000000, "ns"), + "day": np.timedelta64(86400000000000, "ns"), + "fx": np.timedelta64(0, "ns"), + "mon": np.timedelta64(2548800000000000, "ns"), + "monC": np.timedelta64(2505600000000000, "ns"), + "yr": np.timedelta64(31536000000000000, "ns"), +} + + +class DataReaderIconBase(DataReaderTimestep): + "Wrapper for ICON data variables" + + def __init__( + self, + tw_handler: TimeWindowHandler, + stream_info: dict, + ) -> None: + """ + Parent class for ICON data variables + + Parameters + ---------- + tw_handler : TimeWindowHandler + Handles temporal slicing and mapping from time indices to datetimes + stream_info : dict + Stream metadata + """ + + # Extract key metadata from stream_info + lon_attribute = stream_info["attributes"]["lon"] + lat_attribute = stream_info["attributes"]["lat"] + mesh_attribute = stream_info["attributes"]["grid"] + + # Set mesh size based on spatial grid definition + self.mesh_size = len(self.ds[mesh_attribute]) + + # Time range in the dataset + self.time = self.ds["time"].values + start_ds = np.datetime64(self.time[0]) + end_ds = np.datetime64(self.time[-1]) + + # Skip stream if it doesn't intersect with time window + if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: + name = stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + # Compute temporal resolution if not already defined + self.temporal_frequency = ( + self.time[1] - self.time[0] + if self.temporal_frequency is None + else self.temporal_frequency + ) + + # Initialize parent class with resolved time window + super().__init__( + tw_handler, + stream_info, + start_ds, + end_ds, + self.temporal_frequency, + ) + + # Compute absolute start/end indices in the dataset based on time window + self.start_idx = (tw_handler.t_start - start_ds).astype("timedelta64[D]").astype( + int + ) * self.mesh_size + self.end_idx = ( + (tw_handler.t_end - start_ds).astype("timedelta64[D]").astype(int) + 1 + ) * self.mesh_size - 1 + + # Sanity check + assert self.end_idx > self.start_idx, ( + f"Abort: Final index of {self.end_idx} is the same or smaller than " + f"start index {self.start_idx}" + ) + + # Number of time steps in selected range + self.len = int((self.end_idx - self.start_idx) // self.mesh_size) + + # === Coordinates === + + # Convert to degrees if stored in radians + coords_units = self.ds[lat_attribute].attrs["units"] + if coords_units == "radian": + self.lat = np.rad2deg(self.ds[lat_attribute][:].astype("f")) + self.lon = np.rad2deg(self.ds[lon_attribute][:].astype("f")) + else: + self.lat = self.ds[lat_attribute][:].astype("f") + self.lon = self.ds[lon_attribute][:].astype("f") + + # Extract coordinates and pressure level + self.lat = _clip_lat(self.lat) + self.lon = _clip_lon(self.lon) + + # Placeholder; currently unused + self.step_hrs = 1 + + # Stream metadata + self.properties = { + "stream_id": 0, + } + + # === Normalization statistics === + + # Ensure stats match dataset columns + assert self.stats_vars == self.colnames, ( + f"Variables in normalization file {self.stats_vars} do not match " + f"dataset columns {self.colnames}" + ) + + # === Channel selection === + source_channels = stream_info.get("source") + if source_channels: + self.source_channels, self.source_idx = self.select(source_channels) + elif getattr(self, "levels", None): + self.source_channels, self.source_idx = self.select_by_level("source") + else: + self.source_channels = self.colnames + self.source_idx = self.cols_idx + + target_channels = stream_info.get("target") + if target_channels: + self.target_channels, self.target_idx = self.select(target_channels) + elif getattr(self, "levels", None): + self.target_channels, self.target_idx = self.select_by_level("target") + else: + self.target_channels = self.colnames + self.target_idx = self.cols_idx + + # Ensure all selected channels have valid standard deviations + selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) + non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] + if len(non_positive_stds) != 0: + bad_vars = [self.colnames[selected_channel_indices[i]] for i in non_positive_stds] + raise ValueError( + f"Abort: Encountered non-positive standard deviations" + f" for selected columns {bad_vars}." + ) + + # === Geo-info channels (currently unused) === + self.geoinfo_channels = [] + self.geoinfo_idx = [] + + def select(self, ch_filters: list[str]) -> (NDArray, list[str]): + """ + Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + + Parameters + ---------- + ch_filters: list[str] + list of patterns to access + + Returns + ------- + selected_colnames: np.array, + Selected columns according to the patterns specified in ch_filters + selected_cols_idx + respective index of these patterns in the data array + """ + mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] + + selected_cols_idx = self.cols_idx[np.where(mask)[0]] + selected_colnames = [self.colnames[int(i)] for i in np.where(mask)[0]] + + return selected_colnames, selected_cols_idx + + def select_by_level(self, ch_type: str) -> tuple[list[str], NDArray[np.int64]]: + """ + Select channels constrained by allowed pressure levels and optional excludes. + ch_type: "source" or "target" (for *_exclude key in stream_info) + """ + channels_exclude = self.stream_info.get(f"{ch_type}_exclude", []) + allowed_levels = set(self.levels) if getattr(self, "levels", None) else set() + + new_colnames: list[str] = [] + for ch in self.colnames: + parts = ch.split("_") + # Profile channel if exactly one level suffix exists + if len(parts) == 2 and parts != "": + level = parts[1] + ch_base = parts[0] + if ( + not allowed_levels or level in allowed_levels + ) and ch_base not in channels_exclude: + new_colnames.append(ch) + else: + if ch not in channels_exclude: + new_colnames.append(ch) + + mask = [c in new_colnames for c in self.colnames] + selected_cols_idx = self.cols_idx[np.where(mask)] + selected_colnames = [self.colnames[int(i)] for i in np.where(mask)[0]] + + return selected_colnames, selected_cols_idx + + @override + def init_empty(self) -> None: + super().init_empty() + self.len = 0 + + @override + def length(self) -> int: + """ + Length of dataset + + Parameters + ---------- + None + + Returns + ------- + length of dataset + """ + return self.len + + +########################## +class DataReaderIcon(DataReaderIconBase): + "Wrapper for ICON variables - This class reads Zarr format datasets" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + # Open Zarr dataset with Xarray + self.ds = xr.open_zarr(filename, consolidated=True) + + # Column (variable) names and indices + self.colnames = list(self.ds) + self.cols_idx = np.array(list(np.arange(len(self.colnames)))) + + # get pressure levels + # TODO Julius ? + self.levels = [] + + # Will be inferred later based on the dataset’s time variable + self.temporal_frequency = None + + # Load associated statistics file for normalization + stats_filename = Path(filename).with_name(Path(filename).stem + "_stats.json") + with open(stats_filename) as stats_file: + self.stats = json.load(stats_file) + + # Extract variable list from stats metadata + stats_vars_metadata = self.stats["metadata"]["variables"] + self.stats_vars = [v for v in stats_vars_metadata if v not in {"clat", "clon", "time"}] + + # Load mean and standard deviation per variable + self.mean = np.array(self.stats["statistics"]["mean"], dtype="d") + self.stdev = np.array(self.stats["statistics"]["std"], dtype="d") + + # Delegate further initialization to the base class + super().__init__( + tw_handler, + stream_info, + ) + + # TODO Julius ? + def select_by_level(self): + return + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for temporal window + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + Returns + ------- + data (coords, geoinfos, data, datetimes) + """ + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # TODO: handle sub-sampling + + t_idxs_start = t_idxs[0] + t_idxs_end = t_idxs[-1] + 1 + + # datetimes + datetimes = np.asarray(self.time[t_idxs_start:t_idxs_end]) + + # lat/lon coordinates + tiling to match time steps + lat = self.lat.values[:, np.newaxis] + lon = self.lon.values[:, np.newaxis] + + lat = np.tile(lat, len(datetimes)) + lon = np.tile(lon, len(datetimes)) + + coords = np.concatenate([lat, lon], axis=1) + + # time coordinate repeated to match grid points + datetimes = np.repeat(datetimes, self.mesh_size).reshape(-1, 1) + datetimes = np.squeeze(datetimes) + + # expanding indexes for data + start_row = t_idxs_start * self.mesh_size + end_row = t_idxs_end * self.mesh_size + + # data + channels = np.array(self.colnames)[channels_idx] + + data_reshaped = [ + np.asarray(self.ds[ch_]).reshape(-1, 1)[start_row:end_row] for ch_ in channels + ] + data = np.concatenate(data_reshaped, axis=1) + + # empty geoinfos + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + + return rd + + +########################## +class DataReaderIconCmip6(DataReaderIconBase): + "Wrapper for ICON CMIP6 data variables - This class reads NetCDF4 using kerchunk" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + # Open the kerchunk-generated reference JSON + ref_path = Path(filename) + if not ref_path.exists(): + raise FileNotFoundError(f"Kerchunk reference JSON not found: {ref_path}") + + # Load JSON references and initialize a virtual file system + kerchunk_ref = json.loads(ref_path.read_text()) + fs = fsspec.filesystem("reference", fo=kerchunk_ref) + mapper = fs.get_mapper("") + + # Ensure metadata is consolidated for zarr-style access + zarr.consolidate_metadata(mapper) + + # Open the dataset using Xarray with Zarr engine + self.ds = xr.open_dataset(mapper, engine="zarr", consolidated=True, chunks={"time": 1}) + + # get pressure levels + # TODO add self.dataset_levels + self.levels = stream_info["pressure_levels"] + + # Column (variable) names and indices + self.colnames, self.cols_idx = self.get_cols(stream_info["variables"]) + + # Determine temporal frequency from dataset metadata + frequency_attr = self.ds.attrs["frequency"] + self.temporal_frequency = frequencies[frequency_attr] + + # Load associated statistics file for normalization + stats_filename = Path(filename).with_name(Path(filename).stem + "_stats.json") + with open(stats_filename) as stats_file: + self.stats = json.load(stats_file) + + # Variables included in the stats + self.stats_vars = list(self.stats) + + # Load mean and standard deviation per variable + self.mean = np.array([self.stats[var]["mean"] for var in self.stats_vars], dtype=np.float64) + self.stdev = np.array([self.stats[var]["std"] for var in self.stats_vars], dtype=np.float64) + + # Delegate further initialization to the base class + super().__init__( + tw_handler, + stream_info, + ) + + def get_cols(self, channels: list[str]) -> (list[str], list[int]): + """ + TBD + """ + colnames = [] + for ch in channels: + coords_list = list(self.ds[ch].coords) + if "plev" not in coords_list: + colnames.append(f"{ch}") + else: + dataset_levels = self.ds[ch]["plev"][0, :].values + for level in dataset_levels: + colnames.append(f"{ch}_{int(level)}") + + cols_idx = np.array(list(np.arange(len(colnames)))) + + return colnames, cols_idx + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for temporal window + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : list[int] + Selection of channels + + Returns + ------- + ReaderData + """ + (t_idxs, dtr) = self._get_dataset_idxs(idx) + # dtr is a time window object it has the attributes t_start_win and t_end_win + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # Select channels + channels = np.array(self.colnames)[channels_idx] + + start_ts = dtr.start + end_ts = dtr.end - np.timedelta64(1, "h") + + try: + data_per_channel = [] + datetimes = [] + coords = [] + + for ch in channels: + ch_parts = ch.split("_") + if ( + hasattr(self, "levels") + and self.levels + and len(ch_parts) == 2 + and ch_parts[1] in self.levels + ): + ch_ = ch_parts[0] + plev_int = ch_parts[1] + levels_all = self.ds[ch_]["plev"][0].values + da = self.ds[ch_].assign_coords(plev=("plev", levels_all)) + da = da.sel(plev=plev_int, time=slice(start_ts, end_ts)) + else: + da = self.ds[ch].sel(time=slice(start_ts, end_ts)) + data_arr = da.compute(scheduler="synchronous") + + if not data_per_channel: + # datetimes + datetimes = np.repeat(data_arr.time.values, self.mesh_size).reshape(-1, 1) + datetimes = np.squeeze(datetimes) + + # coords + n_times = len(data_arr.time) + lat = np.tile(data_arr.latitude.values[:, np.newaxis], (n_times, 1)) + lon = np.tile(data_arr.longitude.values[:, np.newaxis], (n_times, 1)) + + coords = np.concatenate([lat, lon], axis=1) + + # data + data_per_channel.append(np.asarray(data_arr.data.reshape(-1, 1))) + + data = np.concatenate(data_per_channel, axis=1) + except Exception as e: + _logger.debug(f"Date not present in ICON dataset: {str(e)}. Skipping.") + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + if data_per_channel[0].shape[0] == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # Empty geoinfos + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + return rd diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py new file mode 100644 index 000000000..761628944 --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -0,0 +1,24 @@ +from collections.abc import Callable +from dataclasses import dataclass + +from weathergen.common.config import Config + + +@dataclass +class ReaderEntry: + data_path: str | None + constructor: Callable + + +def get_extra_reader(name: str, cf: Config) -> object | None: + """Get an extra reader by name.""" + # Uses lazy imports to avoid circular dependencies and to not load all the readers at start. + # There is no sanity check on them, so they may fail at runtime during imports + + match name: + case "icon": + from weathergen.readers_extra.data_reader_icon import DataReaderIcon + + return ReaderEntry(cf.data_path_icon, DataReaderIcon) + case _: + return None diff --git a/pyproject.toml b/pyproject.toml index 016552e8b..250a53ebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "numexpr>=2.11.0", "weathergen-common", "weathergen-evaluate", + "weathergen-readers-extra", ] @@ -213,8 +214,9 @@ explicit = true [tool.uv.sources] weathergen-common = { workspace = true } -weathergen-metrics = { workspace = true } weathergen-evaluate = { workspace = true } +weathergen-metrics = { workspace = true } +weathergen-readers-extra = { workspace = true } flash-attn = [ @@ -246,8 +248,9 @@ log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.uv.workspace] members = [ - "packages/evaluate", "packages/common", + "packages/evaluate", "packages/metrics", + "packages/readers_extra", ] diff --git a/scripts/check_tomls.py b/scripts/check_tomls.py index 4d11efec1..cb709c42b 100644 --- a/scripts/check_tomls.py +++ b/scripts/check_tomls.py @@ -60,6 +60,6 @@ def check_tomls(main_toml, *tomls): if __name__ == "__main__": main_toml = _REPO_ROOT / "pyproject.toml" - eval_toml = _REPO_ROOT / "packages" / "evaluate" / "pyproject.toml" - common_toml = _REPO_ROOT / "packages" / "common" / "pyproject.toml" - check_tomls(main_toml, eval_toml, common_toml) + sub_packages = ["evaluate", "common", "metrics", "readers_extra"] + tomls = [_REPO_ROOT / "packages" / package / "pyproject.toml" for package in sub_packages] + check_tomls(main_toml, *tomls) diff --git a/src/weathergen/datasets/icon_dataset.py b/src/weathergen/datasets/icon_dataset.py deleted file mode 100644 index abc17e32a..000000000 --- a/src/weathergen/datasets/icon_dataset.py +++ /dev/null @@ -1,484 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -from datetime import datetime -from pathlib import Path - -import numpy as np -import torch -import zarr - - -class IconDataset: - """ - A data reader for ICON model output stored in zarr. - - Parameters - ---------- - start : datetime | int - Start time of the data period as datetime object or integer in "%Y%m%d%H%M" format - end : datetime | int - End time of the data period (inclusive) with same format as start - len_hrs : int - Length of temporal windows in days - step_hrs : int - (Currently unused) Intended step size between windows in hours - filename : Path - Path to Zarr dataset containing ICON output - stream_info : dict[str, list[str]] - Dictionary with "source" and "target" keys specifying channel subsets to use - (e.g., {"source": ["temp_00"], "target": ["TRCH4_chemtr_00"]}) - - Attributes - ---------- - len_hrs : int - Temporal window length in days - mesh_size : int - Number of nodes in the ICON mesh - source_channels : list[str] - Patterns of selected source channels - target_channels : list[str] - Patterns of selected target channels - mean : np.ndarray - Per-channel means for normalization (includes coordinates) - stdev : np.ndarray - Per-channel standard deviations for normalization (includes coordinates) - properties : dict[str, list[str]] - Dataset metadata including 'stream_id' from Zarr attributes - - """ - - def __init__( - self, - start: datetime | int, - end: datetime | int, - len_hrs: int, - step_hrs: int, - filename: Path, - stream_info: dict, - ): - self.len_hrs = len_hrs - - format_str = "%Y%m%d%H%M" - if type(start) is not datetime: - start = datetime.strptime(str(start), format_str) - start = np.datetime64(start).astype("datetime64[D]") - - if type(end) is not datetime: - end = datetime.strptime(str(end), format_str) - end = np.datetime64(end).astype("datetime64[D]") - - # loading datafile - self.filename = filename - self.ds = zarr.open(filename, mode="r") - self.mesh_size = self.ds.attrs["ncells"] - - # Loading stat file - stats_filename = Path(filename).with_suffix(".json") - with open(stats_filename) as stats_file: - self.stats = json.load(stats_file) - - time_as_in_data_file = np.array(self.ds["time"], dtype="timedelta64[D]") + np.datetime64( - self.ds["time"].attrs["units"].split("since ")[-1] - ) - - start_ds = time_as_in_data_file[0] - end_ds = time_as_in_data_file[-1] - - # asserting start and end times - if start_ds > end or end_ds < start: - # TODO: this should be set in the base class - self.source_channels = [] - self.target_channels = [] - self.source_idx = np.array([]) - self.target_idx = np.array([]) - self.geoinfo_idx = [] - self.len = 0 - self.ds = None - return - - self.start_idx = (start - start_ds).astype("timedelta64[D]").astype(int) * self.mesh_size - self.end_idx = ( - (end - start_ds).astype("timedelta64[D]").astype(int) + 1 - ) * self.mesh_size - 1 - - self.len = (self.end_idx - self.start_idx) // self.mesh_size - - assert self.end_idx > self.start_idx, ( - f"Abort: Final index of {self.end_idx} is the same of larger than", - f" start index {self.start_idx}", - ) - - len_data_entries = len(self.ds["time"]) * self.mesh_size - - assert self.end_idx + len_hrs <= len_data_entries, ( - f"Abort: end_date must be set at least {len_hrs} before the last date in the dataset" - ) - - # variables - self.colnames = list(self.ds) - self.cols_idx = np.array(list(np.arange(len(self.colnames)))) - - # Ignore step_hrs, idk how it supposed to work - # TODO, TODO, TODO: - self.step_hrs = 1 - - # time - repeated_times = np.repeat(time_as_in_data_file, self.mesh_size).reshape(-1, 1) - self.time = repeated_times - - # coordinates - coords_units = self.ds["clat"].attrs["units"] - - if coords_units == "radian": - lat_as_in_data_file = np.rad2deg(self.ds["clat"][:].astype("f")) - lon_as_in_data_file = np.rad2deg(self.ds["clon"][:].astype("f")) - - else: - lat_as_in_data_file = self.ds["clat"][:].astype("f") - lon_as_in_data_file = self.ds["clon"][:].astype("f") - - self.lat = np.tile(lat_as_in_data_file, len(time_as_in_data_file)) - self.lon = np.tile(lon_as_in_data_file, len(time_as_in_data_file)) - - self.properties = {"stream_id": 0} - - # stats - stats_vars = self.stats["metadata"]["variables"] - assert stats_vars == self.colnames, ( - f"Variables in normalization file {stats_vars}" - f"do not match dataset columns {self.colnames}" - ) - - self.mean = np.array(self.stats["statistics"]["mean"], dtype="d") - self.stdev = np.array(self.stats["statistics"]["std"], dtype="d") - - # Channel selection and indexing - source_channels = stream_info["source"] if "source" in stream_info else None - if source_channels: - self.source_channels, self.source_idx = self.select(source_channels) - else: - self.source_channels = self.colnames - self.source_idx = self.cols_idx - - target_channels = stream_info["target"] if "target" in stream_info else None - if target_channels: - self.target_channels, self.target_idx = self.select(target_channels) - else: - self.target_channels = self.colnames - self.target_idx = self.cols_idx - - # Check if standard deviations are strictly positive for selected channels - selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) - non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] - assert len(non_positive_stds) == 0, ( - f"Abort: Encountered non-positive standard deviations " - f"for selected columns { - [self.colnames[selected_channel_indices][i] for i in non_positive_stds] - }." - ) - # TODO: define in base class - self.geoinfo_idx = [] - - def select(self, ch_filters: list[str]) -> tuple[list[str], np.array]: - """ - Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ - - mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] - - selected_cols_idx = np.where(mask)[0] - selected_colnames = [self.colnames[i] for i in selected_cols_idx] - - return selected_colnames, selected_cols_idx - - def __len__(self) -> int: - """ - Length of dataset - - Parameters - ---------- - None - - Returns - ------- - length of dataset - """ - return self.len - - def _get(self, idx: int, channels: np.array) -> tuple: - """ - Get data for window - - Parameters - ---------- - idx : int - Index of temporal window - channels_idx : np.array - Selection of channels - - Returns - ------- - data (coords, geoinfos, data, datetimes) - """ - if self.ds is None: - fp32 = np.float32 - return ( - np.array([], dtype=fp32), - np.array([], dtype=fp32), - np.array([], dtype=fp32), - np.array([], dtype=fp32), - ) - - # indexing - start_row = self.start_idx + idx * self.mesh_size - end_row = start_row + self.len_hrs * self.mesh_size - - # data - data_reshaped = [ - np.asarray(self.ds[ch_]).reshape(-1, 1)[start_row:end_row] for ch_ in channels - ] - data = np.concatenate(data_reshaped, axis=1) - - lat = np.expand_dims(self.lat[start_row:end_row], 1) - lon = np.expand_dims(self.lon[start_row:end_row], 1) - - latlon = np.concatenate([lat, lon], 1) - - # empty geoinfos - geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) - datetimes = np.squeeze(self.time[start_row:end_row]) - - return (latlon, geoinfos, data, datetimes) - - def get_source(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]: - """ - Get source data for idx - - Parameters - ---------- - idx : int - Index of temporal window - - Returns - ------- - source data (coords, geoinfos, data, datetimes) - """ - return self._get(idx, self.source_channels) - - def get_target(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]: - """ - Get target data for idx - - Parameters - ---------- - idx : int - Index of temporal window - - Returns - ------- - target data (coords, geoinfos, data, datetimes) - """ - return self._get(idx, self.target_channels) - - def get_source_size(self) -> int: - """ - Get size of all columns, including coordinates and geoinfo, with source - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 + len(self.geoinfo_idx) + len(self.source_idx) if self.ds else 0 - - def get_target_size(self) -> int: - """ - Get size of all columns, including coordinates and geoinfo, with source - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 + len(self.geoinfo_idx) + len(self.target_idx) if self.ds else 0 - - def get_coords_size(self) -> int: - """ - Get size of coords - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 - - def normalize_coords(self, coords: torch.tensor) -> torch.tensor: - """ - Normalize coordinates - - Parameters - ---------- - coords : - coordinates to be normalized - - Returns - ------- - Normalized coordinates - """ - coords[..., 0] = np.sin(np.deg2rad(coords[..., 0])) - coords[..., 1] = np.sin(0.5 * np.deg2rad(coords[..., 1])) - - return coords - - def normalize_source_channels(self, source: torch.tensor) -> torch.tensor: - """ - Normalize source channels - - Parameters - ---------- - source : - data to be normalized - - Returns - ------- - Normalized data - """ - assert source.shape[1] == len(self.source_idx) - for i, ch in enumerate(self.source_idx): - source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] - - return source - - def normalize_target_channels(self, target: torch.tensor) -> torch.tensor: - """ - Normalize target channels - - Parameters - ---------- - target : - data to be normalized - - Returns - ------- - Normalized data - """ - assert target.shape[1] == len(self.target_idx) - for i, ch in enumerate(self.target_idx): - target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch] - - return target - - def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: - """ - Temporal window corresponding to index - - Parameters - ---------- - idx : - index of temporal window - - Returns - ------- - start and end of temporal window - """ - start_row = self.start_idx + idx * self.mesh_size - end_row = start_row + self.len_hrs * self.mesh_size - - return (self.time[start_row, 0], self.time[end_row, 0]) - - def denormalize_target_channels(self, data: torch.tensor) -> torch.tensor: - """ - Denormalize target channels - - Parameters - ---------- - data : - data to be denormalized (target or pred) - - Returns - ------- - Denormalized data - """ - assert data.shape[-1] == len(self.target_idx), "incorrect number of channels" - for i, ch in enumerate(self.target_idx): - data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch] - - return data - - def get_source_num_channels(self) -> int: - """ - Get number of source channels - - Parameters - ---------- - None - - Returns - ------- - number of source channels - """ - return len(self.source_idx) - - def get_target_num_channels(self) -> int: - """ - Get number of target channels - - Parameters - ---------- - None - - Returns - ------- - number of target channels - """ - return len(self.target_idx) - - def get_geoinfo_size(self) -> int: - """ - Get size of geoinfos - - Parameters - ---------- - None - - Returns - ------- - size of geoinfos - """ - return len(self.geoinfo_idx) - - def normalize_geoinfos(self, geoinfos: torch.tensor) -> torch.tensor: - """ - Normalize geoinfos - - Parameters - ---------- - geoinfos : - geoinfos to be normalized - - Returns - ------- - Normalized geoinfo - """ - - assert geoinfos.shape[-1] == 0, "incorrect number of geoinfo channels" - return geoinfos diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 6519bbd92..e38d518da 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -23,7 +23,6 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs -from weathergen.datasets.icon_dataset import IconDataset from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData, spoof from weathergen.datasets.tokenizer_forecast import TokenizerForecast @@ -33,6 +32,7 @@ compute_offsets_scatter_embed, compute_source_cell_lens, ) +from weathergen.readers_extra.registry import get_extra_reader from weathergen.utils.distributed import is_root from weathergen.utils.train_logger import Stage @@ -145,13 +145,15 @@ def __init__( case "fesom": dataset = DataReaderFesom datapath = cf.data_path_fesom - case "icon": - dataset = IconDataset - datapath = cf.data_path_icon - case _: - msg = f"Unsupported stream type {stream_info['type']}" - f"for stream name '{stream_info['name']}'." - raise ValueError(msg) + case type_name: + reader_entry = get_extra_reader(type_name, cf) + if reader_entry is not None: + dataset = reader_entry.constructor + datapath = reader_entry.data_path + else: + msg = f"Unsupported stream type {stream_info['type']}" + f"for stream name '{stream_info['name']}'." + raise ValueError(msg) datapath = pathlib.Path(datapath) fname = pathlib.Path(fname) diff --git a/uv.lock b/uv.lock index c4f489d1b..4cdcbdcc5 100644 --- a/uv.lock +++ b/uv.lock @@ -20,6 +20,7 @@ members = [ "weathergen-common", "weathergen-evaluate", "weathergen-metrics", + "weathergen-readers-extra", ] [[package]] @@ -2730,6 +2731,7 @@ dependencies = [ { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-evaluate", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "weathergen-readers-extra", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "wheel", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "zarr", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] @@ -2783,6 +2785,7 @@ requires-dist = [ { name = "tqdm" }, { name = "weathergen-common", editable = "packages/common" }, { name = "weathergen-evaluate", editable = "packages/evaluate" }, + { name = "weathergen-readers-extra", editable = "packages/readers_extra" }, { name = "wheel" }, { name = "zarr", specifier = "~=2.17" }, ] @@ -2915,6 +2918,39 @@ dev = [ { name = "ruff", specifier = "==0.9.7" }, ] +[[package]] +name = "weathergen-readers-extra" +version = "0.1.0" +source = { editable = "packages/readers_extra" } +dependencies = [ + { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "xarray", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "zarr", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyrefly", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pytest", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pytest-mock", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "ruff", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] + +[package.metadata] +requires-dist = [ + { name = "weathergen-common", editable = "packages/common" }, + { name = "xarray" }, + { name = "zarr" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyrefly", specifier = "==0.36.0" }, + { name = "pytest", specifier = "~=8.3.5" }, + { name = "pytest-mock", specifier = ">=3.14.1" }, + { name = "ruff", specifier = "==0.9.7" }, +] + [[package]] name = "webencodings" version = "0.5.1" From 2979ddd24c131d232526ea2538dbaba8e2c82048 Mon Sep 17 00:00:00 2001 From: Javad kasravi Date: Mon, 10 Nov 2025 11:11:17 +0100 Subject: [PATCH 22/32] add module to annotations.json (#1142) Co-authored-by: Javad Kasravi --- config/profiling/annotations.json | 82 +++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/config/profiling/annotations.json b/config/profiling/annotations.json index bb42a38af..ea730997b 100644 --- a/config/profiling/annotations.json +++ b/config/profiling/annotations.json @@ -123,10 +123,62 @@ "Model.embed_cells", "Model.source_tokens", "Model.assimilate_local", + "Model.assimilate_global", "Model.forecast", "Model.predict" ] }, + { + "domain": "WeatherGen", + "color": "#C27BA0", + "module": "weathergen.model.attention", + "functions": [ + "MultiSelfAttentionHeadVarlen.__init__", + "MultiSelfAttentionHeadVarlen.forward", + "MultiSelfAttentionHeadVarlenFlex.__init__", + "MultiSelfAttentionHeadVarlenFlex.forward", + "MultiSelfAttentionHeadLocal.__init__", + "MultiSelfAttentionHeadLocal.forward", + "MultiCrossAttentionHeadVarlen.__init__", + "MultiCrossAttentionHeadVarlen.forward", + "MultiCrossAttentionHeadVarlenSlicedQ.__init__", + "MultiCrossAttentionHeadVarlenSlicedQ.forward", + "MultiSelfAttentionHead.__init__", + "MultiSelfAttentionHead.forward", + "MultiCrossAttentionHead.__init__", + "MultiCrossAttentionHead.forward" + ] + }, + { + "domain": "WeatherGen", + "color": "#83F5BF", + "module": "weathergen.model.layers", + "functions": [ + "NamedLinear.forward", + "MLP.__init__", + "MLP.forward" + ] + }, + { + "domain": "WeatherGen", + "color": "#50CDF3", + "module": "weathergen.model.norms", + "functions": [ + "RMSNorm.forward", + "RMSNorm.__init__", + "RMSNorm._norm", + "AdaLayerNorm.forward", + "AdaLayerNorm.__init__", + "SwiGLU.forward", + "SwiGLU.__init__", + "modulate", + "AdaLayerNormLayer.forward", + "AdaLayerNormLayer.__init__", + "AdaLayerNormLayer.initialise_weights", + "SaturateEncodings.forward", + "SaturateEncodings.__init__" + ] + }, { "domain": "WeatherGen", "color": "02dff7", @@ -164,6 +216,36 @@ ] }, + { + "domain": "flash_attention", + "color": "ffff00", + "module": "flash_attn", + "functions": [ + "flash_attn_func", + "flash_attn_varlen_func" + + ] + }, + { + "domain": "PyTorch_flash_attention", + "color": "808000", + "module": "torch.nn.attention.flex_attention", + "functions": [ + "create_block_mask", + "flex_attention" + ] + }, + { + "domain": "WeatherGen", + "color": "808000", + "module": "weathergen.model.positional_encoding", + "functions": [ + "positional_encoding_harmonic", + "positional_encoding_harmonic_idx", + "positional_encoding_harmonic_global", + "positional_encoding_harmonic_coord" + ] + }, { "domain": "WeatherGen", "color": "C6BAFF", From 14b8bf645e1f4c2eb19eecd163d1919a836ba4a2 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Date: Mon, 10 Nov 2025 17:51:56 +0100 Subject: [PATCH 23/32] Correct bug with score cards and bar plots for different metrics (#1192) * Rebase to develop * Linting * Address comments and linting --- .../src/weathergen/evaluate/plot_utils.py | 4 +- .../src/weathergen/evaluate/plotter.py | 171 ++++++++++++------ 2 files changed, 115 insertions(+), 60 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plot_utils.py index 3000b0767..361ab73c4 100644 --- a/packages/evaluate/src/weathergen/evaluate/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plot_utils.py @@ -167,7 +167,7 @@ def score_card_metric_region( if selected_data and len(selected_data) > 1.0: _logger.info(f"Creating score cards for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) - sc_plotter.plot(selected_data, run_ids, channels_common, name) + sc_plotter.plot(selected_data, run_ids, metric, channels_common, name) else: _logger.info( f"Only one run_id under stream: {stream}. Creating score card is skipped..." @@ -212,7 +212,7 @@ def bar_plot_metric_region( if selected_data and len(selected_data) > 1.0: _logger.info(f"Creating bar plots for {metric} - {region} - {stream}.") name = "_".join([metric, region, stream]) - br_plotter.plot(selected_data, run_ids, channels_set, name) + br_plotter.plot(selected_data, run_ids, metric, channels_set, name) else: _logger.info( f"Only one run_id for ({region}) region under stream : {stream}. " diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index f9e34dbb3..42c7a17dc 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -909,7 +909,12 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path) -> None: os.makedirs(self.out_plot_dir, exist_ok=True) def plot( - self, data: list[xr.DataArray], runs: list[str], channels: list[str], tag: str + self, + data: list[xr.DataArray], + runs: list[str], + metric: str, + channels: list[str], + tag: str, ) -> None: """ Plot score cards comparing performance between run_ids against a baseline over channels @@ -921,6 +926,8 @@ def plot( List of (xarray) DataArrays with the scores (stream, region and metric specific) runs: List containing runs (in str format) to be compared (provided in the config) + metric: + Metric for which we are plotting channels: List containing channels (in str format) of interest (provided in the config) tag: @@ -935,14 +942,16 @@ def plot( for run_index in range(1, n_runs): skill_model = 0.0 for var_index, var in enumerate(channels): - diff, diff_mean, skill = self.compare_models(data, baseline, run_index, var) - skill_model += skill.values + diff, avg_diff, avg_skill = self.compare_models( + data, baseline, run_index, var, metric + ) + skill_model += avg_skill.values # Get symbols based on difference and performance as well as coordinates # for the position of the triangles. x, y, alt, color, triangle, size = self.get_plot_symbols( - run_index, var_index, skill, diff_mean + run_index, var_index, avg_skill, avg_diff, metric ) ax.scatter(x, y, marker=triangle, color=color, s=size.values, zorder=3) @@ -1023,6 +1032,7 @@ def compare_models( baseline: xr.DataArray, run_index: int, var: str, + metric: str, x_dim="forecast_step", ) -> tuple[xr.DataArray, xr.DataArray, xr.DataArray]: """ @@ -1063,86 +1073,110 @@ def compare_models( baseline_score, model_score = calculate_average_over_dim(x_dim, baseline_var, data_var) diff = baseline_score - model_score - skill = self.get_skill_score(model_score, baseline_score, 0.0) + skill = self.get_skill_score(model_score, baseline_score, metric) return diff, diff.mean(dim=x_dim), skill.mean(dim=x_dim) def get_skill_score( - self, score_model: xr.DataArray, score_ref: xr.DataArray, score_perf: float + self, score_model: xr.DataArray, score_ref: xr.DataArray, metric: str ) -> xr.DataArray: """ - Calculation function for calculating skill score between a model and the baseline. + Calculate skill score comparing a model against a baseline. + + Skill score is defined as: (model_score - baseline_score) / (perfect_score - baseline_score) Parameters ---------- - score_model: xr.DataArray - The scores of the model that we aim to compare with the baseline. + score_model : xr.DataArray + The scores of the model being evaluated + score_ref : xr.DataArray + The scores of the reference/baseline model + metric : str + The metric name for which to calculate skill score - score_ref: xr.DataArray - The scores of the baseline model. + Returns + ------- + xr.DataArray + Skill scores comparing model to baseline + """ + perf_score = self.get_perf_score(metric) + skill_score = (score_model - score_ref) / (perf_score - score_ref) + return skill_score - score_perf: float - The perfect score based on the metric. For example for RMSE is 0. + def get_perf_score(self, metric: str) -> float: + """ + Get the perfect score for a given metric. - Returns + Perfect scores represent ideal performance: + - Error metrics: 0 (lower is better) + - Skill/score metrics: 1 (higher is better) + - PSNR: 100 (higher is better) + + Parameters ---------- - skill_score: xr.DataArray - Skill scores of a model compared with baseline. + metric : str + Metric name + Returns + ------- + float + Perfect score for the specified metric """ + # Metrics where lower values indicate better performance (error metrics) + if lower_is_better(metric): + return 0.0 - skill_score = (score_model - score_ref) / (score_perf - score_ref) - return skill_score + # Metrics where higher values indicate better performance (with specific perfect score) + elif metric in ["psnr"]: + return 100.0 + + # Metrics where higher values indicate better performance (default perfect score) + else: + return 1.0 def get_plot_symbols( - self, run_index: int, var_index: int, skill: xr.DataArray, diff_mean: xr.DataArray + self, + run_index: int, + var_index: int, + avg_skill: xr.DataArray, + avg_diff: xr.DataArray, + metric: str, ) -> tuple[int, float, str, str, str, xr.DataArray]: """ - Get the triangle symbols per comparison model with the correct size and color - based on score improvement or deterioration. + Determine plot symbol properties based on performance difference. Parameters ---------- - run_index: int - The order index over the run_ids. - var_index: float - The order index over the channels. - skill: xarray.DataArray - The skill of the model - diff_mean: xr.DataArray - The average difference between the baseline and the model. Determines improvement or - deterioration over baseline. + run_index : int + Index of the model. + var_index : int + Index of the variable/channel. + avg_skill : xr.DataArray + Average skill score of the model. + avg_diff : xr.DataArray + Average difference between baseline and model. + metric : str + Metric used for interpretation. Returns - ---------- - x: int - x coordinate of the triangle that indicates improvement or deterioration over baseline. - - y: float - y coordinate of the triangle that indicates improvement or deterioration over baseline. - - alt: str - str that indicates the alternative hypothesis test for Wilcoxon test of significance. - - color: str - The color "red" or "blue" that indicates improvement or deterioration over baseline. - triangle: str - The triangle symbol "^" or "v" that indicates improvement or deterioration over - baseline. - size: xr.DataArray - Size of the triangles in the final plot + ------- + Tuple[int, float, str, str, str, xr.DataArray] + x, y coordinates, alternative hypothesis, color, triangle symbol, size. """ - if diff_mean > 0: - # A better than B + + # Determine if diff_mean indicates improvement + is_improvement = (avg_diff > 0 and lower_is_better(metric)) or ( + avg_diff < 0 and not lower_is_better(metric) + ) + + if is_improvement: alt = "greater" modus = "better" color = "blue" - elif diff_mean < 0: - # A worse than B + elif not is_improvement and avg_diff != 0: alt = "less" modus = "worse" color = "red" else: - # Equal performance (conservative fallback) alt = "two-sided" modus = "different" @@ -1153,7 +1187,7 @@ def get_plot_symbols( # First row is model 1 vs model 0 y = var_index + 0.5 - size = 200 * (1 - (1 / (1 + abs(skill) / self.improvement))) # Add base size to all + size = 200 * (1 - (1 / (1 + abs(avg_skill) / self.improvement))) # Add base size to all return x, y, alt, color, triangle, size @@ -1184,7 +1218,12 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path) -> None: os.makedirs(self.out_plot_dir, exist_ok=True) def plot( - self, data: list[xr.DataArray], runs: list[str], channels: list[str], tag: str + self, + data: list[xr.DataArray], + runs: list[str], + metric: str, + channels: list[str], + tag: str, ) -> None: """ Plot (ratio) bar plots comparing performance between different run_ids over channels of @@ -1196,6 +1235,8 @@ def plot( List of (xarray) DataArrays with the scores (stream, region and metric specific) runs: List containing runs (in str format) to be compared (provided in the config) + metric: + Metric name channels: List containing channels (in str format) of interest (provided in the config) tag: @@ -1219,7 +1260,7 @@ def plot( ax[run_index - 1].barh( np.arange(len(ratio_score)), ratio_score, - color=self.colors(ratio_score), + color=self.colors(ratio_score, metric), align="center", edgecolor="black", linewidth=0.5, @@ -1244,7 +1285,11 @@ def plot( plt.close(fig) def calc_ratio_per_run_id( - self, data: list[xr.DataArray], channels: list[str], run_index: int, x_dim="channel" + self, + data: list[xr.DataArray], + channels: list[str], + run_index: int, + x_dim="channel", ) -> tuple[np.array, str]: """ This function calculates the ratio per comparison model for each channel. @@ -1284,7 +1329,7 @@ def calc_ratio_per_run_id( ratio_score = np.array(ratio_score) - 1 return ratio_score, channels_per_comparison - def colors(self, ratio_score: np.array) -> list[tuple]: + def colors(self, ratio_score: np.array, metric: str) -> list[tuple]: """ This function calculates colormaps based on the skill scores. From negative value blue color variations should be given otherwise red color variations should be given. @@ -1293,13 +1338,18 @@ def colors(self, ratio_score: np.array) -> list[tuple]: ---------- ratio_score: np.array The (ratio) skill for a specific model + metric: str + The metric of interest Returns ---------- colors: list[tuple] The color magnitude (blue to red) of the bars in the plots """ max_val = np.abs(ratio_score).max() - cmap = plt.get_cmap("bwr") + if lower_is_better(metric): + cmap = plt.get_cmap("bwr") + else: + cmap = plt.get_cmap("bwr_r") colors = [cmap(0.5 + v / (2 * max_val)) for v in ratio_score] return colors @@ -1342,3 +1392,8 @@ def calculate_average_over_dim( model_score = data_var.mean(dim=[dim for dim in data_var.dims if dim != x_dim], skipna=True) return baseline_score, model_score + + +def lower_is_better(metric: str) -> bool: + # Determine whether lower or higher is better + return metric in {"l1", "l2", "mse", "rmse", "vrmse", "bias", "crps", "spread"} From a43c028e6e864cb5ba60adcf0bd7fc5bcfd21199 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:15:39 +0100 Subject: [PATCH 24/32] [eval][1122] Plot scores on a map (#1176) * first version of score maps * add maps to compute_scores * fix single sample situation * fix single sample * lint * restore score.py * fix bug in metric stream * default flag to false * Minor correction, a line was deleted by mistake? (#1193) * fix * working setup for regridded data * fix missing valid time case * lint and fix color in score cards * fix path for score maps * Allow plotting score maps every time --------- Co-authored-by: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> --- config/evaluate/eval_config.yml | 1 + .../src/weathergen/evaluate/clim_utils.py | 15 +- .../src/weathergen/evaluate/io_reader.py | 180 ++++++++++++---- .../src/weathergen/evaluate/plotter.py | 59 ++++-- .../src/weathergen/evaluate/run_evaluation.py | 5 +- .../evaluate/src/weathergen/evaluate/score.py | 47 +---- .../evaluate/src/weathergen/evaluate/utils.py | 193 +++++++++++++----- 7 files changed, 343 insertions(+), 157 deletions(-) diff --git a/config/evaluate/eval_config.yml b/config/evaluate/eval_config.yml index 3f436f736..85157728d 100644 --- a/config/evaluate/eval_config.yml +++ b/config/evaluate/eval_config.yml @@ -21,6 +21,7 @@ evaluation: summary_plots : true summary_dir: "./plots/" plot_ensemble: "members" #supported: false, "std", "minmax", "members" + plot_score_maps: false #plot scores on a 2D maps. it slows down score computation print_summary: false #print out score values on screen. it can be verbose log_scale: false add_grid: false diff --git a/packages/evaluate/src/weathergen/evaluate/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/clim_utils.py index 65091dc4d..7ff75986f 100644 --- a/packages/evaluate/src/weathergen/evaluate/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/clim_utils.py @@ -124,8 +124,13 @@ def align_clim_data( for fstep, target_data in target_output.items(): samples = np.unique(target_data.sample.values) for sample in tqdm(samples, f"Aligning climatology for forecast step {fstep}"): - sample_mask = target_data.sample.values == sample - timestamp = target_data.valid_time.values[sample_mask][0] + sel_key = "sample" if "sample" in target_data.dims else "ipoint" + sel_val = ( + sample if "sample" in target_data.dims else (target_data.sample.values == sample) + ) + sel_mask = {sel_key: sel_val} + + timestamp = target_data.sel(sel_mask).valid_time.values[0] # Prepare climatology data for each sample matching_time_idx = match_climatology_time(timestamp, clim_data) @@ -141,8 +146,8 @@ def align_clim_data( ) .transpose("grid_points", "channels") # dimensions specific to anemoi ) - target_lats = target_data.loc[{"ipoint": sample_mask}].lat.values - target_lons = target_data.loc[{"ipoint": sample_mask}].lon.values + target_lats = target_data.loc[sel_mask].lat.values + target_lons = target_data.loc[sel_mask].lon.values # check if target coords match cached target coords # if they do, use cached clim_indices if ( @@ -174,7 +179,7 @@ def align_clim_data( clim_values = prepared_clim_data.isel(grid_points=clim_indices).values try: if len(samples) > 1: - aligned_clim_data[fstep].loc[{"ipoint": sample_mask}] = clim_values + aligned_clim_data[fstep].loc[sel_mask] = clim_values else: aligned_clim_data[fstep] = clim_values except (ValueError, IndexError) as e: diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 5b78ad507..61bffef92 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -88,7 +88,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | self.run_id = run_id self.private_paths = private_paths self.streams = eval_cfg.streams.keys() - self.data = None + # TODO: propagate it to the other functions using global plotting opts + self.global_plotting_options = eval_cfg.get("global_plotting_options", {}) # If results_base_dir and model_base_dir are not provided, default paths are used self.model_base_dir = self.eval_cfg.get("model_base_dir", None) @@ -130,6 +131,13 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" return list() + def is_regular(self, stream: str) -> bool: + """ + Placeholder implementation to check if lat/lon are regularly spaced. + Override in subclass. + """ + return True + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: """Placeholder to load pre-computed scores for a given run, stream, metric""" return None @@ -496,9 +504,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): _logger.error(f"Zarr file {self.fname_zarr} does not exist.") - # raise FileNotFoundError( - # f"Zarr file {self.fname_zarr} does not exist or is not a directory." - # ) + raise FileNotFoundError( + f"Zarr file {self.fname_zarr} does not exist or is not a directory." + ) def get_inference_config(self): """ @@ -610,8 +618,7 @@ def get_data( for fstep in fsteps: _logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...") - da_tars_fs, da_preds_fs = [], [] - pps = [] + da_tars_fs, da_preds_fs, pps = [], [], [] for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"): out = zio.get_data(sample, stream, fstep) @@ -642,60 +649,57 @@ def get_data( _logger.debug(f"Selecting ensemble members {ensemble}.") pred = pred.sel(ens=ensemble) - if ensemble == ["mean"]: - _logger.debug("Averaging over ensemble members.") - pred = pred.mean("ens", keepdims=True) - else: - _logger.debug(f"Selecting ensemble members {ensemble}.") - pred = pred.sel(ens=ensemble) - da_tars_fs.append(target.squeeze()) da_preds_fs.append(pred.squeeze()) - if len(da_tars_fs) > 0: - fsteps_final.append(fstep) + if not da_tars_fs: + _logger.info( + f"[{self.run_id} - {stream}] No valid data found for fstep {fstep}." + ) + continue + + fsteps_final.append(fstep) _logger.debug( f"Concatenating targets and predictions for stream {stream}, " f"forecast_step {fstep}..." ) - if da_tars_fs: + # faster processing + if self.is_regular(stream): + # Efficient concatenation for regular grid + da_preds_fs = _force_consistent_grids(da_preds_fs) + da_tars_fs = _force_consistent_grids(da_tars_fs) + + else: + # Irregular (scatter) case. concatenate over ipoint da_tars_fs = xr.concat(da_tars_fs, dim="ipoint") da_preds_fs = xr.concat(da_preds_fs, dim="ipoint") - if len(samples) == 1: - # Ensure sample coordinate is repeated along ipoint even if only one sample - da_tars_fs = da_tars_fs.assign_coords( - sample=( - "ipoint", - np.repeat(da_tars_fs.sample.values, len(da_tars_fs.ipoint)), - ) - ) - da_preds_fs = da_preds_fs.assign_coords( - sample=( - "ipoint", - np.repeat(da_preds_fs.sample.values, len(da_preds_fs.ipoint)), - ) - ) - if set(channels) != set(all_channels): - _logger.debug( - f"Restricting targets and predictions to channels {channels} " - f"for stream {stream}..." + if len(samples) == 1: + _logger.debug("Repeating sample coordinate for single-sample case.") + for da in (da_tars_fs, da_preds_fs): + da.assign_coords( + sample=("ipoint", np.repeat(da.sample.values, da.sizes["ipoint"])) ) - da_tars_fs, da_preds_fs, channels = dc.get_derived_channels( - da_tars_fs, da_preds_fs - ) + if set(channels) != set(all_channels): + _logger.debug( + f"Restricting targets and predictions to channels {channels} " + f"for stream {stream}..." + ) - da_tars_fs = da_tars_fs.sel(channel=channels) - da_preds_fs = da_preds_fs.sel(channel=channels) + da_tars_fs, da_preds_fs, channels = dc.get_derived_channels( + da_tars_fs, da_preds_fs + ) - da_tars.append(da_tars_fs) - da_preds.append(da_preds_fs) + da_tars_fs = da_tars_fs.sel(channel=channels) + da_preds_fs = da_preds_fs.sel(channel=channels) - if return_counts: - points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps) + da_tars.append(da_tars_fs) + da_preds.append(da_preds_fs) + if return_counts: + points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps) # Safer than a list da_tars = {fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True)} @@ -796,7 +800,17 @@ def get_channels(self, stream: str) -> list[str]: return all_channels def get_ensemble(self, stream: str | None = None) -> list[str]: - """Get the list of ensemble member names for a given stream from the config.""" + """Get the list of ensemble member names for a given stream from the config. + Parameters + ---------- + stream : str + The name of the stream to get channels for. + + Returns + ------- + list[str] + A list of ensemble members. + """ _logger.debug(f"Getting ensembles for stream {stream}...") # TODO: improve this to get ensemble from io class @@ -804,6 +818,47 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) return list(dummy.prediction.as_xarray().coords["ens"].values) + # TODO: improve this + def is_regular(self, stream: str) -> bool: + """Check if the latitude and longitude coordinates are regularly spaced for a given stream. + Parameters + ---------- + stream : str + The name of the stream to get channels for. + + Returns + ------- + bool + True if the stream is regularly spaced. False otherwise. + """ + _logger.debug(f"Checking regular spacing for stream {stream}...") + + with ZarrIO(self.fname_zarr) as zio: + dummy = zio.get_data(0, stream, zio.forecast_steps[0]) + + sample_idx = zio.samples[1] if len(zio.samples) > 1 else zio.samples[0] + fstep_idx = ( + zio.forecast_steps[1] if len(zio.forecast_steps) > 1 else zio.forecast_steps[0] + ) + dummy1 = zio.get_data(sample_idx, stream, fstep_idx) + + da = dummy.prediction.as_xarray() + da1 = dummy1.prediction.as_xarray() + + if ( + da["lat"].shape != da1["lat"].shape + or da["lon"].shape != da1["lon"].shape + or not ( + np.allclose(sorted(da["lat"].values), sorted(da1["lat"].values)) + and np.allclose(sorted(da["lon"].values), sorted(da1["lon"].values)) + ) + ): + _logger.debug("Latitude and/or longitude coordinates are not regularly spaced.") + return False + + _logger.debug("Latitude and longitude coordinates are regularly spaced.") + return True + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: """ Load the pre-computed scores for a given run, stream and metric and epoch. @@ -859,3 +914,40 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): if stream.get("name") == stream_name: return stream.get(key, default) return default + + +################### Helper functions ######################## + + +def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray: + """ + Force all samples to share the same ipoint order. + + Parameters + ---------- + ref: + Input dataset + Returns + ------- + xr.DataArray + Returns a Dataset where all samples have the same lat lon and ipoint ordering + """ + + # Pick first sample as reference + ref_lat = ref[0].lat + ref_lon = ref[0].lon + + sort_idx = np.lexsort((ref_lon.values, ref_lat.values)) + npoints = sort_idx.size + aligned = [] + for a in ref: + a_sorted = a.isel(ipoint=sort_idx) + + a_sorted = a_sorted.assign_coords( + ipoint=np.arange(npoints), + lat=("ipoint", ref_lat.values[sort_idx]), + lon=("ipoint", ref_lon.values[sort_idx]), + ) + aligned.append(a_sorted) + + return xr.concat(aligned, dim="sample") diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index 42c7a17dc..cb15e6f24 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -41,7 +41,7 @@ class Plotter: Contains all basic plotting functions. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None): """ Initialize the Plotter class. @@ -57,6 +57,9 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): output_basedir: Base directory under which the plots will be saved. Expected scheme `/`. + stream: + Stream identifier for which the plots will be created. + It can also be set later via update_data_selection. """ _logger.info(f"Taking cartopy paths from {work_dir}") @@ -77,7 +80,7 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path): os.makedirs(self.out_plot_basedir, exist_ok=True) self.sample = None - self.stream = None + self.stream = stream self.fstep = None self.select = {} @@ -378,6 +381,7 @@ def create_maps_per_sample( var, tag=tag, map_kwargs=dict(map_kwargs.get(var, {})) | map_kwargs_global, + title=f"{self.stream}, {var} : fstep = {self.fstep:03} ({valid_time})", ) plot_names.append(name) @@ -392,6 +396,7 @@ def scatter_plot( varname: str, tag: str = "", map_kwargs: dict | None = None, + title: str | None = None, ): """ Plot a 2D map for a data array using scatter plot. @@ -408,6 +413,8 @@ def scatter_plot( Any tag you want to add to the plot map_kwargs: dict | None Additional keyword arguments for the map. + title: str | None + Title for the plot. Returns ------- @@ -449,11 +456,9 @@ def scatter_plot( ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson()) ax.coastlines() - valid_time = ( - data["valid_time"][0] - .values.astype("datetime64[m]") - .astype(datetime.datetime) - .strftime("%Y-%m-%dT%H%M") + assert data["lon"].shape == data["lat"].shape == data.shape, ( + f"Scatter plot:: Data shape do not match. Shapes: " + f"lon {data['lon'].shape}, lat {data['lat'].shape}, data {data.shape}." ) scatter_plt = ax.scatter( @@ -470,22 +475,34 @@ def scatter_plot( ) plt.colorbar(scatter_plt, ax=ax, orientation="horizontal", label=f"Variable: {varname}") - plt.title(f"{self.stream}, {varname} : fstep = {self.fstep:03} ({valid_time})") + plt.title(title) ax.set_global() ax.gridlines(draw_labels=False, linestyle="--", color="black", linewidth=1) # TODO: make this nicer - parts = [ - "map", - self.run_id, - tag, - str(self.sample), - valid_time, - self.stream, - varname, - "fstep", - str(self.fstep).zfill(3), - ] + parts = ["map", self.run_id, tag] + + if self.sample: + parts.append(str(self.sample)) + + if "valid_time" in data.coords: + valid_time = data["valid_time"][0].values + if ~np.isnat(valid_time): + valid_time = ( + valid_time.astype("datetime64[m]") + .astype(datetime.datetime) + .strftime("%Y-%m-%dT%H%M") + ) + + parts.append(valid_time) + + if self.stream: + parts.append(self.stream) + + parts.append(varname) + + if self.fstep is not None: + parts.extend(["fstep", f"{self.fstep:03d}"]) name = "_".join(filter(None, parts)) fname = f"{map_output_dir.joinpath(name)}.{self.image_format}" @@ -1162,6 +1179,10 @@ def get_plot_symbols( Tuple[int, float, str, str, str, xr.DataArray] x, y coordinates, alternative hypothesis, color, triangle symbol, size. """ + # Conservative choice + alt = "two-sided" + modus = "different" + color = "gray" # Determine if diff_mean indicates improvement is_improvement = (avg_diff > 0 and lower_is_better(metric)) or ( diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 8884b5e91..3bf198d07 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -101,6 +101,7 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: metrics = cfg.evaluation.metrics regions = cfg.evaluation.get("regions", ["global"]) + plot_score_maps = cfg.evaluation.get("plot_score_maps", False) global_plotting_opts = cfg.get("global_plotting_options", {}) @@ -146,7 +147,7 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: metric, ) - if metric_data is None: + if metric_data is None or plot_score_maps: metrics_to_compute.append(metric) continue @@ -166,7 +167,7 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( - reader, stream, region, metrics_to_compute + reader, stream, region, metrics_to_compute, plot_score_maps ) metric_list_to_json( diff --git a/packages/evaluate/src/weathergen/evaluate/score.py b/packages/evaluate/src/weathergen/evaluate/score.py index 384a0a1bb..88d48a098 100755 --- a/packages/evaluate/src/weathergen/evaluate/score.py +++ b/packages/evaluate/src/weathergen/evaluate/score.py @@ -893,7 +893,6 @@ def _calc_act( self, x: xr.DataArray, c: xr.DataArray, - spatial_dims: list = None, ): """ Calculate activity metric as standard deviation of forecast or target anomaly. @@ -907,25 +906,14 @@ def _calc_act( Forecast or target data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - spatial_dims: List[str] - Names of spatial dimensions over which activity is calculated. - Note: No averaging is possible over these dimensions. """ - # Check if spatial_dims are in the data - spatial_dims = ["ipoint"] if spatial_dims is None else to_list(spatial_dims) - - for dim in spatial_dims: - if dim not in x.dims: - raise ValueError( - f"Spatial dimension '{dim}' not found in prediction data dimensions: {x.dims}" - ) if c is None: - return xr.full_like(x.sum(spatial_dims), np.nan) + return xr.full_like(x.sum(self._agg_dims), np.nan) # Calculate anomalies ano = x - c - act = ano.std(dim=spatial_dims) + act = ano.std(dim=self._agg_dims) return act @@ -933,7 +921,6 @@ def calc_fact( self, p: xr.DataArray, c: xr.DataArray, - spatial_dims: list = None, ): """ Calculate forecast activity metric as standard deviation of forecast anomaly. @@ -947,18 +934,14 @@ def calc_fact( Forecast data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - spatial_dims: List[str] - Names of spatial dimensions over which activity is calculated. - Note: No averaging is possible over these dimensions. """ - return self._calc_act(p, c, spatial_dims) + return self._calc_act(p, c) def calc_tact( self, gt: xr.DataArray, c: xr.DataArray, - spatial_dims: list = None, ): """ Calculate target activity metric as standard deviation of target anomaly. @@ -972,19 +955,15 @@ def calc_tact( Target data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - spatial_dims: List[str] - Names of spatial dimensions over which activity is calculated. - Note: No averaging is possible over these dimensions. """ - return self._calc_act(gt, c, spatial_dims) + return self._calc_act(gt, c) def calc_acc( self, p: xr.DataArray, gt: xr.DataArray, c: xr.DataArray, - spatial_dims: list = None, ) -> xr.DataArray: """ Calculate anomaly correlation coefficient (ACC). @@ -1001,32 +980,22 @@ def calc_acc( Ground truth data array c: xr.DataArray Climatological mean data array, which is used to calculate anomalies - spatial_dims: List[str] - Names of spatial dimensions over which ACC is calculated. - Note: No averaging is possible over these dimensions. + Returns ------- xr.DataArray Anomaly correlation coefficient (ACC) """ - # Check if spatial_dims are in the data - spatial_dims = ["ipoint"] if spatial_dims is None else to_list(spatial_dims) - - for dim in spatial_dims: - if dim not in p.dims: - raise ValueError( - f"Spatial dimension '{dim}' not found in prediction data dimensions: {p.dims}" - ) if c is None: - return xr.full_like(p.sum(spatial_dims), np.nan) + return xr.full_like(p.sum(self._agg_dims), np.nan) # Calculate anomalies fcst_ano, obs_ano = p - c, gt - c # Calculate ACC over spatial dimensions (no grouping) - acc = (fcst_ano * obs_ano).sum(spatial_dims) / np.sqrt( - (fcst_ano**2).sum(spatial_dims) * (obs_ano**2).sum(spatial_dims) + acc = (fcst_ano * obs_ano).sum(self._agg_dims) / np.sqrt( + (fcst_ano**2).sum(self._agg_dims) * (obs_ano**2).sum(self._agg_dims) ) return acc diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 98a463a92..09ff2e9b7 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -48,7 +48,11 @@ def get_next_data(fstep, da_preds, da_tars, fsteps): def calc_scores_per_stream( - reader: Reader, stream: str, region: str, metrics: list[str] + reader: Reader, + stream: str, + region: str, + metrics: list[str], + plot_score_maps: bool = False, ) -> tuple[xr.DataArray, xr.DataArray]: """ Calculate scores for a given run and stream using the specified metrics. @@ -63,13 +67,25 @@ def calc_scores_per_stream( Region name to calculate scores for. metrics : List of metric names to calculate. - + plot_score_maps : + When it is True and the stream is on a regular grid the scores are + recomputed as a function of the "ipoint" and plotted on a 2D scatter map. + NOTE: the scores are averaged over the "sample" dimension and for most + of the metrics this does not give the same results as averaging over + the "ipoint" dimension. Returns ------- Tuple of xarray DataArray containing the scores and the number of points per sample. """ _logger.info(f"RUN {reader.run_id} - {stream}: Calculating scores for metrics {metrics}...") + if plot_score_maps: + _logger.info(f"RUN {reader.run_id} - {stream}: Plotting scores is enabled.") + + map_dir = reader.runplot_dir / "plots" / stream / "score_maps" + map_dir.mkdir(parents=True, exist_ok=True) + + _logger.info(f"RUN {reader.run_id} - {stream}: Saving plotted scores to {map_dir}") available_data = reader.check_availability(stream, mode="evaluation") @@ -77,6 +93,8 @@ def calc_scores_per_stream( samples = available_data.samples channels = available_data.channels ensemble = available_data.ensemble + is_regular = reader.is_regular(stream) + group_by_coord = None if is_regular else "sample" output_data = reader.get_data( stream, @@ -109,55 +127,48 @@ def calc_scores_per_stream( ) for (fstep, tars), (_, preds) in zip(da_tars.items(), da_preds.items(), strict=False): + if preds.ipoint.size == 0: + _logger.warning( + f"No data for stream {stream} at fstep {fstep} in region {region}. Skipping." + ) + continue + _logger.debug(f"Verifying data for stream {stream}...") preds_next, tars_next = get_next_data(fstep, da_preds, da_tars, fsteps) - if preds.ipoint.size > 0: - climatology = aligned_clim_data[fstep] if aligned_clim_data else None - score_data = VerifiedData(preds, tars, preds_next, tars_next, climatology) - # Build up computation graphs for all metrics - _logger.debug(f"Build computation graphs for metrics for stream {stream}...") - - # Add it only if it is not None - valid_scores = [ - score - for metric in metrics - if ( - score := get_score( - score_data, - metric, - agg_dims="ipoint", - group_by_coord="sample", - ) - ) - is not None - ] - - # Keep only metrics corresponding to valid_scores - valid_metric_names = [ - metric - for metric, score in zip(metrics, valid_scores, strict=False) - if score is not None - ] - - # Concatenate along a new "metric" dimension and assign metric names - combined_metrics = xr.concat(valid_scores, dim="metric") - combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) - - _logger.debug(f"Running computation of metrics for stream {stream}...") - combined_metrics = combined_metrics.compute() - combined_metrics = scalar_coord_to_dim(combined_metrics, "channel") - combined_metrics = scalar_coord_to_dim(combined_metrics, "sample") - combined_metrics = scalar_coord_to_dim(combined_metrics, "ens") - else: - # depending on the datset, there might be no data (e.g. no CERRA in southern - # hemisphere region) - _logger.warning( - f"No data available for stream {stream} at forecast step {fstep} in " - f"region {region}. Skipping metrics calculation." + climatology = aligned_clim_data[fstep] if aligned_clim_data else None + score_data = VerifiedData(preds, tars, preds_next, tars_next, climatology) + # Build up computation graphs for all metrics + _logger.debug(f"Build computation graphs for metrics for stream {stream}...") + + # Add it only if it is not None + valid_scores = [] + for metric in metrics: + score = get_score( + score_data, + metric, + agg_dims="ipoint", + group_by_coord=group_by_coord, ) - continue + if score is not None: + valid_scores.append(score) + + # Keep only metrics corresponding to valid_scores + valid_metric_names = [ + metric + for metric, score in zip(metrics, valid_scores, strict=False) + if score is not None + ] + + combined_metrics = xr.concat(valid_scores, dim="metric") + combined_metrics = combined_metrics.assign_coords(metric=valid_metric_names) + + _logger.debug(f"Running computation of metrics for stream {stream}...") + combined_metrics = combined_metrics.compute() + + for coord in ["channel", "sample", "ens"]: + combined_metrics = scalar_coord_to_dim(combined_metrics, coord) assert int(combined_metrics.forecast_step) == int(fstep), ( "Different steps in data and metrics. Please check." @@ -174,11 +185,98 @@ def calc_scores_per_stream( criteria["ens"] = combined_metrics.ens metric_stream.loc[criteria] = combined_metrics + ######### + + if is_regular and plot_score_maps: + _logger.info(f"Plotting scores on a map {stream} - forecast step: {fstep}...") + _plot_score_maps_per_stream(reader, map_dir, stream, region, score_data, metrics, fstep) + _logger.info(f"Scores for run {reader.run_id} - {stream} calculated successfully.") return metric_stream, points_per_sample +def _plot_score_maps_per_stream( + reader: Reader, + map_dir: str, + stream: str, + region: str, + score_data: VerifiedData, + metrics: list[str], + fstep: int, +) -> None: + """Plot 2D score maps for all metrics and channels. + Parameters + ---------- + reader: Reader + Reader object containing all infos about the run + map_dir: str + Directory where the plots are saved. + stream: str + Stream name to plot score maps for. + region : + Region name to plot score maps for. + score_data: VerifiedData + prediction and target stored in the data class. + metrics: str + List of all metrics to plot. + fstep: + forecast step to plot. + + Return + ------ + None + """ + + cfg = reader.global_plotting_options + + # TODO: add support for climatology-dependent metrics as well + + plotter = Plotter( + { + "image_format": cfg.get("image_format", "png"), + "dpi_val": cfg.get("dpi_val", 300), + "fig_size": cfg.get("fig_size", (8, 10)), + }, + reader.runplot_dir, + stream, + ) + + preds = score_data.prediction + + plot_metrics = xr.concat( + [get_score(score_data, m, agg_dims="sample") for m in metrics], dim="metric" + ) + + plot_metrics = plot_metrics.assign_coords( + lat=preds.lat.reset_coords(drop=True), + lon=preds.lon.reset_coords(drop=True), + metric=metrics, + ).compute() + + if "ens" in preds.dims: + plot_metrics["ens"] = preds.ens + + has_ens = "ens" in plot_metrics.coords + ens_values = plot_metrics.coords["ens"].values if has_ens else [None] + + for metric in plot_metrics.coords["metric"].values: + for ens_val in tqdm(ens_values, f"Plotting metric - {metric}"): + tag = f"score_maps_{region}_{metric}_fstep_{fstep}" + ( + f"_ens_{ens_val}" if ens_val is not None else "" + ) + for channel in plot_metrics.coords["channel"].values: + sel = {"metric": metric, "channel": channel} + if ens_val is not None: + sel["ens"] = ens_val + + data = plot_metrics.sel(**sel).squeeze() + title = f"{metric} - {channel}: fstep {fstep}" + ( + f", ens {ens_val}" if ens_val is not None else "" + ) + plotter.scatter_plot(data, map_dir, channel, tag=tag, title=title) + + def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: """ Plot the data for a given run and stream. @@ -445,7 +543,6 @@ def common_ranges( the global plotting configuration with the ranges added and included for each variable (and for each stream). """ - for var in plot_chs: if var in maps_config: if not isinstance(maps_config[var].get("vmax"), (int | float)): @@ -455,7 +552,7 @@ def common_ranges( if not isinstance(maps_config[var].get("vmin"), (int | float)): list_min = calc_bounds(data_tars, data_preds, var, "min") - + list_min = np.concatenate([arr.flatten() for arr in list_min]).tolist() maps_config[var].update({"vmin": float(min(list_min))}) else: From 7e7141960b2349a04ad49eddcc4428fbc0a84e35 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Mon, 10 Nov 2025 20:46:06 +0100 Subject: [PATCH 25/32] Fix DDP without FSDP (#1227) * Fix DDP without FSDP * Fixed taht freezing would not have worked with only DDP --- src/weathergen/train/trainer.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index a792c67c7..0c85df93a 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -162,9 +162,13 @@ def init_model_and_shard(self, cf, devices): targets_num_channels = self.dataset.get_targets_num_channels() targets_coords_size = self.dataset.get_targets_coords_size() - with torch.device("meta"): + if cf.with_ddp and not cf.with_fsdp: model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + else: + with torch.device("meta"): + model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + # freeze request model part for name, module in model.named_modules(): name = module.name if hasattr(module, "name") else name # avoid the whole model element which has name '' @@ -174,6 +178,8 @@ def init_model_and_shard(self, cf, devices): freeze_weights(module) if cf.with_ddp and not cf.with_fsdp: + # create DDP model if running without FSDP + model = model.to("cuda") model = torch.nn.parallel.DistributedDataParallel( model, broadcast_buffers=True, @@ -182,7 +188,8 @@ def init_model_and_shard(self, cf, devices): bucket_cap_mb=512, ) - if cf.with_ddp and cf.with_fsdp: + elif cf.with_ddp and cf.with_fsdp: + # with DDP *and() FSDP fsdp_kwargs = { "mp_policy": ( MixedPrecisionPolicy( @@ -243,14 +250,14 @@ def init_model_and_shard(self, cf, devices): for tensor in itertools.chain(model.parameters(), model.buffers()): assert tensor.device == torch.device("meta") - # For reasons we do not yet fully understand, when using train continue in some - # instances, FSDP2 does not register the forward_channels and forward_columns - # functions in the embedding engine as forward functions. Thus, yielding a crash - # because the input tensors are not converted to DTensors. This seems to primarily - # occur during validation. - for embed in model.embed_engine.embeds: - torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_channels") - torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_columns") + # For reasons we do not yet fully understand, when using train continue in some + # instances, FSDP2 does not register the forward_channels and forward_columns + # functions in the embedding engine as forward functions. Thus, yielding a crash + # because the input tensors are not converted to DTensors. This seems to primarily + # occur during validation. + for embed in model.embed_engine.embeds: + torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_channels") + torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_columns") return model, model_params @@ -298,7 +305,8 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): if run_id_contd is None: self.model.to_empty(device="cuda") - self.model.reset_parameters() + if cf.with_fsdp: + self.model.reset_parameters() else: if is_root(): logger.info( @@ -327,7 +335,7 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): ) # if with_fsdp then parameter count is unreliable - if (is_root() and not cf.with_fsdp) or not cf.with_ddp: + if is_root() and not cf.with_fsdp and not cf.with_ddp: self.model.print_num_parameters() # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ From 5cb1ac4f1237651414f22f1038222c9089eeedf4 Mon Sep 17 00:00:00 2001 From: iluise <72020169+iluise@users.noreply.github.com> Date: Tue, 11 Nov 2025 09:53:50 +0100 Subject: [PATCH 26/32] refactor export scripts - Part 1 (#1223) * move stuff around * move stuff around * rename files --- .../weathergen/evaluate/export/__init__.py | 1 + .../weathergen/evaluate/export/cf_utils.py | 186 +++++ .../weathergen/evaluate/export/export_core.py | 198 +++++ .../evaluate/export/export_inference.py | 211 +++++ .../weathergen/evaluate/export/io_utils.py | 67 ++ .../src/weathergen/evaluate/export/reshape.py | 135 ++++ .../weathergen/evaluate/export_inference.py | 738 ------------------ pyproject.toml | 2 +- 8 files changed, 799 insertions(+), 739 deletions(-) create mode 100644 packages/evaluate/src/weathergen/evaluate/export/__init__.py create mode 100644 packages/evaluate/src/weathergen/evaluate/export/cf_utils.py create mode 100644 packages/evaluate/src/weathergen/evaluate/export/export_core.py create mode 100755 packages/evaluate/src/weathergen/evaluate/export/export_inference.py create mode 100644 packages/evaluate/src/weathergen/evaluate/export/io_utils.py create mode 100644 packages/evaluate/src/weathergen/evaluate/export/reshape.py delete mode 100755 packages/evaluate/src/weathergen/evaluate/export_inference.py diff --git a/packages/evaluate/src/weathergen/evaluate/export/__init__.py b/packages/evaluate/src/weathergen/evaluate/export/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/__init__.py @@ -0,0 +1 @@ + diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py new file mode 100644 index 000000000..20fde50f7 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -0,0 +1,186 @@ +import logging + +import numpy as np +import xarray as xr +from omegaconf import OmegaConf + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset: + """ + Add Gaussian grid metadata following CF conventions. + + Parameters + ---------- + ds : xr.Dataset + Dataset to add metadata to + grid_info : dict, optional + Dictionary with grid information: + - 'N': Gaussian grid number (e.g., N320) + - 'reduced': Whether it's a reduced Gaussian grid + + Returns + ------- + xr.Dataset + Dataset with added grid metadata + """ + ds = ds.copy() + # Add grid mapping information + ds.attrs["grid_type"] = "gaussian" + + # If grid info provided, add it + if grid_info: + ds.attrs["gaussian_grid_number"] = grid_info.get("N", "unknown") + ds.attrs["gaussian_grid_type"] = "reduced" if grid_info.get("reduced", False) else "regular" + + return ds + + +def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: + """ + Add CF conventions to the dataset attributes. + + Parameters + ---------- + stream : Stream name to include in the title attribute. + run_id : Run ID to include in the title attribute. + ds : Input xarray Dataset to add conventions to. + Returns + ------- + xarray Dataset with CF conventions added to attributes. + """ + ds = ds.copy() + ds.attrs["title"] = f"WeatherGenerator Output for {run_id} using stream {stream}" + ds.attrs["institution"] = "WeatherGenerator Project" + ds.attrs["source"] = "WeatherGenerator v0.0" + ds.attrs["history"] = ( + "Created using the export_inference.py script on " + + np.datetime_as_string(np.datetime64("now"), unit="s") + ) + ds.attrs["Conventions"] = "CF-1.12" + return ds + + +def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: + """ + Modified CF parser that handles both regular and Gaussian grids. + + Parameters + ---------- + config : OmegaConf + Configuration for CF parsing + ds : xr.Dataset + Input dataset + + Returns + ------- + xr.Dataset + Parsed dataset with appropriate structure for grid type + """ + # Detect if this is a Gaussian grid + is_gaussian = "ncells" in ds.dims + + variables = {} + mapping = config["variables"] + + # Handle dimensions based on grid type + if is_gaussian: + # For Gaussian grids, keep ncells and don't try to create lat/lon dimensions + for var_name in ds.data_vars: + if var_name in ["lat", "lon"]: + continue + + variable = ds[var_name] + + if var_name not in mapping: + # Variable not in mapping - skip or keep as-is + variables[var_name] = variable + continue + + dims = list(variable.dims) + + attributes = dict( + standard_name=mapping[var_name].get("std", var_name), + units=mapping[var_name].get("std_unit", "unknown"), + coordinates="lat lon", # Mark auxiliary coordinates + ) + + # Get mapped variable name or use original + mapped_name = mapping[var_name].get("var", var_name) + + variables[mapped_name] = xr.DataArray( + data=variable.values, + dims=dims, + coords={coord: ds.coords[coord] for coord in variable.coords if coord in ds.coords}, + attrs=attributes, + name=mapped_name, + ) + + # Preserve lat/lon as coordinate variables with proper attributes + if "lat" in ds.coords: + ds.coords["lat"].attrs = { + "standard_name": "latitude", + "long_name": "latitude", + "units": "degrees_north", + } + if "lon" in ds.coords: + ds.coords["lon"].attrs = { + "standard_name": "longitude", + "long_name": "longitude", + "units": "degrees_east", + } + + else: + # Original logic for regular grids + ds_attributes = {} + for dim_name, dim_dict in config["dimensions"].items(): + if dim_name == dim_dict["wg"]: + dim_attributes = dict(standard_name=dim_dict.get("std", None)) + if dim_dict.get("std_unit", None) is not None: + dim_attributes["units"] = dim_dict["std_unit"] + ds_attributes[dim_dict["wg"]] = dim_attributes + continue + + if dim_name in ds.dims: + ds = ds.rename_dims({dim_name: dim_dict["wg"]}) + + dim_attributes = dict(standard_name=dim_dict.get("std", None)) + if "std_unit" in dim_dict and dim_dict["std_unit"] is not None: + dim_attributes["units"] = dim_dict["std_unit"] + ds_attributes[dim_dict["wg"]] = dim_attributes + + for var_name in ds.data_vars: + dims = ["pressure", "valid_time", "latitude", "longitude"] + if mapping[var_name]["level_type"] == "sfc": + dims.remove("pressure") + + coordinates = {} + for coord, new_name in config["coordinates"][mapping[var_name]["level_type"]].items(): + coordinates |= { + new_name: ( + ds.coords[coord].dims, + ds.coords[coord].values, + ds_attributes[new_name], + ) + } + + variable = ds[var_name] + attributes = dict( + standard_name=mapping[var_name]["std"], + units=mapping[var_name]["std_unit"], + ) + + variables[mapping[var_name]["var"]] = xr.DataArray( + data=variable.values, + dims=dims, + coords={**coordinates, "valid_time": ds["valid_time"].values}, + attrs=attributes, + name=mapping[var_name]["var"], + ) + + dataset = xr.merge(variables.values()) + dataset.attrs = ds.attrs + + return dataset diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_core.py b/packages/evaluate/src/weathergen/evaluate/export/export_core.py new file mode 100644 index 000000000..c444e743c --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/export_core.py @@ -0,0 +1,198 @@ +import logging +from multiprocessing import Pool + +import xarray as xr +from omegaconf import OmegaConf +from tqdm import tqdm + +from weathergen.common.config import get_model_results +from weathergen.common.io import ZarrIO +from weathergen.evaluate.export.cf_utils import ( + add_conventions, + add_gaussian_grid_metadata, + cf_parser_gaussian_aware, +) +from weathergen.evaluate.export.io_utils import get_data_worker, output_filename +from weathergen.evaluate.export.reshape import reshape_dataset_adaptive + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def export_model_outputs( + run_id: str, + samples: list, + stream: str, + dtype: str, + fsteps: list, + channels: list, + fstep_hours: int, + n_processes: list, + epoch: int, + rank: int, + output_dir: str, + output_format: str, + config: OmegaConf, +) -> None: + """ + Retrieve data from Zarr store and save one sample to each NetCDF file. + Using multiprocessing to speed up data retrieval. + + Parameters + ---------- + run_id : str + Run ID to identify the Zarr store. + samples : list + Sample to process + stream : str + Stream name to retrieve data for (e.g., 'ERA5'). + dtype : str + Type of data to retrieve ('target' or 'prediction'). + fsteps : list + List of forecast steps to retrieve. If None, retrieves all available forecast steps. + channels : list + List of channels to retrieve. If None, retrieves all available channels. + n_processes : list + Number of parallel processes to use for data retrieval. + ecpoch : int + Epoch number to identify the Zarr store. + rank : int + Rank number to identify the Zarr store. + output_dir : str + Directory to save the NetCDF files. + output_format : str + Output file format (currently only 'netcdf' supported). + config : OmegaConf + Loaded config for cf_parser function. + """ + if dtype not in ["target", "prediction"]: + raise ValueError(f"Invalid type: {dtype}. Must be 'target' or 'prediction'.") + + fname_zarr = get_model_results(run_id, epoch, rank) + with ZarrIO(fname_zarr) as zio: + zio_forecast_steps = sorted([int(step) for step in zio.forecast_steps]) + zio_samples = sorted([int(sample) for sample in zio.samples]) + dummy_out = zio.get_data(0, stream, zio_forecast_steps[0]) + all_channels = dummy_out.target.channels + channels = all_channels if channels is None else channels + + fsteps = zio_forecast_steps if fsteps is None else sorted([int(fstep) for fstep in fsteps]) + + samples = ( + zio_samples + if samples is None + else sorted([int(sample) for sample in samples if sample in samples]) + ) + with Pool(processes=n_processes, maxtasksperchild=5) as pool: + for sample_idx in tqdm(samples): + da_fs = [] + step_tasks = [ + (sample_idx, fstep, run_id, stream, dtype, epoch, rank) for fstep in fsteps + ] + for result in tqdm( + pool.imap_unordered(get_data_worker, step_tasks, chunksize=1), + total=len(step_tasks), + desc=f"Processing {run_id} - stream: {stream} - sample: {sample_idx}", + ): + if result is not None: + # Select only requested channels + result = result.as_xarray().squeeze() + if set(channels) != set(all_channels): + available_channels = result.channel.values + existing_channels = [ch for ch in channels if ch in available_channels] + if len(existing_channels) < len(channels): + _logger.info( + f"The following channels were not found: " + f"{list(set(channels) - set(existing_channels))}. Skipping them." + ) + result = result.sel(channel=existing_channels) + # reshape result - use adaptive function to handle both regular and Gaussian + # grids + result = reshape_dataset_adaptive(result) + da_fs.append(result) + + _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {dtype}.") + _logger.info( + f"Saving sample {sample_idx} data to {output_format} format in {output_dir}." + ) + + save_sample_to_netcdf( + str(dtype)[:4], + da_fs, + fstep_hours, + run_id, + output_dir, + output_format, + config, + ) + pool.terminate() + pool.join() + + +def save_sample_to_netcdf( + type_str, + array_list, + fstep_hours, + run_id, + output_dir, + output_format, + config, +) -> None: + """ + Uses list of pred/target xarray DataArrays to save one sample to a NetCDF file. + + Parameters + ---------- + type_str : str + Type of data ('pred' or 'targ') to include in the filename. + dict_sample_all_steps : dict + Dictionary where keys is sample index and values is a list of xarray DataArrays + for all the forecast steps + fstep_hours : np.timedelta64 + Time difference between forecast steps (e.g., 6 hours). + run_id : str + Run ID to include in the filename. + output_dir : str + Directory to save the NetCDF files. + output_format : str + Output file format (currently only 'netcdf' supported). + config : OmegaConf + Loaded config for cf_parser function. + """ + # find forecast_ref_time + frt = array_list[0].valid_time.values[0] - fstep_hours * int(array_list[0].forecast_step.values) + out_fname = output_filename(type_str, run_id, output_dir, output_format, frt) + # check if file already exists + if out_fname.exists(): + _logger.info(f"File {out_fname} already exists. Skipping.") + else: + sample_all_steps = xr.concat( + array_list, + dim="valid_time", + data_vars="minimal", + coords="different", + compat="equals", + combine_attrs="drop", + ).sortby("valid_time") + _logger.info(f"Saving to {out_fname}.") + sample_all_steps = sample_all_steps.assign_coords(forecast_ref_time=frt) + stream = str(sample_all_steps.coords["stream"].values) + + if "sample" in sample_all_steps.coords: + sample_all_steps = sample_all_steps.drop_vars("sample") + + sample_all_steps = cf_parser_gaussian_aware(config, sample_all_steps) + # Add Gaussian grid metadata if detected + if "ncells" in sample_all_steps.dims: + sample_all_steps = add_gaussian_grid_metadata(sample_all_steps) + _logger.info("Detected and preserved Gaussian grid structure") + # add forecast_period attributes + n_hours = fstep_hours.astype("int64") + sample_all_steps["forecast_period"] = sample_all_steps["forecast_step"] * n_hours + sample_all_steps["forecast_period"].attrs = { + "standard_name": "forecast_period", + "long_name": "time since forecast_reference_time", + "units": "hours", + } + sample_all_steps = add_conventions(stream, run_id, sample_all_steps) + sample_all_steps.to_netcdf(out_fname, mode="w", compute=False) diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py new file mode 100755 index 000000000..4120a0200 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -0,0 +1,211 @@ +#!/usr/bin/env -S uv run +# /// script +# dependencies = [ +# "weathergen-evaluate", +# "weathergen-common", +# "weathergen" +# ] +# [tool.uv.sources] +# weathergen-evaluate = { path = "../../../../../packages/evaluate" } +# weathergen-common = { path = "../../../../../packages/common" } +# weathergen = { path = "../../../../../" } +# /// +## Example USAGE: uv run export --run-id grwnhykd --stream ERA5 \ +## --output-dir /p/home/jusers/owens1/jureca/WeatherGen/test_output1 \ +## --format netcdf --type prediction target --fsteps 1 --samples 1 +import argparse +import logging +import sys +from pathlib import Path + +import numpy as np +from omegaconf import OmegaConf + +from weathergen.common.config import _REPO_ROOT +from weathergen.evaluate.export.export_core import export_model_outputs + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +if not _logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + _logger.addHandler(handler) + + +def parse_args(args: list) -> argparse.Namespace: + """ + Parse command line arguments. + + Parameters + ---------- + args : + List of command line arguments. + + Returns + ------- + Parsed command line arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--run-id", + type=str, + help=" Zarr folder which contains target and inference results", + required=True, + ) + + parser.add_argument( + "--type", + type=str, + choices=["prediction", "target"], + nargs="+", + help="List of type of data to convert (e.g. prediction target)", + required=True, + ) + + parser.add_argument( + "--output-dir", + type=str, + help="Output directory to save the NetCDF files", + required=True, + ) + + parser.add_argument( + "--format", + type=str, + choices=["netcdf", "grib"], + help="Output file format (currently only netcdf supported)", + required=True, + ) + + parser.add_argument( + "--stream", + type=str, + choices=["ERA5"], + help="Stream name to retrieve data for", + required=True, + ) + + parser.add_argument( + "--fsteps", + type=int, + nargs="+", + default=None, + help="List of forecast steps to retrieve (e.g. 1 2 3). " + "If not provided, retrieves all available forecast steps.", + ) + + parser.add_argument( + "--samples", + type=int, + nargs="+", + default=None, + help="List of samples to process (e.g. 0 1 2). If not provided, processes all samples.", + ) + + parser.add_argument( + "--channels", + type=str, + nargs="+", + default=None, + help="List of channels to retrieve (e.g., 'q_500 t_2m'). " + "If not provided, retrieves all available channels.", + ) + + parser.add_argument( + "--n-processes", + type=int, + default=8, + help="Number of parallel processes to use for data retrieval", + ) + + parser.add_argument( + "--fstep-hours", + type=int, + default=6, + help="Time difference between forecast steps in hours (e.g., 6)", + ) + + parser.add_argument( + "--epoch", + type=int, + default=0, + help="Epoch number to identify the Zarr store", + ) + + parser.add_argument( + "--rank", + type=int, + default=0, + help="Rank number to identify the Zarr store", + ) + + args, unknown_args = parser.parse_known_args(args) + if unknown_args: + _logger.warning(f"Unknown arguments: {unknown_args}") + return args + + +def export() -> None: + """ + Main function to export data from Zarr store to NetCDF files. + """ + # By default, arguments from the command line are read. + export_from_args(sys.argv[1:]) + + +def export_from_args(args: list) -> None: + # Get run_id zarr data as lists of xarray DataArrays + """ + Export data from Zarr store to NetCDF files based on command line arguments. + Parameters + ---------- + args : List of command line arguments. + """ + args = parse_args(sys.argv[1:]) + run_id = args.run_id + data_type = args.type + output_dir = args.output_dir + output_format = args.format + samples = args.samples + stream = args.stream + fsteps = args.fsteps + fstep_hours = np.timedelta64(args.fstep_hours, "h") + channels = args.channels + n_processes = args.n_processes + epoch = args.epoch + rank = args.rank + + # Ensure output directory exists + out_dir = Path(output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # Load configuration + config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") + config = OmegaConf.load(config_file) + # check config loaded correctly + assert len(config["variables"].keys()) > 0, "Config file not loaded correctly" + + for dtype in data_type: + _logger.info(f"Starting processing {dtype} for run ID {run_id}.") + export_model_outputs( + run_id, + samples, + stream, + dtype, + fsteps, + channels, + fstep_hours, + n_processes, + epoch, + rank, + output_dir, + output_format, + config, + ) + _logger.info(f"Finished processing {dtype} for run ID {run_id}.") + + +if __name__ == "__main__": + export() diff --git a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py new file mode 100644 index 000000000..98cdbb04d --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py @@ -0,0 +1,67 @@ +import logging +from pathlib import Path + +import numpy as np +import xarray as xr + +from weathergen.common.config import get_model_results +from weathergen.common.io import ZarrIO + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def output_filename( + prefix: str, + run_id: str, + output_dir: str, + output_format: str, + forecast_ref_time: np.datetime64, +) -> Path: + """ + Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample + index, output directory, format and forecast_ref_time. + + Parameters + ---------- + prefix : Prefix for file name (e.g., 'pred' or 'targ'). + run_id :Run ID to include in the filename. + output_dir : Directory to save the output file. + output_format : Output file format (currently only 'netcdf' supported). + forecast_ref_time : Forecast reference time to include in the filename. + + Returns + ------- + Full path to the output file. + """ + if output_format not in ["netcdf"]: + raise ValueError( + f"Unsupported output format: {output_format}, supported formates are ['netcdf']" + ) + file_extension = "nc" + frt = np.datetime_as_string(forecast_ref_time, unit="h") + out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" + return out_fname + + +def get_data_worker(args: tuple) -> xr.DataArray: + """ + Worker function to retrieve data for a single sample and forecast step. + + Parameters + ---------- + args : Tuple containing (sample, fstep, run_id, stream, type). + + Returns + ------- + xarray DataArray for the specified sample and forecast step. + """ + sample, fstep, run_id, stream, dtype, epoch, rank = args + fname_zarr = get_model_results(run_id, epoch, rank) + with ZarrIO(fname_zarr) as zio: + out = zio.get_data(sample, stream, fstep) + if dtype == "target": + data = out.target + elif dtype == "prediction": + data = out.prediction + return data diff --git a/packages/evaluate/src/weathergen/evaluate/export/reshape.py b/packages/evaluate/src/weathergen/evaluate/export/reshape.py new file mode 100644 index 000000000..b4122ace2 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/export/reshape.py @@ -0,0 +1,135 @@ +import logging +import re + +import numpy as np +import xarray as xr + +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +""" +Enhanced functions to handle Gaussian grids when converting from Zarr to NetCDF. +""" + + +def detect_grid_type(data: xr.DataArray) -> str: + """ + Detect whether data is on a regular lat/lon grid or Gaussian grid. + + Parameters + ---------- + data: + input dataset. + + Returns + ------- + str: + String with the grid type. + Supported options at the moment: "unknown", "regular", "gaussian" + """ + if "lat" not in data.coords or "lon" not in data.coords: + return "unknown" + + lats = data.coords["lat"].values + lons = data.coords["lon"].values + + unique_lats = np.unique(lats) + unique_lons = np.unique(lons) + + # Check if all (lat, lon) combinations exist (regular grid) + if len(lats) == len(unique_lats) * len(unique_lons): + lat_lon_pairs = set(zip(lats, lons, strict=False)) + expected_pairs = {(lat, lon) for lat in unique_lats for lon in unique_lons} + if lat_lon_pairs == expected_pairs: + return "regular" + + # Otherwise it's Gaussian (irregular spacing or reduced grid) + return "gaussian" + + +def find_pl(vars: list) -> tuple[dict[str, list[str]], list[int]]: + """ + Find all the pressure levels for each variable using regex and returns a dictionary + mapping variable names to their corresponding pressure levels. + + Parameters + ---------- + vars : list of variable names with pressure levels (e.g.,'q_500','t_2m'). + + Returns + ------- + A tuple containing: + - var_dict: dict + Dictionary mapping variable names to lists of their corresponding pressure levels. + - pl: list of int + List of unique pressure levels found in the variable names. + """ + var_dict = {} + pl = [] + for var in vars: + match = re.search(r"^([a-zA-Z0-9_]+)_(\d+)$", var) + if match: + var_name = match.group(1) + pressure_level = int(match.group(2)) + pl.append(pressure_level) + var_dict.setdefault(var_name, []).append(var) + else: + var_dict.setdefault(var, []).append(var) + pl = list(set(pl)) + return var_dict, pl + + +def reshape_dataset_adaptive(data: xr.DataArray) -> xr.Dataset: + """ + Reshape dataset while preserving grid structure (regular or Gaussian). + + Parameters + ---------- + data : xr.DataArray + Input data with dimensions (ipoint, channel) + + Returns + ------- + xr.Dataset + Reshaped dataset appropriate for the grid type + """ + grid_type = detect_grid_type(data) + + # Original logic + var_dict, pl = find_pl(data.channel.values) + data_vars = {} + + for new_var, old_vars in var_dict.items(): + if len(old_vars) > 1: + data_vars[new_var] = xr.DataArray( + data.sel(channel=old_vars).values, + dims=["ipoint", "pressure_level"], + ) + else: + data_vars[new_var] = xr.DataArray( + data.sel(channel=old_vars[0]).values, + dims=["ipoint"], + ) + + reshaped_dataset = xr.Dataset(data_vars) + reshaped_dataset = reshaped_dataset.assign_coords( + ipoint=data.coords["ipoint"], + pressure_level=pl, + ) + + if grid_type == "regular": + # Use original reshape logic for regular grids + # This is safe for regular grids + reshaped_dataset = reshaped_dataset.set_index(ipoint=("valid_time", "lat", "lon")).unstack( + "ipoint" + ) + else: + # Use new logic for Gaussian/unstructured grids + reshaped_dataset = reshaped_dataset.set_index(ipoint2=("ipoint", "valid_time")).unstack( + "ipoint2" + ) + # rename ipoint to ncells + reshaped_dataset = reshaped_dataset.rename_dims({"ipoint": "ncells"}) + reshaped_dataset = reshaped_dataset.rename_vars({"ipoint": "ncells"}) + + return reshaped_dataset diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py deleted file mode 100755 index bda1c6cd3..000000000 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ /dev/null @@ -1,738 +0,0 @@ -#!/usr/bin/env -S uv run -# /// script -# dependencies = [ -# "weathergen-evaluate", -# "weathergen-common", -# "weathergen" -# ] -# [tool.uv.sources] -# weathergen-evaluate = { path = "../../../../../packages/evaluate" } -# weathergen-common = { path = "../../../../../packages/common" } -# weathergen = { path = "../../../../../" } -# /// -## Example USAGE: uv run export --run-id grwnhykd --stream ERA5 \ -## --output-dir /p/home/jusers/owens1/jureca/WeatherGen/test_output1 \ -## --format netcdf --type prediction target --fsteps 1 --samples 1 -import argparse -import logging -import re -import sys -from multiprocessing import Pool -from pathlib import Path - -import numpy as np -import xarray as xr -from omegaconf import OmegaConf -from tqdm import tqdm - -from weathergen.common.config import _REPO_ROOT, get_model_results -from weathergen.common.io import ZarrIO - -_logger = logging.getLogger(__name__) -_logger.setLevel(logging.INFO) - -if not _logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - handler.setFormatter(formatter) - _logger.addHandler(handler) - -""" -Enhanced functions to handle Gaussian grids when converting from Zarr to NetCDF. -""" - - -def detect_grid_type(input_data_array: xr.DataArray) -> str: - """Detect whether data is on a regular lat/lon grid or Gaussian grid.""" - if "lat" not in input_data_array.coords or "lon" not in input_data_array.coords: - return "unknown" - - lats = input_data_array.coords["lat"].values - lons = input_data_array.coords["lon"].values - - unique_lats = np.unique(lats) - unique_lons = np.unique(lons) - - # Check if all (lat, lon) combinations exist (regular grid) - if len(lats) == len(unique_lats) * len(unique_lons): - lat_lon_pairs = set(zip(lats, lons, strict=False)) - expected_pairs = {(lat, lon) for lat in unique_lats for lon in unique_lons} - if lat_lon_pairs == expected_pairs: - return "regular" - - # Otherwise it's Gaussian (irregular spacing or reduced grid) - return "gaussian" - - -def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]: - """ - Find all the pressure levels for each variable using regex and returns a dictionary - mapping variable names to their corresponding pressure levels. - - Parameters - ---------- - all_variables : list of variable names with pressure levels (e.g.,'q_500','t_2m'). - - Returns - ------- - A tuple containing: - - var_dict: dict - Dictionary mapping variable names to lists of their corresponding pressure levels. - - pl: list of int - List of unique pressure levels found in the variable names. - """ - var_dict = {} - pl = [] - for var in all_variables: - match = re.search(r"^([a-zA-Z0-9_]+)_(\d+)$", var) - if match: - var_name = match.group(1) - pressure_level = int(match.group(2)) - pl.append(pressure_level) - var_dict.setdefault(var_name, []).append(var) - else: - var_dict.setdefault(var, []).append(var) - pl = list(set(pl)) - return var_dict, pl - - -def reshape_dataset_adaptive(input_data_array: xr.DataArray) -> xr.Dataset: - """ - Reshape dataset while preserving grid structure (regular or Gaussian). - - Parameters - ---------- - input_data_array : xr.DataArray - Input data with dimensions (ipoint, channel) - - Returns - ------- - xr.Dataset - Reshaped dataset appropriate for the grid type - """ - grid_type = detect_grid_type(input_data_array) - - # Original logic - var_dict, pl = find_pl(input_data_array.channel.values) - data_vars = {} - - for new_var, old_vars in var_dict.items(): - if len(old_vars) > 1: - data_vars[new_var] = xr.DataArray( - input_data_array.sel(channel=old_vars).values, - dims=["ipoint", "pressure_level"], - ) - else: - data_vars[new_var] = xr.DataArray( - input_data_array.sel(channel=old_vars[0]).values, - dims=["ipoint"], - ) - - reshaped_dataset = xr.Dataset(data_vars) - reshaped_dataset = reshaped_dataset.assign_coords( - ipoint=input_data_array.coords["ipoint"], - pressure_level=pl, - ) - - if grid_type == "regular": - # Use original reshape logic for regular grids - # This is safe for regular grids - reshaped_dataset = reshaped_dataset.set_index(ipoint=("valid_time", "lat", "lon")).unstack( - "ipoint" - ) - else: - # Use new logic for Gaussian/unstructured grids - reshaped_dataset = reshaped_dataset.set_index(ipoint2=("ipoint", "valid_time")).unstack( - "ipoint2" - ) - # rename ipoint to ncells - reshaped_dataset = reshaped_dataset.rename_dims({"ipoint": "ncells"}) - reshaped_dataset = reshaped_dataset.rename_vars({"ipoint": "ncells"}) - - return reshaped_dataset - - -def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) -> xr.Dataset: - """ - Add Gaussian grid metadata following CF conventions. - - Parameters - ---------- - ds : xr.Dataset - Dataset to add metadata to - grid_info : dict, optional - Dictionary with grid information: - - 'N': Gaussian grid number (e.g., N320) - - 'reduced': Whether it's a reduced Gaussian grid - - Returns - ------- - xr.Dataset - Dataset with added grid metadata - """ - ds = ds.copy() - # Add grid mapping information - ds.attrs["grid_type"] = "gaussian" - - # If grid info provided, add it - if grid_info: - ds.attrs["gaussian_grid_number"] = grid_info.get("N", "unknown") - ds.attrs["gaussian_grid_type"] = "reduced" if grid_info.get("reduced", False) else "regular" - - return ds - - -def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: - """ - Add CF conventions to the dataset attributes. - Parameters - ---------- - stream : Stream name to include in the title attribute. - run_id : Run ID to include in the title attribute. - ds : Input xarray Dataset to add conventions to. - Returns - ------- - xarray Dataset with CF conventions added to attributes. - """ - ds = ds.copy() - ds.attrs["title"] = f"WeatherGenerator Output for {run_id} using stream {stream}" - ds.attrs["institution"] = "WeatherGenerator Project" - ds.attrs["source"] = "WeatherGenerator v0.0" - ds.attrs["history"] = ( - "Created using the export_inference.py script on " - + np.datetime_as_string(np.datetime64("now"), unit="s") - ) - ds.attrs["Conventions"] = "CF-1.12" - return ds - - -def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: - """ - Modified CF parser that handles both regular and Gaussian grids. - - Parameters - ---------- - config : OmegaConf - Configuration for CF parsing - ds : xr.Dataset - Input dataset - - Returns - ------- - xr.Dataset - Parsed dataset with appropriate structure for grid type - """ - # Detect if this is a Gaussian grid - is_gaussian = "ncells" in ds.dims - - variables = {} - mapping = config["variables"] - - # Handle dimensions based on grid type - if is_gaussian: - # For Gaussian grids, keep ncells and don't try to create lat/lon dimensions - for var_name in ds.data_vars: - if var_name in ["lat", "lon"]: - continue - - variable = ds[var_name] - - if var_name not in mapping: - # Variable not in mapping - skip or keep as-is - variables[var_name] = variable - continue - - dims = list(variable.dims) - - attributes = dict( - standard_name=mapping[var_name].get("std", var_name), - units=mapping[var_name].get("std_unit", "unknown"), - coordinates="lat lon", # Mark auxiliary coordinates - ) - - # Get mapped variable name or use original - mapped_name = mapping[var_name].get("var", var_name) - - variables[mapped_name] = xr.DataArray( - data=variable.values, - dims=dims, - coords={coord: ds.coords[coord] for coord in variable.coords if coord in ds.coords}, - attrs=attributes, - name=mapped_name, - ) - - # Preserve lat/lon as coordinate variables with proper attributes - if "lat" in ds.coords: - ds.coords["lat"].attrs = { - "standard_name": "latitude", - "long_name": "latitude", - "units": "degrees_north", - } - if "lon" in ds.coords: - ds.coords["lon"].attrs = { - "standard_name": "longitude", - "long_name": "longitude", - "units": "degrees_east", - } - - else: - # Original logic for regular grids - ds_attributes = {} - for dim_name, dim_dict in config["dimensions"].items(): - if dim_name == dim_dict["wg"]: - dim_attributes = dict(standard_name=dim_dict.get("std", None)) - if dim_dict.get("std_unit", None) is not None: - dim_attributes["units"] = dim_dict["std_unit"] - ds_attributes[dim_dict["wg"]] = dim_attributes - continue - - if dim_name in ds.dims: - ds = ds.rename_dims({dim_name: dim_dict["wg"]}) - - dim_attributes = dict(standard_name=dim_dict.get("std", None)) - if "std_unit" in dim_dict and dim_dict["std_unit"] is not None: - dim_attributes["units"] = dim_dict["std_unit"] - ds_attributes[dim_dict["wg"]] = dim_attributes - - for var_name in ds.data_vars: - dims = ["pressure", "valid_time", "latitude", "longitude"] - if mapping[var_name]["level_type"] == "sfc": - dims.remove("pressure") - - coordinates = {} - for coord, new_name in config["coordinates"][mapping[var_name]["level_type"]].items(): - coordinates |= { - new_name: ( - ds.coords[coord].dims, - ds.coords[coord].values, - ds_attributes[new_name], - ) - } - - variable = ds[var_name] - attributes = dict( - standard_name=mapping[var_name]["std"], - units=mapping[var_name]["std_unit"], - ) - - variables[mapping[var_name]["var"]] = xr.DataArray( - data=variable.values, - dims=dims, - coords={**coordinates, "valid_time": ds["valid_time"].values}, - attrs=attributes, - name=mapping[var_name]["var"], - ) - - dataset = xr.merge(variables.values()) - dataset.attrs = ds.attrs - - return dataset - - -def output_filename( - prefix: str, - run_id: str, - output_dir: str, - output_format: str, - forecast_ref_time: np.datetime64, -) -> Path: - """ - Generate output filename based on prefix (should refer to type e.g. pred/targ), run_id, sample - index, output directory, format and forecast_ref_time. - - Parameters - ---------- - prefix : Prefix for file name (e.g., 'pred' or 'targ'). - run_id :Run ID to include in the filename. - output_dir : Directory to save the output file. - output_format : Output file format (currently only 'netcdf' supported). - forecast_ref_time : Forecast reference time to include in the filename. - - Returns - ------- - Full path to the output file. - """ - if output_format not in ["netcdf"]: - raise ValueError( - f"Unsupported output format: {output_format}, supported formates are ['netcdf']" - ) - file_extension = "nc" - frt = np.datetime_as_string(forecast_ref_time, unit="h") - out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" - return out_fname - - -def get_data_worker(args: tuple) -> xr.DataArray: - """ - Worker function to retrieve data for a single sample and forecast step. - - Parameters - ---------- - args : Tuple containing (sample, fstep, run_id, stream, type). - - Returns - ------- - xarray DataArray for the specified sample and forecast step. - """ - sample, fstep, run_id, stream, dtype, mini_epoch, rank = args - fname_zarr = get_model_results(run_id, mini_epoch, rank) - with ZarrIO(fname_zarr) as zio: - out = zio.get_data(sample, stream, fstep) - if dtype == "target": - data = out.target - elif dtype == "prediction": - data = out.prediction - return data - - -def get_data( - run_id: str, - samples: list, - stream: str, - dtype: str, - fsteps: list, - channels: list, - fstep_hours: int, - n_processes: list, - mini_epoch: int, - rank: int, - output_dir: str, - output_format: str, - config: OmegaConf, -) -> None: - """ - Retrieve data from Zarr store and save one sample to each NetCDF file. - Using multiprocessing to speed up data retrieval. - - Parameters - ---------- - run_id : str - Run ID to identify the Zarr store. - samples : list - Sample to process - stream : str - Stream name to retrieve data for (e.g., 'ERA5'). - dtype : str - Type of data to retrieve ('target' or 'prediction'). - fsteps : list - List of forecast steps to retrieve. If None, retrieves all available forecast steps. - channels : list - List of channels to retrieve. If None, retrieves all available channels. - n_processes : list - Number of parallel processes to use for data retrieval. - mini_epoch : int - Mini_epoch number to identify the Zarr store. - rank : int - Rank number to identify the Zarr store. - output_dir : str - Directory to save the NetCDF files. - output_format : str - Output file format (currently only 'netcdf' supported). - config : OmegaConf - Loaded config for cf_parser function. - """ - if dtype not in ["target", "prediction"]: - raise ValueError(f"Invalid type: {dtype}. Must be 'target' or 'prediction'.") - - fname_zarr = get_model_results(run_id, mini_epoch, rank) - with ZarrIO(fname_zarr) as zio: - zio_forecast_steps = sorted([int(step) for step in zio.forecast_steps]) - zio_samples = sorted([int(sample) for sample in zio.samples]) - dummy_out = zio.get_data(0, stream, zio_forecast_steps[0]) - all_channels = dummy_out.target.channels - channels = all_channels if channels is None else channels - - fsteps = zio_forecast_steps if fsteps is None else sorted([int(fstep) for fstep in fsteps]) - - samples = ( - zio_samples - if samples is None - else sorted([int(sample) for sample in samples if sample in samples]) - ) - with Pool(processes=n_processes, maxtasksperchild=5) as pool: - for sample_idx in tqdm(samples): - da_fs = [] - step_tasks = [ - (sample_idx, fstep, run_id, stream, dtype, mini_epoch, rank) for fstep in fsteps - ] - for result in tqdm( - pool.imap_unordered(get_data_worker, step_tasks, chunksize=1), - total=len(step_tasks), - desc=f"Processing {run_id} - stream: {stream} - sample: {sample_idx}", - ): - if result is not None: - # Select only requested channels - result = result.as_xarray().squeeze() - if set(channels) != set(all_channels): - available_channels = result.channel.values - existing_channels = [ch for ch in channels if ch in available_channels] - if len(existing_channels) < len(channels): - _logger.info( - f"The following channels were not found: " - f"{list(set(channels) - set(existing_channels))}. Skipping them." - ) - result = result.sel(channel=existing_channels) - # reshape result - use adaptive function to handle both regular and Gaussian - # grids - result = reshape_dataset_adaptive(result) - da_fs.append(result) - - _logger.info(f"Retrieved {len(da_fs)} forecast steps for type {dtype}.") - _logger.info( - f"Saving sample {sample_idx} data to {output_format} format in {output_dir}." - ) - - save_sample_to_netcdf( - str(dtype)[:4], - da_fs, - fstep_hours, - run_id, - output_dir, - output_format, - config, - ) - pool.terminate() - pool.join() - - -def save_sample_to_netcdf( - type_str, - array_list, - fstep_hours, - run_id, - output_dir, - output_format, - config, -) -> None: - """ - Uses list of pred/target xarray DataArrays to save one sample to a NetCDF file. - - Parameters - ---------- - type_str : str - Type of data ('pred' or 'targ') to include in the filename. - dict_sample_all_steps : dict - Dictionary where keys is sample index and values is a list of xarray DataArrays - for all the forecast steps - fstep_hours : np.timedelta64 - Time difference between forecast steps (e.g., 6 hours). - run_id : str - Run ID to include in the filename. - output_dir : str - Directory to save the NetCDF files. - output_format : str - Output file format (currently only 'netcdf' supported). - config : OmegaConf - Loaded config for cf_parser function. - """ - # find forecast_ref_time - frt = array_list[0].valid_time.values[0] - fstep_hours * int(array_list[0].forecast_step.values) - out_fname = output_filename(type_str, run_id, output_dir, output_format, frt) - # check if file already exists - if out_fname.exists(): - _logger.info(f"File {out_fname} already exists. Skipping.") - else: - sample_all_steps = xr.concat( - array_list, - dim="valid_time", - data_vars="minimal", - coords="different", - compat="equals", - combine_attrs="drop", - ).sortby("valid_time") - _logger.info(f"Saving to {out_fname}.") - sample_all_steps = sample_all_steps.assign_coords(forecast_ref_time=frt) - stream = str(sample_all_steps.coords["stream"].values) - - if "sample" in sample_all_steps.coords: - sample_all_steps = sample_all_steps.drop_vars("sample") - - sample_all_steps = cf_parser_gaussian_aware(config, sample_all_steps) - # Add Gaussian grid metadata if detected - if "ncells" in sample_all_steps.dims: - sample_all_steps = add_gaussian_grid_metadata(sample_all_steps) - _logger.info("Detected and preserved Gaussian grid structure") - # add forecast_period attributes - n_hours = fstep_hours.astype("int64") - sample_all_steps["forecast_period"] = sample_all_steps["forecast_step"] * n_hours - sample_all_steps["forecast_period"].attrs = { - "standard_name": "forecast_period", - "long_name": "time since forecast_reference_time", - "units": "hours", - } - sample_all_steps = add_conventions(stream, run_id, sample_all_steps) - sample_all_steps.to_netcdf(out_fname, mode="w", compute=False) - - -def parse_args(args: list) -> argparse.Namespace: - """ - Parse command line arguments. - - Parameters - ---------- - args : List of command line arguments. - Returns - ------- - Parsed command line arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument( - "--run-id", - type=str, - help=" Zarr folder which contains target and inference results", - required=True, - ) - - parser.add_argument( - "--type", - type=str, - choices=["prediction", "target"], - nargs="+", - help="List of type of data to convert (e.g. prediction target)", - required=True, - ) - - parser.add_argument( - "--output-dir", - type=str, - help="Output directory to save the NetCDF files", - required=True, - ) - - parser.add_argument( - "--format", - type=str, - choices=["netcdf", "grib"], - help="Output file format (currently only netcdf supported)", - required=True, - ) - - parser.add_argument( - "--stream", - type=str, - choices=["ERA5"], - help="Stream name to retrieve data for", - required=True, - ) - - parser.add_argument( - "--fsteps", - type=int, - nargs="+", - default=None, - help="List of forecast steps to retrieve (e.g. 1 2 3). " - "If not provided, retrieves all available forecast steps.", - ) - - parser.add_argument( - "--samples", - type=int, - nargs="+", - default=None, - help="List of samples to process (e.g. 0 1 2). If not provided, processes all samples.", - ) - - parser.add_argument( - "--channels", - type=str, - nargs="+", - default=None, - help="List of channels to retrieve (e.g., 'q_500 t_2m'). " - "If not provided, retrieves all available channels.", - ) - - parser.add_argument( - "--n-processes", - type=int, - default=8, - help="Number of parallel processes to use for data retrieval", - ) - - parser.add_argument( - "--fstep-hours", - type=int, - default=6, - help="Time difference between forecast steps in hours (e.g., 6)", - ) - - parser.add_argument( - "--mini_epoch", - type=int, - default=0, - help="mini_epoch number to identify the Zarr store", - ) - - parser.add_argument( - "--rank", - type=int, - default=0, - help="Rank number to identify the Zarr store", - ) - - args, unknown_args = parser.parse_known_args(args) - if unknown_args: - _logger.warning(f"Unknown arguments: {unknown_args}") - return args - - -def export() -> None: - """ - Main function to export data from Zarr store to NetCDF files. - """ - # By default, arguments from the command line are read. - export_from_args(sys.argv[1:]) - - -def export_from_args(args: list) -> None: - # Get run_id zarr data as lists of xarray DataArrays - """ - Export data from Zarr store to NetCDF files based on command line arguments. - Parameters - ---------- - args : List of command line arguments. - """ - args = parse_args(sys.argv[1:]) - run_id = args.run_id - data_type = args.type - output_dir = args.output_dir - output_format = args.format - samples = args.samples - stream = args.stream - fsteps = args.fsteps - fstep_hours = np.timedelta64(args.fstep_hours, "h") - channels = args.channels - n_processes = args.n_processes - mini_epoch = args.mini_epoch - rank = args.rank - - # Ensure output directory exists - out_dir = Path(output_dir) - out_dir.mkdir(parents=True, exist_ok=True) - - # Load configuration - config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") - config = OmegaConf.load(config_file) - # check config loaded correctly - assert len(config["variables"].keys()) > 0, "Config file not loaded correctly" - - for dtype in data_type: - _logger.info(f"Starting processing {dtype} for run ID {run_id}.") - get_data( - run_id, - samples, - stream, - dtype, - fsteps, - channels, - fstep_hours, - n_processes, - mini_epoch, - rank, - output_dir, - output_format, - config, - ) - _logger.info(f"Finished processing {dtype} for run ID {run_id}.") - - -if __name__ == "__main__": - export() diff --git a/pyproject.toml b/pyproject.toml index 250a53ebc..0f0f7a296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ train_continue = "weathergen.run_train:train_continue" inference = "weathergen.run_train:inference" evaluate = "weathergen.evaluate.run_evaluation:evaluate" plot_train = "weathergen.utils.plot_training:plot_train" -export = "weathergen.evaluate.export_inference:export" +export = "weathergen.evaluate.export.export_inference:export" [build-system] requires = ["hatchling"] From 224d4bdfff9a4eccd16dece3cc2affca28e77917 Mon Sep 17 00:00:00 2001 From: Simone Norberti <63310821+simone99n@users.noreply.github.com> Date: Tue, 11 Nov 2025 10:41:57 +0100 Subject: [PATCH 27/32] [1034][reader_extra] E-Obs datareader (#1228) * [1034] rebase * [1034] add dataloader * [1034] Zarr3-->Zarr2 * [1034] lint * [1034] lint * [1034] Moved to reader_extra * [1034] registry E-Obs --- .../readers_extra/data_reader_eobs.py | 415 ++++++++++++++++++ .../src/weathergen/readers_extra/registry.py | 4 + 2 files changed, 419 insertions(+) create mode 100644 packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py diff --git a/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py b/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py new file mode 100644 index 000000000..4f0157792 --- /dev/null +++ b/packages/readers_extra/src/weathergen/readers_extra/data_reader_eobs.py @@ -0,0 +1,415 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, + str_to_timedelta, +) + +_logger = logging.getLogger(__name__) + + +# TODO make this datareader works with multiple datasets in ZARR format +class DataReaderEObs(DataReaderTimestep): + """ + Data reader for gridded Zarr datasets with regular lat/lon structure. + + This reader handles datasets stored as Zarr with dimensions (time, latitude, longitude) + and converts the gridded data to point-wise format required by the framework. + + The reader implements lazy initialization to work efficiently with multiple dataloader workers. + """ + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + """ + Construct data reader for gridded Zarr dataset. + + Parameters + ---------- + tw_handler : TimeWindowHandler + Handler for time windows + filename : Path + Path to the Zarr dataset + stream_info : dict + Stream configuration containing channel selection and other metadata + + Returns + ------- + None + """ + # Store configuration but DO NOT open files here + self._filename = filename + self._tw_handler = tw_handler + self._stream_info = stream_info + + # Initialize data-dependent attributes to None + self.ds: xr.Dataset | None = None + self.len = 0 + self.source_channels = [] + self.source_idx = [] + self.target_channels = [] + self.target_idx = [] + self.geoinfo_channels = [] + self.geoinfo_idx = [] + self.properties = {} + + # Grid properties + self.latitudes: NDArray | None = None + self.longitudes: NDArray | None = None + self.n_lat: int = 0 + self.n_lon: int = 0 + self.n_points: int = 0 + + # Statistics + self.mean: NDArray | None = None + self.stdev: NDArray | None = None + + # Call super() with temporary values + super().__init__(self._tw_handler, self._stream_info) + + # Flag to ensure initialization happens only once per worker + self._initialized = False + + def _lazy_init(self) -> None: + """ + Initialize the dataset. Called once per worker process to ensure + proper handling of file handles across processes. + """ + if self._initialized: + return + + try: + # Open the Zarr dataset with xarray + self.ds = xr.open_zarr(self._filename, consolidated=True, chunks=None, zarr_format=2) + except Exception as e: + name = self._stream_info["name"] + _logger.error(f"Failed to open {name} at {self._filename}: {e}") + self.init_empty() + self._initialized = True + return + + # Extract time coordinate + time_coord = self.ds.coords["time"].values + data_start_time = np.datetime64(time_coord[0]) + data_end_time = np.datetime64(time_coord[-1]) + + # Check if dataset overlaps with requested time window + if self._tw_handler.t_start >= data_end_time or self._tw_handler.t_end <= data_start_time: + name = self._stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + self.init_empty() + self._initialized = True + return + + # Determine the period/frequency + if len(time_coord) > 1: + period = np.timedelta64(time_coord[1] - time_coord[0]) + else: + # Default to daily if only one timestep + period = np.timedelta64(1, "D") + + # Handle frequency override from stream_info + if "frequency" in self._stream_info: + period = str_to_timedelta(self._stream_info["frequency"]) + + # Re-initialize parent class with correct time info + super().__init__( + self._tw_handler, + self._stream_info, + data_start_time, + data_end_time, + period, + ) + + # Calculate valid time range indices + time_mask = (time_coord >= self._tw_handler.t_start) & (time_coord < self._tw_handler.t_end) + self.len = int(np.sum(time_mask)) + + if self.len <= 0: + self.init_empty() + self._initialized = True + return + + # Extract and validate spatial coordinates + self.latitudes = self.ds.coords["latitude"].values.astype(np.float32) + self.longitudes = self.ds.coords["longitude"].values.astype(np.float32) + + # Validate coordinate ranges + if np.any(self.latitudes < -90) or np.any(self.latitudes > 90): + _logger.warning( + f"Latitude values outside valid range [-90, 90] in stream " + f"'{self._stream_info['name']}'" + ) + self.latitudes = np.clip(self.latitudes, -90.0, 90.0) + + if np.any(self.longitudes < -180) or np.any(self.longitudes > 180): + _logger.warning( + f"Longitude values outside valid range [-180, 180] in stream " + f"'{self._stream_info['name']}'. Converting from [0, 360] format." + ) + self.longitudes = ((self.longitudes + 180.0) % 360.0 - 180.0).astype(np.float32) + + self.n_lat = len(self.latitudes) + self.n_lon = len(self.longitudes) + self.n_points = self.n_lat * self.n_lon + + # Identify available data variables (exclude coordinate and statistics variables) + available_vars = [ + var + for var in self.ds.data_vars + if not var.endswith("_mean") + and not var.endswith("_std") + and "time" in self.ds[var].dims + ] + + # Select source channels + source_channels_filter = self._stream_info.get("source") + source_exclude = self._stream_info.get("source_exclude", []) + self.source_channels, self.source_idx = self._select_channels( + available_vars, source_channels_filter, source_exclude + ) + + # Select target channels + target_channels_filter = self._stream_info.get("target") + target_exclude = self._stream_info.get("target_exclude", []) + self.target_channels, self.target_idx = self._select_channels( + available_vars, target_channels_filter, target_exclude + ) + + # No geoinfo channels for gridded data + self.geoinfo_channels = [] + self.geoinfo_idx = [] + + # Get target channel weights + self.target_channel_weights = self.parse_target_channel_weights() + + # Load or compute statistics + all_channels = sorted(set(self.source_channels + self.target_channels)) + self._load_statistics(all_channels) + + # Log configuration + ds_name = self._stream_info["name"] + _logger.info(f"{ds_name}: source channels: {self.source_channels}") + _logger.info(f"{ds_name}: target channels: {self.target_channels}") + _logger.info(f"{ds_name}: grid shape: {self.n_lat} x {self.n_lon}") + + self.properties = { + "stream_id": self._stream_info.get("id", 0), + } + + self._initialized = True + + def _select_channels( + self, + available_vars: list[str], + include_filters: list[str] | None, + exclude_filters: list[str] | None = None, + ) -> tuple[list[str], list[int]]: + """ + Select channels based on include/exclude filters. + + Parameters + ---------- + available_vars : list[str] + List of available variable names + include_filters : list[str] | None + List of patterns to include (None means include all) + exclude_filters : list[str] | None + List of patterns to exclude + + Returns + ------- + tuple[list[str], list[int]] + Selected channel names and their indices + """ + if exclude_filters is None: + exclude_filters = [] + + selected = [] + for var in available_vars: + # Check inclusion + if include_filters is not None: + if not any(f in var or f == var for f in include_filters): + continue + + # Check exclusion + if any(f in var for f in exclude_filters): + continue + + selected.append(var) + + # Return channels and their indices in the original list + indices = [available_vars.index(ch) for ch in selected] + return selected, indices + + def _load_statistics(self, channels: list[str]) -> None: + """ + Load or compute statistics (mean and standard deviation) for channels. + + Parameters + ---------- + channels : list[str] + List of channel names for which to load statistics + """ + means = [] + stds = [] + + for ch in channels: + # Try to load pre-computed statistics + mean_var = f"{ch}_mean" + std_var = f"{ch}_std" + + if mean_var in self.ds.data_vars: + mean = float(self.ds[mean_var].values) + else: + _logger.warning( + f"No pre-computed mean for {ch}, using 0.0. " + "Consider computing statistics offline." + ) + mean = 0.0 + + if std_var in self.ds.data_vars: + std = float(self.ds[std_var].values) + else: + _logger.warning( + f"No pre-computed std for {ch}, using 1.0. " + "Consider computing statistics offline." + ) + std = 1.0 + + means.append(mean) + stds.append(std) + + self.mean = np.array(means, dtype=np.float32) + self.stdev = np.array(stds, dtype=np.float32) + + # Avoid division by zero + self.stdev[self.stdev <= 1e-5] = 1.0 + + @override + def init_empty(self) -> None: + """Initialize an empty reader.""" + super().init_empty() + self.ds = None + self.len = 0 + self.n_points = 0 + + @override + def length(self) -> int: + """Return the length of the dataset.""" + self._lazy_init() + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for a time window. + + Parameters + ---------- + idx : TIndex + Index of temporal window + channels_idx : list[int] + Selection of channel indices + + Returns + ------- + ReaderData + Data structure containing coords, geoinfos, data, and datetimes + """ + self._lazy_init() + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), + num_geo_fields=len(self.geoinfo_idx), + ) + + # Get the actual channel names + all_channels = sorted(set(self.source_channels + self.target_channels)) + selected_channels = [all_channels[i] for i in channels_idx] + + # Extract data for selected timesteps and channels + data_arrays = [] + datetimes_list = [] + + for t_idx in t_idxs: + if t_idx < 0 or t_idx >= len(self.ds.coords["time"]): + continue + + # Extract data for this timestep + timestep_data = [] + for ch in selected_channels: + # Load data using isel for efficient indexing + var_data = self.ds[ch].isel(time=t_idx).values.astype(np.float32) + # Flatten spatial dimensions (lat, lon) -> (n_points,) + var_data_flat = var_data.flatten() + timestep_data.append(var_data_flat) + + # Stack channels: (n_points, n_channels) + timestep_data = np.stack(timestep_data, axis=1) + data_arrays.append(timestep_data) + + # Get datetime for this timestep + dt = np.datetime64(self.ds.coords["time"].values[t_idx]) + datetimes_list.extend([dt] * self.n_points) + + if len(data_arrays) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), + num_geo_fields=len(self.geoinfo_idx), + ) + + # Concatenate all timesteps: (n_timesteps * n_points, n_channels) + data = np.vstack(data_arrays) + + # Create coordinate grid + lon_grid, lat_grid = np.meshgrid(self.longitudes, self.latitudes) + coords_single = np.stack([lat_grid.flatten(), lon_grid.flatten()], axis=1).astype( + np.float32 + ) + + # Repeat coordinates for each timestep + coords = np.tile(coords_single, (len(t_idxs), 1)) + + # Empty geoinfos + geoinfos = np.zeros((len(data), 0), dtype=np.float32) + + # Convert datetimes to numpy array + datetimes = np.array(datetimes_list, dtype="datetime64[ns]") + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + + check_reader_data(rd, dtr) + + return rd diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 761628944..27ff2c101 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -20,5 +20,9 @@ def get_extra_reader(name: str, cf: Config) -> object | None: from weathergen.readers_extra.data_reader_icon import DataReaderIcon return ReaderEntry(cf.data_path_icon, DataReaderIcon) + case "eobs": + from weathergen.readers_extra.data_reader_eobs import DataReaderEObs + + return ReaderEntry(cf.data_path_eobs, DataReaderEObs) case _: return None From ab2e5d8c2035cf26e8b5750881e2bba32149f0e4 Mon Sep 17 00:00:00 2001 From: TillHae Date: Thu, 30 Oct 2025 15:16:00 +0100 Subject: [PATCH 28/32] training progress unit realignment from epoch to mini_epoch --- .../common/src/weathergen/common/config.py | 30 +++++++++++++------ .../src/weathergen/evaluate/io_reader.py | 10 ++++++- src/weathergen/model/model.py | 4 +++ src/weathergen/train/trainer.py | 4 +++ 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 195f7394c..5a0d521a2 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -61,7 +61,7 @@ def save(config: Config, mini_epoch: int | None): dirname = path_models / config.run_id dirname.mkdir(exist_ok=True, parents=True) - fname = dirname / _get_model_config_file_name(config.run_id, mini_epoch) + fname = _get_model_config_file_name(path_models, config.run_id, mini_epoch) json_str = json.dumps(OmegaConf.to_container(config)) with fname.open("w") as f: @@ -84,7 +84,7 @@ def load_model_config(run_id: str, mini_epoch: int | None, model_path: str | Non config=pconf, attribute_name="model_path", fallback="models" ) path = Path(model_path) - fname = path / run_id / _get_model_config_file_name(run_id, mini_epoch) + fname = _get_model_config_file_name(path, run_id, mini_epoch) assert fname.exists(), ( "The fallback path to the model does not exist. Please provide a `model_path`.", fname, @@ -100,14 +100,15 @@ def load_model_config(run_id: str, mini_epoch: int | None, model_path: str | Non return _apply_fixes(config) -def _get_model_config_file_name(run_id: str, mini_epoch: int | None): +def _get_model_config_file_name(path: pathlib.Path, run_id: str, mini_epoch: int | None): if mini_epoch is None: mini_epoch_str = "" elif mini_epoch == -1: mini_epoch_str = "_latest" - else: - mini_epoch_str = f"_chkpt{mini_epoch:05d}" - return f"model_{run_id}{mini_epoch_str}.json" + elif (path / run_id / f"model_{run_id}_chkpt{mini_epoch:05d}.json").exists(): + return path / run_id / f"model_{run_id}_chkpt{mini_epoch:05d}.json" + + return path / run_id / f"model_{run_id}_epoch{mini_epoch:05d}.json" def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: @@ -115,9 +116,20 @@ def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: Get the path to the model results zarr store from a given run_id and mini_epoch. """ run_results = Path(_load_private_conf(None)["path_shared_working_dir"]) / f"results/{run_id}" - zarr_path = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr" - if not zarr_path.exists() or not zarr_path.is_dir(): - raise FileNotFoundError(f"Zarr file {zarr_path} does not exist or is not a directory.") + + zarr_path_new = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr" + zarr_path_old = run_results / f"validation_epoch{mini_epoch:05d}_rank{rank:04d}.zarr" + + if zarr_path_new.exists() or zarr_path_new.is_dir(): + zarr_path = zarr_path_new + elif zarr_path_old.exists() or zarr_path_old.is_dir(): + zarr_path = zarr_path_old + else: + raise FileNotFoundError( + f"Zarr file with run_id {run_id}, mini_epoch {mini_epoch} and rank {rank} does not " + f"exist or is not a directory." + ) + return zarr_path diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 61bffef92..23f0f8a98 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -498,9 +498,17 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") ) - self.fname_zarr = self.results_dir.joinpath( + fname_zarr_new = self.results_dir.joinpath( f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" ) + fname_zarr_old = self.results_dir.joinpath( + f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" + ) + + if fname_zarr_new.exists(): + self.fname_zarr = fname_zarr_new + + self.fname_zarr = fname_zarr_old if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): _logger.error(f"Zarr file {self.fname_zarr} does not exist.") diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f5b2bac00..8a4524c1a 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -562,6 +562,10 @@ def load(self, run_id: str, mini_epoch: str = -1) -> None: ) filename = f"{run_id}_{mini_epoch_id}.chkpt" + if not (path_run / filename).exists(): + mini_epoch_id = f"epoch{mini_epoch:05d}" + filename = f"{run_id}_{mini_epoch_id}.chkpt" + params = torch.load( path_run / filename, map_location=torch.device("cpu"), weights_only=True ) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 0c85df93a..01efdd6c6 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -770,6 +770,10 @@ def load_model(self, run_id: str, mini_epoch=-1): ) filename = f"{run_id}_{mini_epoch_id}.chkpt" + if not (path_run / filename).exists(): + mini_epoch_id = f"epoch{mini_epoch:05d}" + filename = f"{run_id}_{mini_epoch_id}.chkpt" + params = torch.load( path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True ) From d45d67e8d16a6ffffcea63a26f799c5e1a2541e0 Mon Sep 17 00:00:00 2001 From: TillHae Date: Thu, 13 Nov 2025 11:25:10 +0100 Subject: [PATCH 29/32] ruffed --- packages/common/src/weathergen/common/config.py | 10 ++++++---- packages/evaluate/src/weathergen/evaluate/io_reader.py | 2 +- .../src/weathergen/readers_extra/registry.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 5a0d521a2..0dea9773a 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -100,15 +100,17 @@ def load_model_config(run_id: str, mini_epoch: int | None, model_path: str | Non return _apply_fixes(config) -def _get_model_config_file_name(path: pathlib.Path, run_id: str, mini_epoch: int | None): +def _get_model_config_file_name(path: Path, run_id: str, mini_epoch: int | None): if mini_epoch is None: mini_epoch_str = "" elif mini_epoch == -1: mini_epoch_str = "_latest" elif (path / run_id / f"model_{run_id}_chkpt{mini_epoch:05d}.json").exists(): - return path / run_id / f"model_{run_id}_chkpt{mini_epoch:05d}.json" - - return path / run_id / f"model_{run_id}_epoch{mini_epoch:05d}.json" + mini_epoch_str = f"_chkpt{mini_epoch:05d}" + else: + mini_epoch_str = f"_epoch{mini_epoch:05d}" + + return path / run_id / f"model_{run_id}{mini_epoch_str}.json" def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path: diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 23f0f8a98..c55229aa3 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -507,7 +507,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if fname_zarr_new.exists(): self.fname_zarr = fname_zarr_new - + self.fname_zarr = fname_zarr_old if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): diff --git a/packages/readers_extra/src/weathergen/readers_extra/registry.py b/packages/readers_extra/src/weathergen/readers_extra/registry.py index 27ff2c101..8920354b4 100644 --- a/packages/readers_extra/src/weathergen/readers_extra/registry.py +++ b/packages/readers_extra/src/weathergen/readers_extra/registry.py @@ -22,7 +22,7 @@ def get_extra_reader(name: str, cf: Config) -> object | None: return ReaderEntry(cf.data_path_icon, DataReaderIcon) case "eobs": from weathergen.readers_extra.data_reader_eobs import DataReaderEObs - + return ReaderEntry(cf.data_path_eobs, DataReaderEObs) case _: return None From 89f60c8aeb677a7de16a45997c340a200497e8fe Mon Sep 17 00:00:00 2001 From: TillHae Date: Thu, 13 Nov 2025 15:06:21 +0100 Subject: [PATCH 30/32] check if path is dir in io_reader --- packages/evaluate/src/weathergen/evaluate/io_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index c55229aa3..32981d276 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -505,7 +505,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr" ) - if fname_zarr_new.exists(): + if fname_zarr_new.exists() or fname_zarr_new.is_dir(): self.fname_zarr = fname_zarr_new self.fname_zarr = fname_zarr_old From ae981295905ce2580d04195440e6aedc84f89df8 Mon Sep 17 00:00:00 2001 From: TillHae Date: Fri, 14 Nov 2025 09:45:08 +0100 Subject: [PATCH 31/32] fix overwrite of fname_zarr in io_reader --- packages/evaluate/src/weathergen/evaluate/io_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 32981d276..66fb2602d 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -507,8 +507,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if fname_zarr_new.exists() or fname_zarr_new.is_dir(): self.fname_zarr = fname_zarr_new - - self.fname_zarr = fname_zarr_old + else: + self.fname_zarr = fname_zarr_old if not self.fname_zarr.exists() or not self.fname_zarr.is_dir(): _logger.error(f"Zarr file {self.fname_zarr} does not exist.") From 5eef1b9d901947ba7b0b842df82cef6c3d9d3e0b Mon Sep 17 00:00:00 2001 From: TillHae Date: Fri, 14 Nov 2025 09:47:44 +0100 Subject: [PATCH 32/32] add backward compatibility to config read --- packages/common/src/weathergen/common/config.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 0dea9773a..d173f3381 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -105,10 +105,10 @@ def _get_model_config_file_name(path: Path, run_id: str, mini_epoch: int | None) mini_epoch_str = "" elif mini_epoch == -1: mini_epoch_str = "_latest" - elif (path / run_id / f"model_{run_id}_chkpt{mini_epoch:05d}.json").exists(): - mini_epoch_str = f"_chkpt{mini_epoch:05d}" - else: + elif (path / run_id / f"model_{run_id}_epoch{mini_epoch:05d}.json").exists(): mini_epoch_str = f"_epoch{mini_epoch:05d}" + else: + mini_epoch_str = f"_chkpt{mini_epoch:05d}" return path / run_id / f"model_{run_id}{mini_epoch_str}.json" @@ -214,6 +214,12 @@ def load_config( # use OmegaConf.unsafe_merge if too slow c = OmegaConf.merge(base_config, private_config, *overwrite_configs) assert isinstance(c, Config) + + # Ensure the config has mini-epoch notation + if hasattr(c, "samples_per_epoch"): + c.samples_per_mini_epoch = c.samples_per_epoch + c.num_mini_epochs = c.num_epochs + return c