diff --git a/presto/dataset.py b/presto/dataset.py index aae3eac..b4d5d31 100644 --- a/presto/dataset.py +++ b/presto/dataset.py @@ -8,10 +8,9 @@ import geopandas as gpd import numpy as np import pandas as pd -import rioxarray import xarray as xr -from einops import rearrange, repeat -from pyproj import Transformer +from einops import rearrange +from pyproj import CRS, Transformer from sklearn.utils.class_weight import compute_class_weight from torch.utils.data import Dataset @@ -369,22 +368,6 @@ def __getitem__(self, idx): class WorldCerealInferenceDataset(Dataset): _NODATAVALUE = 65535 Y = "worldcereal_cropland" - BAND_MAPPING = { - "B02": "B2", - "B03": "B3", - "B04": "B4", - "B05": "B5", - "B06": "B6", - "B07": "B7", - "B08": "B8", - # B8A is missing - "B11": "B11", - "B12": "B12", - "VH": "VH", - "VV": "VV", - "precipitation-flux": "total_precipitation", - "temperature-mean": "temperature_2m", - } def __init__(self, path_to_files: Path = data_dir / "inference_areas"): self.path_to_files = path_to_files @@ -394,49 +377,136 @@ def __len__(self): return len(self.all_files) @classmethod - def nc_to_arrays( - cls, filepath: Path - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - ds = cast(xr.Dataset, rioxarray.open_rasterio(filepath, decode_times=False)) - epsg_coords = ds.rio.crs.to_epsg() - - num_instances = len(ds.x) * len(ds.y) - num_timesteps = len(ds.t) - eo_data = np.zeros((num_instances, num_timesteps, len(BANDS))) - mask = np.zeros((num_instances, num_timesteps, len(BANDS_GROUPS_IDX))) - # for now, B8A is missing - mask[:, :, IDX_TO_BAND_GROUPS["B8A"]] = 1 - - for org_band, presto_val in cls.BAND_MAPPING.items(): - # flatten the values - values = np.swapaxes(ds[org_band].values.reshape((num_timesteps, -1)), 0, 1) - idx_valid = values != cls._NODATAVALUE + def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: + """ + Extracts EO data and mask arrays from the input xarray.DataArray. + + Args: + inarr (xr.DataArray): Input xarray.DataArray containing EO data. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing EO data array and mask array. + """ + num_pixels = len(inarr.x) * len(inarr.y) + + # Use valid_time attribute to extract the right part of the time series + valid_time = pd.to_datetime(inarr.attrs["valid_time"]).replace(day=1) + end_time = valid_time + pd.DateOffset(months=5) + start_time = valid_time - pd.DateOffset(months=6) + inarr = inarr.sel(t=slice(start_time, end_time)) + num_timesteps = len(inarr.t) + assert num_timesteps == 12, "Expected 12 timesteps, only found {}".format(num_timesteps) + + # Handle NaN values in Presto compatible way + inarr = inarr.astype(np.float32) + inarr = inarr.fillna(65535) + + eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS))) + mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX))) + + for presto_band in NORMED_BANDS: + if presto_band in inarr.coords["bands"]: + values = np.swapaxes( + inarr.sel(bands=presto_band).values.reshape((num_timesteps, -1)), + 0, + 1, + ) + idx_valid = values != cls._NODATAVALUE + values = cls._preprocess_band_values(values, presto_band) + eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid + mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid + elif presto_band == "NDVI": + # Band NDVI will be computed by Presto + continue + else: + logger.warning(f"Band {presto_band} not found in input data.") + eo_data[:, :, BANDS.index(presto_band)] = 0 + mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 - if presto_val in ["VV", "VH"]: - # convert to dB - values = 20 * np.log10(values) - 83 - elif presto_val == "total_precipitation": - # scaling, and AgERA5 is in mm, Presto expects m - values = values / (100 * 1000.0) - elif presto_val == "temperature_2m": - # remove scaling - values = values / 100 + return eo_data, mask - eo_data[:, :, BANDS.index(presto_val)] = values - mask[:, :, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid + @staticmethod + def _extract_latlons(inarr: xr.DataArray, epsg: int) -> np.ndarray: + """ + Extracts latitudes and longitudes from the input xarray.DataArray. + + Args: + inarr (xr.DataArray): Input xarray.DataArray containing spatial coordinates. + epsg (int): EPSG code for coordinate reference system. + + Returns: + np.ndarray: Array containing extracted latitudes and longitudes. + """ + # EPSG:4326 is the supported crs for presto + lon, lat = np.meshgrid(inarr.x, inarr.y) + transformer = Transformer.from_crs(f"EPSG:{epsg}", "EPSG:4326", always_xy=True) + lon, lat = transformer.transform(lon, lat) + latlons = rearrange(np.stack([lat, lon]), "c x y -> (x y) c") + + # 2D array where each row represents a pair of latitude and longitude coordinates. + return latlons + + @classmethod + def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray: + """ + Preprocesses the band values based on the given presto_val. + + Args: + values (np.ndarray): Array of band values to preprocess. + presto_val (str): Name of the band for preprocessing. + + Returns: + np.ndarray: Preprocessed array of band values. + """ + if presto_band in ["VV", "VH"]: + # Convert to dB + values = 20 * np.log10(values) - 83 + elif presto_band == "total_precipitation": + # Scale precipitation and convert mm to m + values = values / (100 * 1000.0) + elif presto_band == "temperature_2m": + # Remove scaling + values = values / 100 + return values + + @staticmethod + def _extract_months(inarr: xr.DataArray) -> np.ndarray: + """ + Calculate the start month based on the first timestamp in the input array, + and create an array of the same length filled with that start month value. + + Parameters: + - inarr: xarray.DataArray or numpy.ndarray + Input array containing timestamps. + + Returns: + - months: numpy.ndarray + Array of start month values, with the same length as the input array. + """ + num_instances = len(inarr.x) * len(inarr.y) + + start_month = (inarr.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1 - y = rearrange(ds[cls.Y].values, "t x y -> (x y) t") - # -1 because we index from 0 - start_month = (ds.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1 months = np.ones((num_instances)) * start_month + return months - transformer = Transformer.from_crs(f"EPSG:{epsg_coords}", "EPSG:4326", always_xy=True) - lon, lat = transformer.transform(ds.x, ds.y) + @classmethod + def nc_to_arrays( + cls, filepath: Path + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ds = xr.open_dataset(filepath) + epsg_coords = CRS.from_wkt(xr.open_dataset(filepath).crs.attrs["crs_wkt"]).to_epsg() + inarr = ds.drop("crs").to_array(dim="bands") - latlons = np.stack( - [np.repeat(lat, repeats=len(lon)), repeat(lon, "c -> (h c)", h=len(lat))], - axis=-1, - ) + eo_data, mask = cls._extract_eo_data(inarr) + latlons = cls._extract_latlons(inarr, epsg_coords) + months = cls._extract_months(inarr) + + if cls.Y not in ds: + y = np.ones_like(months) * cls._NODATAVALUE + else: + # TODO: needs to be checked once the labels are back + y = rearrange(inarr[cls.Y].values, "t x y -> (x y) t") return eo_data, np.repeat(mask, BAND_EXPANSION, axis=-1), latlons, months, y