diff --git a/src/gz21_ocean_momentum/analysis/analysis.py b/src/gz21_ocean_momentum/analysis/analysis.py index 6b18dfa4..addd904f 100644 --- a/src/gz21_ocean_momentum/analysis/analysis.py +++ b/src/gz21_ocean_momentum/analysis/analysis.py @@ -14,6 +14,7 @@ import numpy as np import matplotlib.pyplot as plt from os.path import join +from functools import wraps data_location = "/data/ag7531/" figures_directory = "figures" @@ -23,15 +24,28 @@ def allow_hold_on(f): """Decorator that allows to specify a hold_on parameter that makes the plotting use the current figure instead of creating a new one.""" - def wrapper_f(*args, **kargs): - if "hold_on" in kargs and kargs["hold_on"]: + @wraps(f) # preserves the name and docstring of the function + def wrapped(*args, **kargs): + if kargs.pop("hold_on", False): plt.gcf() - del kargs["hold_on"] else: plt.figure() f(*args, **kargs) - return wrapper_f + return wrapped + + +def allow_save_fig(f): + """Decorator that allows to specify a save_file parameter that saves the plot.""" + + @wraps(f) + def wrapped(*args, **kargs): + save_file = kargs.pop("save_file", None) + f(*args, **kargs) + if save_file: # save_file gives the filename of the saved figure + plt.savefig(join(data_location, figures_directory, save_file)) + + return wrapped class TimeSeriesForPoint: @@ -71,6 +85,7 @@ def true_values(self): return self._time_series["true values"] @allow_hold_on + @allow_save_fig def plot_pred_vs_true(self): """Plots the predictions and the true target accross time for the instance's point.""" @@ -87,4 +102,4 @@ def plot_pred_vs_true(self): def save_fig(self): if not self._fig: self.plot_pred_vs_true() - plt.savefig(join(data_location, figures_directory, self.name)) + plt.savefig(join(data_location, figures_directory, self._name)) diff --git a/src/gz21_ocean_momentum/analysis/utils.py b/src/gz21_ocean_momentum/analysis/utils.py index 7b1c382c..fc944968 100755 --- a/src/gz21_ocean_momentum/analysis/utils.py +++ b/src/gz21_ocean_momentum/analysis/utils.py @@ -14,6 +14,7 @@ import pandas as pd from analysis.analysis import TimeSeriesForPoint import xarray as xr +from typing import Optional from scipy.ndimage import gaussian_filter from data.pangeo_catalog import get_patch, get_whole_data from cartopy.crs import PlateCarree @@ -165,7 +166,7 @@ def onClick(event): fig.canvas.mpl_connect("button_press_event", onClick) -def sample(data: np.ndarray, step_time: int = 1, nb_per_time: int = 5): +def sample(data: np.ndarray, step_time: int = 1, nb_per_time: int = 5, random_state: Optional[int] = None): """Samples points from the data, where it is assumed that the data is 4-D, with the first dimension representing time , the second the channel, and the others representing spatial dimensions. @@ -185,6 +186,9 @@ def sample(data: np.ndarray, step_time: int = 1, nb_per_time: int = 5): :nb_per_time: int, Number of points used (chosen randomly according to a uniform distribution over the spatial domain) for each image. + + :random_state: int, optional, + Random state used for the random number generator. Returns @@ -194,6 +198,7 @@ def sample(data: np.ndarray, step_time: int = 1, nb_per_time: int = 5): """ if data.ndim != 4: raise ValueError("The data is expected to have 4 dimensions.") + np.random.seed(random_state) n_times, n_channels, n_x, n_y = data.shape time_indices = np.arange(0, n_times, step_time) x_indices = np.random.randint(0, n_x, (time_indices.shape[0], 2, nb_per_time)) diff --git a/src/gz21_ocean_momentum/models/base.py b/src/gz21_ocean_momentum/models/base.py index 2b667f5d..3d9feaf0 100755 --- a/src/gz21_ocean_momentum/models/base.py +++ b/src/gz21_ocean_momentum/models/base.py @@ -19,7 +19,7 @@ class DetectOutputSizeMixin: """Class to detect the shape of a neural net.""" - # TODO: protect this with `@no_grad` decorator to conserve memory/time etc. + @torch.no_grad() def output_width(self, input_height, input_width): """ Generate a tensor and run forward model to get output width. @@ -34,17 +34,14 @@ def output_width(self, input_height, input_width): dummy_out.size(3) : int width of the output tensor """ - # TODO: following 2 lines can be combined for speedup as - # e.g. `torch.zeros(10, 10, device=self.device)` - dummy_in = torch.zeros((1, self.n_in_channels, input_height, input_width)) - dummy_in = dummy_in.to(device=self.device) + dummy_in = torch.zeros((1, self.n_in_channels, input_height, input_width), device=self.device) # AB - Self here is assuming access to a neural net forward method? # If so I think this should really be contained in FullyCNN. # We can discuss and I am happy to perform the refactor. dummy_out = self(dummy_in) return dummy_out.size(3) - # TODO: protect this with `@no_grad` decorator to conserve memory/time etc. + @torch.no_grad() def output_height(self, input_height, input_width): """ Generate a tensor and run forward model to get output height. @@ -59,10 +56,7 @@ def output_height(self, input_height, input_width): dummy_out.size(2) : int height of the output tensor """ - # TODO: following 2 lines can be combined for speedup as - # e.g. `torch.zeros(10, 10, device=self.device)` - dummy_in = torch.zeros((1, self.n_in_channels, input_height, input_width)) - dummy_in = dummy_in.to(device=self.device) + dummy_in = torch.zeros((1, self.n_in_channels, input_height, input_width), device=self.device) dummy_out = self(dummy_in) return dummy_out.size(2) diff --git a/src/gz21_ocean_momentum/models/fully_conv_net.py b/src/gz21_ocean_momentum/models/fully_conv_net.py index 32282aad..994e8069 100755 --- a/src/gz21_ocean_momentum/models/fully_conv_net.py +++ b/src/gz21_ocean_momentum/models/fully_conv_net.py @@ -76,6 +76,8 @@ def __init__( # store in_chans as attribute self._n_in_channels = in_chans + self._final_transformation = lambda x: x + @staticmethod def _process_padding(padding: Optional[str] = None) -> Tuple[int, int]: diff --git a/src/gz21_ocean_momentum/trainScript.py b/src/gz21_ocean_momentum/trainScript.py index fa3c6c4b..b47976a5 100755 --- a/src/gz21_ocean_momentum/trainScript.py +++ b/src/gz21_ocean_momentum/trainScript.py @@ -153,6 +153,9 @@ def check_str_is_None(string_in: str): "models.transforms.", ) parser.add_argument("--submodel", type=str, default="transform1") +parser.add_argument( + "--device", type=str, default="auto", help="Device to use for training." +) parser.add_argument( "--features_transform_cls_name", type=str, default="None", help="Depreciated" ) @@ -216,9 +219,11 @@ def check_str_is_None(string_in: str): _check_dir(os.path.join(data_location, directory)) # Device selection. If available we use the GPU. -# TODO Allow CLI argument to select the GPU -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -device_type = DEVICE_TYPE.GPU if torch.cuda.is_available() else DEVICE_TYPE.CPU +device = torch.device( + params.device if params.device != "auto" else + "cuda:0" if torch.cuda.is_available() else "cpu" +) +device_type = DEVICE_TYPE.CPU if device.type == 'cpu' else DEVICE_TYPE.GPU print("Selected device type: ", device_type.value) diff --git a/src/gz21_ocean_momentum/utils.py b/src/gz21_ocean_momentum/utils.py index e413b26a..8efe45af 100755 --- a/src/gz21_ocean_momentum/utils.py +++ b/src/gz21_ocean_momentum/utils.py @@ -6,8 +6,12 @@ @author: arthur """ +import os import mlflow from mlflow.tracking import client +import torch +import random +import numpy as np import pandas as pd import pickle import gz21_ocean_momentum.models as models @@ -126,3 +130,16 @@ def pickle_artifact(run_id: str, path: str): file = client.download_artifacts(run_id, path) f = open(file, "rb") return pickle.load(f) + + +def seed_all(seed: int = 0): + random.seed(seed) + # seed hash + os.environ['PYTHONHASHSEED'] = str(seed) + # seed numpy + np.random.seed(seed) + # seed torch + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True diff --git a/tests/models/test_fully_conv_net.py b/tests/models/test_fully_conv_net.py index 49df2434..4a2f10d7 100755 --- a/tests/models/test_fully_conv_net.py +++ b/tests/models/test_fully_conv_net.py @@ -10,7 +10,6 @@ def test_construct_valid(): Simple check migrated from `models.models1`. """ net = FullyCNN() - net._final_transformation = lambda x: x input_ = torch.randint(0, 10, (17, 2, 35, 30)).to(dtype=torch.float) input_[0, 0, 0, 0] = np.nan output = net(input_)