diff --git a/src/gz21_ocean_momentum/cli/data.py b/src/gz21_ocean_momentum/cli/data.py index 6187a298..db082bc2 100755 --- a/src/gz21_ocean_momentum/cli/data.py +++ b/src/gz21_ocean_momentum/cli/data.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import gz21_ocean_momentum.step.data.lib as lib +import gz21_ocean_momentum.lib.data as lib import gz21_ocean_momentum.common.cli as cli from gz21_ocean_momentum.common.bounding_box import BoundingBox import gz21_ocean_momentum.common.bounding_box as bounding_box @@ -60,14 +60,15 @@ logger.debug("dropping irrelevant data variables...") surface_fields = surface_fields[["usurf", "vsurf"]] -if options.ntimes is not None: - logger.info(f"slicing {options.ntimes} time points...") - surface_fields = surface_fields.isel(time=slice(options.ntimes)) - logger.info("selecting input data bounding box...") surface_fields = bounding_box.bound_dataset("yu_ocean", "xu_ocean", surface_fields, bbox) grid = bounding_box.bound_dataset("yu_ocean", "xu_ocean", grid, bbox) +# TODO 2023-11-29 raehik: original bounded first, sliced (immediately) after +if options.ntimes is not None: + logger.info(f"slicing {options.ntimes} time points...") + surface_fields = surface_fields.isel(time=slice(options.ntimes)) + logger.debug("placing grid dataset into local memory...") grid = grid.compute() diff --git a/src/gz21_ocean_momentum/cli/infer.py b/src/gz21_ocean_momentum/cli/infer.py index 1e32cda7..696b7cd2 100755 --- a/src/gz21_ocean_momentum/cli/infer.py +++ b/src/gz21_ocean_momentum/cli/infer.py @@ -1,9 +1,11 @@ import configargparse +import gz21_ocean_momentum.common.cli as cli import logging from dask.diagnostics import ProgressBar from gz21_ocean_momentum.utils import TaskInfo +import gz21_ocean_momentum.lib.model as lib from gz21_ocean_momentum.data.datasets import ( pytorch_dataset_from_cm2_6_forcing_dataset, #DatasetPartitioner, @@ -13,27 +15,18 @@ ) import xarray as xr - import torch from torch.utils.data import DataLoader -# TODO hardcode submodel, transformation, NN loss function -# unlikely for a CLI we need to provide dynamic code loading -- let's just give -# options -# we could enable such "dynamic loading" in the "library" interface!-- but, due -# to the class-based setup, it's a little complicated for a user to come in with -# their own code for some of these, and it needs documentation. so a task for -# later import gz21_ocean_momentum.models.models1 as model import gz21_ocean_momentum.models.submodels as submodels import gz21_ocean_momentum.models.transforms as transforms import gz21_ocean_momentum.train.losses as loss_funcs from gz21_ocean_momentum.inference.utils import predict_lazy_cm2_6 -#from gz21_ocean_momentum.train.base import Trainer submodel = submodels.transform3 -DESCRIPTION = """ +_cli_desc = """ Use a trained GZ21 neural net to predict forcing for input ocean velocity data. This script is intended as example of how use the GZ21 neural net, generating @@ -54,21 +47,21 @@ into your GCM of choice. """ -p = configargparse.ArgParser(description=DESCRIPTION) +p = configargparse.ArgParser(description=_cli_desc) p.add("--config-file", is_config_file=True, help="config file path") - p.add("--input-data-dir", type=str, required=True, help="path to input ocean velocity data, in zarr format (folder)") p.add("--model-state-dict-file", type=str, required=True, help="model state dict file (*.pth)") p.add("--out-dir", type=str, required=True, help="folder to save forcing predictions dataset to (in zarr format)") - p.add("--device", type=str, default="cuda", help="neural net device (e.g. cuda, cuda:0, cpu)") -p.add("--splits", type=int) options = p.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +cli.fail_if_path_is_nonempty_dir( + 1, f"--out-dir \"{options.out_dir}\" invalid", options.out_dir) + # --- logger.info("loading input (coarse) ocean momentum data...") @@ -77,14 +70,7 @@ with ProgressBar(), TaskInfo("Applying transforms to dataset"): ds_computed_xr = submodel.fit_transform(ds_computed_xr) -ds_computed_torch = pytorch_dataset_from_cm2_6_forcing_dataset(ds_computed_xr) - -logger.info("performing various dataset transforms...") -features_transform_ = ComposeTransforms() -targets_transform_ = ComposeTransforms() -transform = DatasetTransformer(features_transform_, targets_transform_) -dataset = DatasetWithTransform(ds_computed_torch, transform) - +dataset = lib.gz21_train_data_subdomain_xr_to_torch(ds_computed_xr) loader = DataLoader(dataset) criterion = loss_funcs.HeteroskedasticGaussianLossV2(dataset.n_targets) diff --git a/src/gz21_ocean_momentum/cli/train.py b/src/gz21_ocean_momentum/cli/train.py index 93de42ac..b7c15bd1 100755 --- a/src/gz21_ocean_momentum/cli/train.py +++ b/src/gz21_ocean_momentum/cli/train.py @@ -7,7 +7,7 @@ import gz21_ocean_momentum.common.cli as cli import gz21_ocean_momentum.common.assorted as common import gz21_ocean_momentum.common.bounding_box as bounding_box -import gz21_ocean_momentum.unsorted.train_data_xr_to_pytorch as lib +import gz21_ocean_momentum.lib.model as lib import gz21_ocean_momentum.models.submodels as submodels import gz21_ocean_momentum.models.transforms as transforms import gz21_ocean_momentum.models.models1 as model diff --git a/src/gz21_ocean_momentum/data/datasets.py b/src/gz21_ocean_momentum/data/datasets.py index 0c7ab637..815c0cbf 100644 --- a/src/gz21_ocean_momentum/data/datasets.py +++ b/src/gz21_ocean_momentum/data/datasets.py @@ -429,7 +429,6 @@ def transform(self, x): return np.concatenate((left, x, right), axis=self.axis) def transform_coordinate(self, coords, dim): - print(f"{dim}, {self.dim_name}") if dim == self.dim_name: left = coords[-self.nb_points :] - self.length right = coords[: self.nb_points] + self.length @@ -697,8 +696,6 @@ def __len__(self): Number of samples of the dataset. """ - print("xrrawdataset len called") - print(len(self.xr_dataset[self._index])) try: return len(self.xr_dataset[self._index]) except KeyError as e: @@ -797,7 +794,6 @@ def __getattr__(self, attr): raise AttributeError() def __len__(self): - print("len on datasetwithtransform") return len(self.dataset) def add_transforms_from_model(self, model): @@ -914,6 +910,8 @@ class ConcatDataset_(ConcatDataset): - enforces the concatenated dataset to have the same shapes - passes on attributes (from the first dataset, assuming they are equal accross concatenated datasets) + + TODO input datasets need to have .height, .width """ def __init__(self, datasets): diff --git a/src/gz21_ocean_momentum/step/data/lib.py b/src/gz21_ocean_momentum/lib/data.py similarity index 100% rename from src/gz21_ocean_momentum/step/data/lib.py rename to src/gz21_ocean_momentum/lib/data.py diff --git a/src/gz21_ocean_momentum/unsorted/train_data_xr_to_pytorch.py b/src/gz21_ocean_momentum/lib/model.py similarity index 83% rename from src/gz21_ocean_momentum/unsorted/train_data_xr_to_pytorch.py rename to src/gz21_ocean_momentum/lib/model.py index 22f82bd6..1af01e32 100644 --- a/src/gz21_ocean_momentum/unsorted/train_data_xr_to_pytorch.py +++ b/src/gz21_ocean_momentum/lib/model.py @@ -1,7 +1,11 @@ +# Common functions relating to neural net model, training data. + import xarray as xr import numpy as np import torch.utils.data as torch +from gz21_ocean_momentum.common.assorted import at_idx_pct + from gz21_ocean_momentum.data.datasets import ( DatasetWithTransform, DatasetTransformer, @@ -11,7 +15,11 @@ ComposeTransforms, ) -def cm26_xarray_to_torch(ds_xr: xr.Dataset): +def cm26_xarray_to_torch(ds_xr: xr.Dataset) -> torch.Dataset: + """ + Obtain a PyTorch `Dataset` view over an xarray dataset, specifically for + CM2.6 ocean velocity data annotated with forcings in `S_x` and `S_y`. + """ ds_torch = RawDataFromXrDataset(ds_xr) ds_torch.index = "time" ds_torch.add_input("usurf") @@ -20,7 +28,7 @@ def cm26_xarray_to_torch(ds_xr: xr.Dataset): ds_torch.add_output("S_y") return ds_torch -def gz21_train_data_subdomain_xr_to_torch(ds_xr: xr.Dataset): +def gz21_train_data_subdomain_xr_to_torch(ds_xr: xr.Dataset) -> torch.Dataset: """ Convert GZ21 training data (coarsened CM2.6 data with diagnosed forcings) into a PyTorch dataset. @@ -39,22 +47,6 @@ def gz21_train_data_subdomain_xr_to_torch(ds_xr: xr.Dataset): return ds_torch_with_transform -def at_idx_pct(pct: float, a) -> int: - """ - Obtain the index into the given list-like to the given percent. - No interpolation is performed: we choose the leftmost closest index i.e. the - result is floored. - - e.g. `at_idx_pct(0.5, [0,1,2]) == 1` - - Must be able to `len(a)`. - - Invariant: `0<=pct<=1`. - - Returns a valid index into `a`. - """ - return int(pct * len(a)) - def prep_train_test_dataloaders( dss: list, pct_train_end: float, diff --git a/src/gz21_ocean_momentum/step/inference/lib.py b/src/gz21_ocean_momentum/step/inference/lib.py deleted file mode 100644 index 15684272..00000000 --- a/src/gz21_ocean_momentum/step/inference/lib.py +++ /dev/null @@ -1,9 +0,0 @@ -def cm2_6_prep_pytorch(_: xr.Dataset, idx_start: float, idx_end) -> _: - """ - Various transformations, subsetting of a CM2.6 dataset. - - Retrieve using data step lib, slice & restrict spatial domain. - - idx_start: 0->1 subset of dataset to use, start - idx_end: 0->1 subset of dataset to use, end (must be > idx_start) - """