Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated WorldCerealInferenceDataset #103

Merged
merged 27 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
47034ac
Updated inference dataset
kvantricht Sep 3, 2024
d7d541e
Updated WorldCerealInferenceDataset for new files
kvantricht Sep 3, 2024
5d6fe88
Remove unused import
kvantricht Sep 3, 2024
f0b1c8c
Fix typing
kvantricht Sep 3, 2024
ce52f9c
Ignore this type check
kvantricht Sep 3, 2024
1ac7ea2
Black fixes
kvantricht Sep 3, 2024
906652e
More line length fixes
kvantricht Sep 3, 2024
1a19ec1
:facepalm
kvantricht Sep 3, 2024
8ee5b47
:facepalm:
kvantricht Sep 3, 2024
3f75632
Added ground truth labels
kvantricht Sep 3, 2024
9155516
Only subset of the file for faster tests
kvantricht Sep 3, 2024
5f6b58d
Fix gt selection and rearranged t subset
kvantricht Sep 3, 2024
794065e
test update
kvantricht Sep 3, 2024
817da45
More consistent handling of inference datasets
kvantricht Sep 4, 2024
666d05f
use h5netcdf instead of rioxarray
kvantricht Sep 4, 2024
7f14c8c
Dont import at the top to avoid dependency
kvantricht Sep 4, 2024
70015af
Add rioxarray and relax xarray version
kvantricht Sep 4, 2024
eab0569
Relax rioxarray library version
kvantricht Sep 4, 2024
3eea359
Test file update
kvantricht Sep 5, 2024
4f9a18f
Clarified comment
kvantricht Sep 5, 2024
6547114
Black fix
kvantricht Sep 5, 2024
8ac58df
Only create the grid before feeding to presto
kvantricht Sep 5, 2024
fab066c
Formatting
gabrieltseng Sep 5, 2024
6079001
Revert "Only create the grid before feeding to presto"
kvantricht Sep 5, 2024
2575d39
Merge branch 'updated-inferencedatasets' of github.com:WorldCereal/pr…
kvantricht Sep 5, 2024
be7a31b
Clarified variable naming
kvantricht Sep 5, 2024
13f8433
Black fix
kvantricht Sep 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
274 changes: 203 additions & 71 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
from math import modf
from pathlib import Path
from random import sample
from typing import Callable, Dict, List, Optional, Tuple, cast
from typing import Callable, Dict, List, Optional, Tuple, Union

import geopandas as gpd
import numpy as np
import pandas as pd
import rioxarray
import xarray as xr
from einops import rearrange, repeat
from pyproj import Transformer
from sklearn.utils.class_weight import compute_class_weight
from einops import rearrange
from pyproj import CRS, Transformer
from torch.utils.data import Dataset

from .dataops import (
Expand Down Expand Up @@ -320,6 +318,8 @@ def __getitem__(self, idx):

@property
def class_weights(self) -> np.ndarray:
from sklearn.utils.class_weight import compute_class_weight

if self._class_weights is None:
ys = []
for _, row in self.df.iterrows():
Expand Down Expand Up @@ -368,23 +368,7 @@ def __getitem__(self, idx):

class WorldCerealInferenceDataset(Dataset):
_NODATAVALUE = 65535
Y = "worldcereal_cropland"
BAND_MAPPING = {
"B02": "B2",
"B03": "B3",
"B04": "B4",
"B05": "B5",
"B06": "B6",
"B07": "B7",
"B08": "B8",
# B8A is missing
"B11": "B11",
"B12": "B12",
"VH": "VH",
"VV": "VV",
"precipitation-flux": "total_precipitation",
"temperature-mean": "temperature_2m",
}
Y = "WORLDCEREAL_TEMPORARYCROPS_2021"

def __init__(self, path_to_files: Path = data_dir / "inference_areas"):
self.path_to_files = path_to_files
Expand All @@ -394,72 +378,220 @@ def __len__(self):
return len(self.all_files)

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ds = cast(xr.Dataset, rioxarray.open_rasterio(filepath, decode_times=False))
epsg_coords = ds.rio.crs.to_epsg()

num_instances = len(ds.x) * len(ds.y)
num_timesteps = len(ds.t)
eo_data = np.zeros((num_instances, num_timesteps, len(BANDS)))
mask = np.zeros((num_instances, num_timesteps, len(BANDS_GROUPS_IDX)))
# for now, B8A is missing
mask[:, :, IDX_TO_BAND_GROUPS["B8A"]] = 1

for org_band, presto_val in cls.BAND_MAPPING.items():
# flatten the values
values = np.swapaxes(ds[org_band].values.reshape((num_timesteps, -1)), 0, 1)
idx_valid = values != cls._NODATAVALUE
def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
"""
Extracts EO data and mask arrays from the input xarray.DataArray.

Args:
inarr (xr.DataArray): Input xarray.DataArray containing EO data.

Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing EO data array and mask array.
"""
num_pixels = len(inarr.x) * len(inarr.y)
num_timesteps = len(inarr.t)

# Handle NaN values in Presto compatible way
inarr = inarr.astype(np.float32)
inarr = inarr.fillna(65535)

eo_data = np.zeros((num_pixels, num_timesteps, len(BANDS)))
mask = np.zeros((num_pixels, num_timesteps, len(BANDS_GROUPS_IDX)))

for presto_band in NORMED_BANDS:
if presto_band in inarr.coords["bands"]:
values = np.swapaxes(
inarr.sel(bands=presto_band).values.reshape((num_timesteps, -1)),
0,
1,
)
idx_valid = values != cls._NODATAVALUE
values = cls._preprocess_band_values(values, presto_band)
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
elif presto_band == "NDVI":
# Band NDVI will be computed by Presto
kvantricht marked this conversation as resolved.
Show resolved Hide resolved
continue
else:
logger.warning(f"Band {presto_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

if presto_val in ["VV", "VH"]:
# convert to dB
values = 20 * np.log10(values) - 83
elif presto_val == "total_precipitation":
# scaling, and AgERA5 is in mm, Presto expects m
values = values / (100 * 1000.0)
elif presto_val == "temperature_2m":
# remove scaling
values = values / 100
return eo_data, mask

@staticmethod
def _extract_latlons(inarr: xr.DataArray, epsg: int) -> np.ndarray:
"""
Extracts latitudes and longitudes from the input xarray.DataArray.

Args:
inarr (xr.DataArray): Input xarray.DataArray containing spatial coordinates.
epsg (int): EPSG code for coordinate reference system.

Returns:
np.ndarray: Array containing extracted latitudes and longitudes.
"""
# EPSG:4326 is the supported crs for presto
lon, lat = np.meshgrid(inarr.x, inarr.y)
transformer = Transformer.from_crs(f"EPSG:{epsg}", "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(lon, lat)
latlons = rearrange(np.stack([lat, lon]), "c x y -> (x y) c")

# 2D array where each row represents a pair of latitude and longitude coordinates.
return latlons

@classmethod
def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray:
"""
Preprocesses the band values based on the given presto_val.

Args:
values (np.ndarray): Array of band values to preprocess.
presto_val (str): Name of the band for preprocessing.

Returns:
np.ndarray: Preprocessed array of band values.
"""
if presto_band in ["VV", "VH"]:
# Convert to dB
values = 20 * np.log10(values) - 83
elif presto_band == "total_precipitation":
# Scale precipitation and convert mm to m
values = values / (100 * 1000.0)
elif presto_band == "temperature_2m":
# Remove scaling
values = values / 100
return values

@staticmethod
def _extract_months(inarr: xr.DataArray) -> np.ndarray:
"""
Calculate the start month based on the first timestamp in the input array,
and create an array of the same length filled with that start month value.

Parameters:
- inarr: xarray.DataArray or numpy.ndarray
Input array containing timestamps.

Returns:
- months: numpy.ndarray
Array of start month values, with the same length as the input array.
"""
num_instances = len(inarr.x) * len(inarr.y)

eo_data[:, :, BANDS.index(presto_val)] = values
mask[:, :, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid
start_month = (inarr.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1

y = rearrange(ds[cls.Y].values, "t x y -> (x y) t")
# -1 because we index from 0
start_month = (ds.t.values[0].astype("datetime64[M]").astype(int) % 12 + 1) - 1
months = np.ones((num_instances)) * start_month
return months

transformer = Transformer.from_crs(f"EPSG:{epsg_coords}", "EPSG:4326", always_xy=True)
lon, lat = transformer.transform(ds.x, ds.y)
@staticmethod
def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray:
"""
Subset the input xarray.DataArray temporally based on `valid_time` attribute.

latlons = np.stack(
[np.repeat(lat, repeats=len(lon)), repeat(lon, "c -> (h c)", h=len(lat))],
axis=-1,
)
Args:
inarr (xr.DataArray): Input xarray.DataArray containing EO data.

Returns:
xr.DataArray: Temporally subsetted xarray.DataArray.
"""

return eo_data, np.repeat(mask, BAND_EXPANSION, axis=-1), latlons, months, y
# Use valid_time attribute to extract the right part of the time series
valid_time = pd.to_datetime(inarr.attrs["valid_time"]).replace(day=1)
end_time = valid_time + pd.DateOffset(months=5)
start_time = valid_time - pd.DateOffset(months=6)
inarr = inarr.sel(t=slice(start_time, end_time))
num_timesteps = len(inarr.t)
assert num_timesteps == 12, "Expected 12 timesteps, only found {}".format(num_timesteps)

return inarr

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
]:
ds = xr.open_dataset(filepath)
epsg = CRS.from_wkt(xr.open_dataset(filepath).crs.attrs["crs_wkt"]).to_epsg()

if epsg is None:
raise ValueError("EPSG code not found in the input file.")
inarr = ds.drop("crs").to_array(dim="bands")
lon, lat = inarr.y.values, inarr.x.values

# Temporal subsetting to 12 timesteps
inarr = cls._subset_array_temporally(inarr)

eo_data, mask = cls._extract_eo_data(inarr)
latlons = cls._extract_latlons(inarr, epsg)
months = cls._extract_months(inarr)

if cls.Y not in ds:
target = np.ones_like(months) * cls._NODATAVALUE
else:
target = rearrange(inarr.sel(bands=cls.Y).values, "t x y -> (x y) t")

return (
eo_data,
np.repeat(mask, BAND_EXPANSION, axis=-1),
latlons,
months,
target,
lon,
lat,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need latlons and lon, lat here? If I understand correctly, lon == latlons[:, 1] and lat == latlons[:, 0], which means we don't need lon, lat?

Copy link
Contributor Author

@kvantricht kvantricht Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not so simple I'm afraid, I struggled with this a lot. After the meshgrid and flattening of latlons it becomes quite hard to get back to original lon, lat we need to properly reconstruct the DataArray.

latlons.shape
(2500, 2)
lon.shape
(50,)

So no, lon != latlons[:, 1]. I've been thinking about easier ways but haven't found them as of yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm okay. I think this is probably (?) easier, especially considering we additionally apply the transformation. The latlons take up quite a bit of RAM, and for large tiles this might become an issue.

Just in case its useful here is some code to go from the flat latlons back to the original ones (but without the transformation):

import numpy as np
from einops import rearrange

org_lat = np.array([1, 2, 3])
org_lon = np.array([4, 5, 6, 7])

def to_flat_latlons(lat, lon):
    lon, lat = np.meshgrid(lon, lat)
    latlons = rearrange(np.stack([lat, lon]), "c x y -> (x y) c")
    return latlons


def from_latlons(latlons):
    x = len(np.unique(latlons[:, 0]))
    y = len(np.unique(latlons[:, 1]))
    latlons = rearrange(latlons, "(x y) c -> c x y", x=x, y=y)
    lats, lons = latlons[0], latlons[1]
    return lats[:, 0], lons[0, ]

output_lat, output_lon = from_latlons(to_flat_latlons(org_lat, org_lon))
assert np.equal(output_lon, org_lon).all()
assert np.equal(output_lat, org_lat).all()

I wonder if its worth passing around the original lats and lons and only applying that transformation right before the model ingests the values

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope i understood this well. Could you check if my latest commit address this the way you suggest it? Feel free to suggest to do it differently. Rest of the day, unfortunately I'm away. Will catch up tomorrow morning.

)

def __getitem__(self, idx):
filepath = self.all_files[idx]
eo, mask, latlons, months, y = self.nc_to_arrays(filepath)
eo, mask, latlons, months, target, lon, lat = self.nc_to_arrays(filepath)

dynamic_world = np.ones((eo.shape[0], eo.shape[1])) * (DynamicWorld2020_2021.class_amount)

return S1_S2_ERA5_SRTM.normalize(eo), dynamic_world, mask, latlons, months, y
return (
S1_S2_ERA5_SRTM.normalize(eo),
dynamic_world,
mask,
latlons,
months,
target,
lon,
lat,
)

@staticmethod
def combine_predictions(
latlons: np.ndarray, all_preds: np.ndarray, gt: np.ndarray, ndvi: np.ndarray
) -> pd.DataFrame:
flat_lat, flat_lon = latlons[:, 0], latlons[:, 1]
all_preds: np.ndarray,
gt: np.ndarray,
ndvi: np.ndarray,
lon: Union[xr.DataArray, np.ndarray, List[float]],
lat: Union[xr.DataArray, np.ndarray, List[float]],
) -> xr.DataArray:

if len(all_preds.shape) == 1:
all_preds = np.expand_dims(all_preds, axis=-1)

data_dict: Dict[str, np.ndarray] = {"lat": flat_lat, "lon": flat_lon}
# Get band names
bands = [f"prediction_{i}" for i in range(all_preds.shape[1])] + [
"ground_truth",
"ndvi",
]

# Initialize gridded data array
data = np.empty((len(bands), len(lon), len(lat)))

# Fill with gridded predictions
for i in range(all_preds.shape[1]):
prediction_label = f"prediction_{i}"
data_dict[prediction_label] = all_preds[:, i]
data_dict["ground_truth"] = gt[:, 0]
data_dict["ndvi"] = ndvi
return pd.DataFrame(data=data_dict).set_index(["lat", "lon"])
data[i, ...] = rearrange(all_preds[:, i], "(y x) -> 1 y x", y=len(lon), x=len(lat))

# Fill with ground truth and NDVI
data[-2, ...] = rearrange(gt[:, 0], "(y x) -> 1 y x", y=len(lon), x=len(lat))
data[-1, ...] = rearrange(ndvi, "(y x) -> 1 y x", y=len(lon), x=len(lat))

return xr.DataArray(coords=[bands, lon, lat], dims=["bands", "lon", "lat"], data=data)
14 changes: 9 additions & 5 deletions presto/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def finetune_sklearn_model(
assert model_mode in ["Regression", "Random Forest", "CatBoostClassifier"]
pretrained_model.eval()

def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
def dataloader_to_encodings_and_targets(
dl: DataLoader,
) -> Tuple[np.ndarray, np.ndarray]:
encoding_list, target_list = [], []
for x, y, dw, latlons, month, variable_mask in dl:
x_f, dw_f, latlons_f, month_f, variable_mask_f = [
Expand Down Expand Up @@ -184,7 +186,9 @@ def dataloader_to_encodings_and_targets(dl: DataLoader) -> Tuple[np.ndarray, np.
if model == "CatBoostClassifier":
fit_models.append(
clone(model_dict[model]).fit(
train_encodings, train_targets, eval_set=Pool(val_encodings, val_targets)
train_encodings,
train_targets,
eval_set=Pool(val_encodings, val_targets),
)
)
else:
Expand Down Expand Up @@ -245,7 +249,7 @@ def spatial_inference(
assert self.spatial_inference_savedir is not None
ds = WorldCerealInferenceDataset()
for i in range(len(ds)):
eo, dynamic_world, mask, latlons, months, y = ds[i]
eo, dynamic_world, mask, latlons, months, y, lon, lat = ds[i]
dl = DataLoader(
TensorDataset(
torch.from_numpy(eo).float(),
Expand All @@ -263,13 +267,13 @@ def spatial_inference(
# take the middle timestep's ndvi
middle_timestep = eo.shape[1] // 2
ndvi = eo[:, middle_timestep, NORMED_BANDS.index("NDVI")]
df = ds.combine_predictions(latlons, test_preds_np, y, ndvi)
da = ds.combine_predictions(test_preds_np, y, ndvi, lon, lat)
prefix = f"{self.name}_{ds.all_files[i].stem}"
if pretrained_model is None:
filename = f"{prefix}_finetuning.nc"
else:
filename = f"{prefix}_{finetuned_model.__class__.__name__}.nc"
df.to_xarray().to_netcdf(self.spatial_inference_savedir / filename)
da.to_netcdf(self.spatial_inference_savedir / filename)

@torch.no_grad()
def evaluate(
Expand Down
2 changes: 1 addition & 1 deletion requirements.full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ black==22.10.0
catboost==1.2.2
fastparquet
flake8==5.0.4
h5netcdf==1.3.0
isort==5.10.1
mypy==1.1.1
matplotlib==3.7.5
rioxarray==0.13.1
rtree==1.1.0
tqdm==4.64.1
types-requests~=2.32.0
Expand Down
3 changes: 2 additions & 1 deletion requirements.inference.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
einops==0.6.0
geopandas>=0.13.2
numpy==1.23.5
rioxarray>=0.15.0
torch==2.3.1
tqdm==4.64.1
xarray==2023.1.0
xarray>=2023.1.0
requests==2.32.3
Loading
Loading