Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"same_strategy_per_batch": false
}

num_epochs: 32
samples_per_epoch: 4096
num_mini_epochs: 32
samples_per_mini_epoch: 4096
samples_per_validation: 512
shuffle: True

Expand Down
4 changes: 2 additions & 2 deletions config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ run_ids :
ar40mckx:
label: "pretrained model ar40mckx"
results_base_dir : "./results/"
epoch: 0
mini_epoch: 0
rank: 0
streams:
ERA5:
Expand Down Expand Up @@ -61,7 +61,7 @@ run_ids :
c8g5katp:
label: "2 steps window"
results_base_dir : "./results/"
epoch: 0
mini_epoch: 0
rank: 0
streams:
ERA5:
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/small1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ run_path: "./results"
model_path: "./models"
loss_fcts: [["mse", 1.0]]
loss_fcts_val: [["mse", 1.0]]
num_epochs: 1
samples_per_epoch: 10
num_mini_epochs: 1
samples_per_mini_epoch: 10
samples_per_validation: 5
lr_steps: 4
lr_steps_warmup: 2
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/small1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_train(setup, test_run_id):
def infer(run_id):
logger.info("run inference")
inference_from_args(
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"]
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"]
+ [
"--from_run_id",
run_id,
Expand All @@ -84,7 +84,7 @@ def infer(run_id):
def infer_with_missing(run_id):
logger.info("run inference")
inference_from_args(
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"]
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"]
+ [
"--from_run_id",
run_id,
Expand Down Expand Up @@ -128,7 +128,7 @@ def evaluate_results(run_id):
}
},
"label": "MTM ERA5",
"epoch": 0,
"mini_epoch": 0,
"rank": 0,
}
},
Expand Down Expand Up @@ -170,7 +170,7 @@ def assert_train_loss_below_threshold(run_id):
assert loss_metric is not None, (
"'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file"
)
# Check that the loss does not explode in a single epoch
# Check that the loss does not explode in a single mini_epoch
# This is meant to be a quick test, not a convergence test
target = 1.5
assert loss_metric < target, (
Expand All @@ -192,7 +192,7 @@ def assert_val_loss_below_threshold(run_id):
assert loss_metric is not None, (
"'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file"
)
# Check that the loss does not explode in a single epoch
# Check that the loss does not explode in a single mini_epoch
# This is meant to be a quick test, not a convergence test
assert loss_metric < 1.25, (
f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"
Expand Down
46 changes: 24 additions & 22 deletions packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ def format_cf(config: Config) -> str:
return stream.getvalue()


def save(config: Config, epoch: int | None):
def save(config: Config, mini_epoch: int | None):
"""Save current config into the current runs model directory."""
path_models = Path(config.model_path)
# save in directory with model files
dirname = path_models / config.run_id
dirname.mkdir(exist_ok=True, parents=True)

fname = dirname / _get_model_config_file_name(config.run_id, epoch)
fname = dirname / _get_model_config_file_name(config.run_id, mini_epoch)

json_str = json.dumps(OmegaConf.to_container(config))
with fname.open("w") as f:
f.write(json_str)


def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> Config:
def load_model_config(run_id: str, mini_epoch: int | None, model_path: str | None) -> Config:
"""
Load a configuration file from a given run_id and epoch.
Load a configuration file from a given run_id and mini_epoch.
If run_id is a full path, loads it from the full path.
"""
if Path(run_id).exists(): # load from the full path if a full path is provided
Expand All @@ -84,13 +84,13 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
config=pconf, attribute_name="model_path", fallback="models"
)
path = Path(model_path)
fname = path / run_id / _get_model_config_file_name(run_id, epoch)
fname = path / run_id / _get_model_config_file_name(run_id, mini_epoch)
assert fname.exists(), (
"The fallback path to the model does not exist. Please provide a `model_path`.",
fname,
)

_logger.info(f"Loading config from specified run_id and epoch: {fname}")
_logger.info(f"Loading config from specified run_id and mini_epoch: {fname}")

with fname.open() as f:
json_str = f.read()
Expand All @@ -100,22 +100,22 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
return _apply_fixes(config)


def _get_model_config_file_name(run_id: str, epoch: int | None):
if epoch is None:
epoch_str = ""
elif epoch == -1:
epoch_str = "_latest"
def _get_model_config_file_name(run_id: str, mini_epoch: int | None):
if mini_epoch is None:
mini_epoch_str = ""
elif mini_epoch == -1:
mini_epoch_str = "_latest"
else:
epoch_str = f"_epoch{epoch:05d}"
return f"model_{run_id}{epoch_str}.json"
mini_epoch_str = f"_chkpt{mini_epoch:05d}"
return f"model_{run_id}{mini_epoch_str}.json"


def get_model_results(run_id: str, epoch: int, rank: int) -> Path:
def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path:
"""
Get the path to the model results zarr store from a given run_id and epoch.
Get the path to the model results zarr store from a given run_id and mini_epoch.
"""
run_results = Path(_load_private_conf(None)["path_shared_working_dir"]) / f"results/{run_id}"
zarr_path = run_results / f"validation_epoch{epoch:05d}_rank{rank:04d}.zarr"
zarr_path = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this backwards-compatible? Something like (not tested)

zarr_path_old = run_results / f"validation_epoch{mini_epoch:05d}_rank{rank:04d}.zarr"
zarr_path_new = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr"
zarr_path = zarr_path_new if zarr_path_new.exists() else zarr_path_old

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if this method is still in use? Ideally the knowledge of the specifics of the path should be encapsulated in get_output_path and anywhere this information is required get_output_path should be used instead. A simple backward compatibility mechanism such as @MatKbauer describes can be then implemented there.

if not zarr_path.exists() or not zarr_path.is_dir():
raise FileNotFoundError(f"Zarr file {zarr_path} does not exist or is not a directory.")
return zarr_path
Expand Down Expand Up @@ -150,7 +150,7 @@ def _check_logging(config: Config) -> Config:
def load_config(
private_home: Path | None,
from_run_id: str | None,
epoch: int | None,
mini_epoch: int | None,
*overwrites: Path | dict | Config,
) -> Config:
"""
Expand All @@ -161,7 +161,7 @@ def load_config(
private_home: Configuration file containing platform dependent information and secretes
from_run_id: Run id of the pretrained WeatherGenerator model
to continue training or inference
epoch: epoch of the checkpoint to load. -1 indicates last checkpoint available.
mini_epoch: mini_epoch of the checkpoint to load. -1 indicates last checkpoint available.
*overwrites: Additional overwrites from different sources

Note: The order of precendence for merging the final config is in ascending order:
Expand Down Expand Up @@ -191,7 +191,9 @@ def load_config(
if from_run_id is None:
base_config = _load_default_conf()
else:
base_config = load_model_config(from_run_id, epoch, private_config.get("model_path", None))
base_config = load_model_config(
from_run_id, mini_epoch, private_config.get("model_path", None)
)
from_run_id = base_config.run_id
with open_dict(base_config):
base_config.from_run_id = from_run_id
Expand Down Expand Up @@ -456,9 +458,9 @@ def get_path_model(config: Config) -> Path:
return Path(config.model_path) / config.run_id


def get_path_output(config: Config, epoch: int) -> Path:
def get_path_output(config: Config, mini_epoch: int) -> Path:
base_path = get_path_run(config)
fname = f"validation_epoch{epoch:05d}_rank{config.rank:04d}.zarr"
fname = f"validation_chkpt{mini_epoch:05d}_rank{config.rank:04d}.zarr"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward compatibility?


return base_path / fname

Expand Down Expand Up @@ -523,7 +525,7 @@ def validate_forecast_policy_and_steps(cf: OmegaConf):
valid_forecast_policies = (
"Valid values for 'forecast_policy' are, e.g., 'fixed' when using constant "
"forecast steps throughout the training, or 'sequential' when varying the forecast "
"steps over epochs, such as, e.g., 'forecast_steps: [2, 2, 4, 4]'. "
"steps over mini_epochs, such as, e.g., 'forecast_steps: [2, 2, 4, 4]'. "
)
valid_forecast_steps = (
"'forecast_steps' must be a positive integer or a non-empty list of positive integers. "
Expand Down
26 changes: 14 additions & 12 deletions packages/evaluate/src/weathergen/evaluate/export_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def detect_grid_type(input_data_array: xr.DataArray) -> str:
# Otherwise it's Gaussian (irregular spacing or reduced grid)
return "gaussian"


def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]:
"""
Find all the pressure levels for each variable using regex and returns a dictionary
Expand Down Expand Up @@ -90,6 +91,7 @@ def find_pl(all_variables: list) -> tuple[dict[str, list[str]], list[int]]:
pl = list(set(pl))
return var_dict, pl


def reshape_dataset_adaptive(input_data_array: xr.DataArray) -> xr.Dataset:
"""
Reshape dataset while preserving grid structure (regular or Gaussian).
Expand Down Expand Up @@ -176,8 +178,6 @@ def add_gaussian_grid_metadata(ds: xr.Dataset, grid_info: dict | None = None) ->
return ds




def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset:
"""
Add CF conventions to the dataset attributes.
Expand All @@ -201,6 +201,7 @@ def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset:
ds.attrs["Conventions"] = "CF-1.12"
return ds


def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset:
"""
Modified CF parser that handles both regular and Gaussian grids.
Expand Down Expand Up @@ -323,6 +324,7 @@ def cf_parser_gaussian_aware(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset:

return dataset


def output_filename(
prefix: str,
run_id: str,
Expand Down Expand Up @@ -363,8 +365,8 @@ def get_data_worker(args: tuple) -> xr.DataArray:
-------
xarray DataArray for the specified sample and forecast step.
"""
sample, fstep, run_id, stream, dtype, epoch, rank = args
fname_zarr = get_model_results(run_id, epoch, rank)
sample, fstep, run_id, stream, dtype, mini_epoch, rank = args
fname_zarr = get_model_results(run_id, mini_epoch, rank)
with ZarrIO(fname_zarr) as zio:
out = zio.get_data(sample, stream, fstep)
if dtype == "target":
Expand All @@ -383,7 +385,7 @@ def get_data(
channels: list,
fstep_hours: int,
n_processes: list,
epoch: int,
mini_epoch: int,
rank: int,
output_dir: str,
output_format: str,
Expand All @@ -402,7 +404,7 @@ def get_data(
fsteps : List of forecast steps to retrieve. If None, retrieves all available forecast steps.
channels :List of channels to retrieve. If None, retrieves all available channels.
n_processes : Number of parallel processes to use for data retrieval.
ecpoch : Epoch number to identify the Zarr store.
mini_epoch : Mini_epoch number to identify the Zarr store.
rank : Rank number to identify the Zarr store.
output_dir : Directory to save the NetCDF files.
output_format : Output file format (currently only 'netcdf' supported).
Expand All @@ -411,7 +413,7 @@ def get_data(
if dtype not in ["target", "prediction"]:
raise ValueError(f"Invalid type: {dtype}. Must be 'target' or 'prediction'.")

fname_zarr = get_model_results(run_id, epoch, rank)
fname_zarr = get_model_results(run_id, mini_epoch, rank)
with ZarrIO(fname_zarr) as zio:
zio_forecast_steps = sorted([int(step) for step in zio.forecast_steps])
zio_samples = sorted([int(sample) for sample in zio.samples])
Expand All @@ -430,7 +432,7 @@ def get_data(
for sample_idx in tqdm(samples):
da_fs = []
step_tasks = [
(sample_idx, fstep, run_id, stream, dtype, epoch, rank) for fstep in fsteps
(sample_idx, fstep, run_id, stream, dtype, mini_epoch, rank) for fstep in fsteps
]
for result in tqdm(
pool.imap_unordered(get_data_worker, step_tasks, chunksize=1),
Expand Down Expand Up @@ -627,10 +629,10 @@ def parse_args(args: list) -> argparse.Namespace:
)

parser.add_argument(
"--epoch",
"--mini_epoch",
type=int,
default=0,
help="Epoch number to identify the Zarr store",
help="mini_epoch number to identify the Zarr store",
)

parser.add_argument(
Expand Down Expand Up @@ -673,7 +675,7 @@ def export_from_args(args: list) -> None:
fstep_hours = np.timedelta64(args.fstep_hours, "h")
channels = args.channels
n_processes = args.n_processes
epoch = args.epoch
mini_epoch = args.mini_epoch
rank = args.rank

# Ensure output directory exists
Expand All @@ -697,7 +699,7 @@ def export_from_args(args: list) -> None:
channels,
fstep_hours,
n_processes,
epoch,
mini_epoch,
rank,
output_dir,
output_format,
Expand Down
8 changes: 4 additions & 4 deletions packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non

super().__init__(eval_cfg, run_id, private_paths)

self.epoch = eval_cfg.epoch
self.mini_epoch = eval_cfg.mini_epoch
self.rank = eval_cfg.rank

# Load model configuration and set (run-id specific) directories
Expand Down Expand Up @@ -334,7 +334,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
)

self.fname_zarr = self.results_dir.joinpath(
f"validation_epoch{self.epoch:05d}_rank{self.rank:04d}.zarr"
f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward compatibility

)

if not self.fname_zarr.exists() or not self.fname_zarr.is_dir():
Expand All @@ -356,12 +356,12 @@ def get_inference_config(self):
_logger.info(
f"Loading config for run {self.run_id} from private paths: {self.private_paths}"
)
config = load_config(self.private_paths, self.run_id, self.epoch)
config = load_config(self.private_paths, self.run_id, self.mini_epoch)
else:
_logger.info(
f"Loading config for run {self.run_id} from model directory: {self.model_base_dir}"
)
config = load_model_config(self.run_id, self.epoch, self.model_base_dir)
config = load_model_config(self.run_id, self.mini_epoch, self.model_base_dir)

if type(config) not in [dict, oc.DictConfig]:
_logger.warning("Model config not found. inference config will be empty.")
Expand Down
Loading
Loading