Skip to content

Commit

Permalink
Updated WorldCerealInferenceDataset for new files
Browse files Browse the repository at this point in the history
  • Loading branch information
kvantricht committed Sep 3, 2024
1 parent 47034ac commit d7d541e
Showing 1 changed file with 126 additions and 56 deletions.
182 changes: 126 additions & 56 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
from einops import rearrange, repeat
from pyproj import Transformer
from einops import rearrange
from pyproj import CRS, Transformer
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset

Expand Down Expand Up @@ -369,22 +368,6 @@ def __getitem__(self, idx):
class WorldCerealInferenceDataset(Dataset):
_NODATAVALUE = 65535
Y = "worldcereal_cropland"
BAND_MAPPING = {
"B02": "B2",
"B03": "B3",
"B04": "B4",
"B05": "B5",
"B06": "B6",
"B07": "B7",
"B08": "B8",
# B8A is missing
"B11": "B11",
"B12": "B12",
"VH": "VH",
"VV": "VV",
"precipitation-flux": "total_precipitation",
"temperature-mean": "temperature_2m",
}

def __init__(self, path_to_files: Path = data_dir / "inference_areas"):
self.path_to_files = path_to_files
Expand All @@ -394,49 +377,136 @@ def __len__(self):
return len(self.all_files)

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ds = cast(xr.Dataset, rioxarray.open_rasterio(filepath, decode_times=False))
epsg_coords = ds.rio.crs.to_epsg()

num_instances = len(ds.x) * len(ds.y)
num_timesteps = len(ds.t)
eo_data = np.zeros((num_instances, num_timesteps, len(BANDS)))
mask = np.zeros((num_instances, num_timesteps, len(BANDS_GROUPS_IDX)))
# for now, B8A is missing
mask[:, :, IDX_TO_BAND_GROUPS["B8A"]] = 1

for org_band, presto_val in cls.BAND_MAPPING.items():
# flatten the values
values = np.swapaxes(ds[org_band].values.reshape((num_timesteps, -1)), 0, 1)
idx_valid = values != cls._NODATAVALUE
def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
"""
Extracts EO data and mask arrays from the input xarray.DataArray.
Args:
inarr (xr.DataArray): Input xarray.DataArray containing EO data.
Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing EO data array and mask array.
"""
num_pixels = len(inarr.x) * len(inarr.y)

# Use valid_time attribute to extract the right part of the time series
valid_time = pd.to_datetime(inarr.attrs["valid_time"]).replace(day=1)
end_time = valid_time + pd.DateOffset(months=5)
start_time = valid_time - pd.DateOffset(months=6)
inarr = inarr.sel(t=slice(start_time, end_time))
num_timesteps = len(inarr.t)
assert num_timesteps == 12, "Expected 12 timesteps, only found {}".format(num_timesteps)

# Handle NaN values in Presto compatible way
inarr = inarr.astype(np.float32)
inarr = inarr.fillna(65535)

eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS)))
mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX)))

for presto_band in NORMED_BANDS:
if presto_band in inarr.coords["bands"]:
values = np.swapaxes(
inarr.sel(bands=presto_band).values.reshape((num_timesteps, -1)),
0,
1,
)
idx_valid = values != cls._NODATAVALUE
values = cls._preprocess_band_values(values, presto_band)
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
elif presto_band == "NDVI":
# Band NDVI will be computed by Presto
continue
else:
logger.warning(f"Band {presto_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

if presto_val in ["VV", "VH"]:
# convert to dB
values = 20 * np.log10(values) - 83
elif presto_val == "total_precipitation":
# scaling, and AgERA5 is in mm, Presto expects m
values = values / (100 * 1000.0)
elif presto_val == "temperature_2m":
# remove scaling
values = values / 100
return eo_data, mask

eo_data[:, :, BANDS.index(presto_val)] = values
mask[:, :, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid
@staticmethod
def _extract_latlons(inarr: xr.DataArray, epsg: int) -> np.ndarray:
"""
Extracts latitudes and longitudes from the input xarray.DataArray.
Args:
inarr (xr.DataArray): Input xarray.DataArray containing spatial coordinates.
epsg (int): EPSG code for coordinate reference system.
Returns:
np.ndarray: Array containing extracted latitudes and longitudes.
"""
# EPSG:4326 is the supported crs for presto
lon, lat = np.meshgrid(inarr.x, inarr.y)
transformer = Transformer.from_crs(f"EPSG:{epsg}", "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(lon, lat)
latlons = rearrange(np.stack([lat, lon]), "c x y -> (x y) c")

# 2D array where each row represents a pair of latitude and longitude coordinates.
return latlons

@classmethod
def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray:
"""
Preprocesses the band values based on the given presto_val.
Args:
values (np.ndarray): Array of band values to preprocess.
presto_val (str): Name of the band for preprocessing.
Returns:
np.ndarray: Preprocessed array of band values.
"""
if presto_band in ["VV", "VH"]:
# Convert to dB
values = 20 * np.log10(values) - 83
elif presto_band == "total_precipitation":
# Scale precipitation and convert mm to m
values = values / (100 * 1000.0)
elif presto_band == "temperature_2m":
# Remove scaling
values = values / 100
return values

@staticmethod
def _extract_months(inarr: xr.DataArray) -> np.ndarray:
"""
Calculate the start month based on the first timestamp in the input array,
and create an array of the same length filled with that start month value.
Parameters:
- inarr: xarray.DataArray or numpy.ndarray
Input array containing timestamps.
Returns:
- months: numpy.ndarray
Array of start month values, with the same length as the input array.
"""
num_instances = len(inarr.x) * len(inarr.y)

start_month = (inarr.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1

y = rearrange(ds[cls.Y].values, "t x y -> (x y) t")
# -1 because we index from 0
start_month = (ds.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1
months = np.ones((num_instances)) * start_month
return months

transformer = Transformer.from_crs(f"EPSG:{epsg_coords}", "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(ds.x, ds.y)
@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ds = xr.open_dataset(filepath)
epsg_coords = CRS.from_wkt(xr.open_dataset(filepath).crs.attrs["crs_wkt"]).to_epsg()
inarr = ds.drop("crs").to_array(dim="bands")

latlons = np.stack(
[np.repeat(lat, repeats=len(lon)), repeat(lon, "c -> (h c)", h=len(lat))],
axis=-1,
)
eo_data, mask = cls._extract_eo_data(inarr)
latlons = cls._extract_latlons(inarr, epsg_coords)
months = cls._extract_months(inarr)

if cls.Y not in ds:
y = np.ones_like(months) * cls._NODATAVALUE
else:
# TODO: needs to be checked once the labels are back
y = rearrange(inarr[cls.Y].values, "t x y -> (x y) t")

return eo_data, np.repeat(mask, BAND_EXPANSION, axis=-1), latlons, months, y

Expand Down

0 comments on commit d7d541e

Please sign in to comment.