From 218b95eb0024bd38a486222ed85deea4dcd7d55f Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 21 Dec 2023 09:53:52 +1300 Subject: [PATCH 1/6] :beers: Implement CLAYModule's predict_step to generate embeddings table Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a76716409e05f1b9dfb992c373fc6acd1e194bcf in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings. --- src/model_clay.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++- src/model_vit.py | 2 +- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/model_clay.py b/src/model_clay.py index eabc654a..b1dd8f0e 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -1,4 +1,9 @@ +import geopandas as gpd import lightning as L +import numpy as np +import pandas as pd +import pyarrow as pa +import shapely import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat @@ -8,6 +13,7 @@ from src.utils import posemb_sincos_1d, posemb_sincos_2d +# %% class Patchify(nn.Module): """ Patchify the input cube & create embeddings per patch @@ -152,7 +158,7 @@ def add_encodings(self, patches): A tensor of shape (B, G, L, D) containing the embeddings of the patches + position & band encoding. """ - B, G, L, D = patches.shape + self.B, G, L, D = patches.shape # Align position & band embeddings across patches pos_encoding = repeat( @@ -824,3 +830,71 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int): return self.shared_step(batch, batch_idx, phase="val") + + def predict_step( + self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int + ) -> gpd.GeoDataFrame: + """ + Logic for the neural network's prediction loop. + """ + # Get image, bounding box, EPSG code, and date inputs + # x: torch.Tensor = batch["pixels"] # image of shape (1, 13, 512, 512) # BCHW + bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes + epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code + dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12'] + source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif'] + "source_url" + ] + + # Forward encoder + self.model.encoder.mask_ratio = 0.0 # disable masking + outputs_encoder: dict = self.model.encoder( + datacube=batch # input (pixels, timestep, latlon) + ) + + # Get embeddings generated from encoder + # (encoded_unmasked_patches, _, _, _) = outputs_encoder + embeddings_raw: torch.Tensor = outputs_encoder[0] + assert embeddings_raw.shape == torch.Size( + [self.model.encoder.B, 1538, 768] # (batch_size, seq_length, hidden_size) + ) + assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding + + # Take the mean of the embeddings along the sequence_length dimension + # excluding the last two latlon_ and time_ embeddings, i.e. compute + # mean over patch embeddings only + embeddings_mean: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1) + assert embeddings_mean.shape == torch.Size( + [self.model.encoder.B, 768] # (batch_size, hidden_size) + ) + + # Create table to store the embeddings with spatiotemporal metadata + unique_epsg_codes = set(int(epsg) for epsg in epsgs) + if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG + epsg: int = batch["epsg"][0] + else: + raise NotImplementedError( + f"More than 1 EPSG code detected: {unique_epsg_codes}" + ) + + gdf = gpd.GeoDataFrame( + data={ + "source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"), + "date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( + dtype="date32[day][pyarrow]" + ), + "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( + embeddings_mean.cpu().detach().__array__() + ), + }, + geometry=shapely.box( + xmin=bboxes[:, 0], + ymin=bboxes[:, 1], + xmax=bboxes[:, 2], + ymax=bboxes[:, 3], + ), + crs=f"EPSG:{epsg}", + ) + gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates + + return gdf diff --git a/src/model_vit.py b/src/model_vit.py index 46086dc2..4253baba 100644 --- a/src/model_vit.py +++ b/src/model_vit.py @@ -211,7 +211,7 @@ def predict_step( | s3://3C.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) | """ # Get image, bounding box, EPSG code, and date inputs - x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW + x: torch.Tensor = batch["image"] # image of shape (1, 13, 512, 512) # BCHW bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12'] From f19cf8f8522cbf508c735a3f8095867755551509 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 21 Dec 2023 11:48:50 +1300 Subject: [PATCH 2/6] :truck: Rename output file to {MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq The output GeoParquet file now has a filename with a format like "{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq", e.g. "12ABC_20210101_20231231_v001.gpq". Have implemented this in model_vit.py, and copied over the same `on_predict_epoch_end` method to model_clay.py. Also, we are no longer saving out the index column to the GeoParquet file. --- src/model_clay.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++ src/model_vit.py | 16 +++++++++++--- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/src/model_clay.py b/src/model_clay.py index b1dd8f0e..e6bde892 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -1,3 +1,6 @@ +import os +import re + import geopandas as gpd import lightning as L import numpy as np @@ -898,3 +901,56 @@ def predict_step( gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates return gdf + + def on_predict_epoch_end(self) -> gpd.GeoDataFrame: + """ + Logic to gather all the results from one epoch in a prediction loop. + """ + # Combine list of geopandas.GeoDataFrame objects + results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions + if results: + gdf: gpd.GeoDataFrame = pd.concat( + objs=results, axis="index", ignore_index=True + ) + else: + print( + "No embeddings generated, " + f"possibly no GeoTIFF files in {self.trainer.datamodule.data_dir}" + ) + return + + # Save embeddings in GeoParquet format, one file for each MGRS code + outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings" + os.makedirs(name=outfolder, exist_ok=True) + + # Find unique MGRS names (e.g. '12ABC'), e.g. + # from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC + mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1] + unique_mgrs_codes = mgrs_codes.unique() + for mgrs_code in unique_mgrs_codes: + if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None: + raise ValueError( + "MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), " + f"but got {mgrs_code} instead" + ) + + # Subset GeoDataFrame to a single MGRS code + _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code] + + # Get min/max date from GeoDataFrame + minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"]) + min_date: str = minmax_date["min"].strftime("%Y%m%d") + max_date: str = minmax_date["max"].strftime("%Y%m%d") + + # Output to a GeoParquet filename like + # {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq + outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq" + _gdf.to_parquet( + path=outpath, index=False, compression="ZSTD", schema_version="1.0.0" + ) + print( + f"Saved {len(_gdf)} rows of embeddings of " + f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" + ) + + return gdf diff --git a/src/model_vit.py b/src/model_vit.py index 4253baba..383315bf 100644 --- a/src/model_vit.py +++ b/src/model_vit.py @@ -319,10 +319,20 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame: f"but got {mgrs_code} instead" ) - # Output to a GeoParquet filename like {MGRS:5}_v{VERSION:2}.gpq - outpath = f"{outfolder}/{mgrs_code}_v01.gpq" + # Subset GeoDataFrame to a single MGRS code _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code] - _gdf.to_parquet(path=outpath, schema_version="1.0.0", compression="ZSTD") + + # Get min/max date from GeoDataFrame + minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"]) + min_date: str = minmax_date["min"].strftime("%Y%m%d") + max_date: str = minmax_date["max"].strftime("%Y%m%d") + + # Output to a GeoParquet filename like + # {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq + outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq" + _gdf.to_parquet( + path=outpath, index=False, compression="ZSTD", schema_version="1.0.0" + ) print( f"Saved {len(_gdf)} rows of embeddings of " f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" From 6030cf7c3e372e229daf24263fc7db4ff7296930 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 21 Dec 2023 14:11:36 +1300 Subject: [PATCH 3/6] :white_check_mark: Fix failing test by updating to new output filename Forgot to update the filename in the unit test to conform to the new `{MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq` format. Patches f19cf8f8522cbf508c735a3f8095867755551509. --- src/tests/test_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tests/test_model.py b/src/tests/test_model.py index ee17ecfe..19ae8b4e 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -35,7 +35,7 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe: [561415.0, 3397465.0, 563975.0, 3400025.0], ] ), - "date": ["2020-01-01", "2020-12-31", "2020-12-31"], + "date": ["2020-01-01", "2021-06-15", "2022-12-31"], "epsg": torch.tensor(data=[32760, 32760, 32760]), "source_url": [ "s3://claytile_60HTE_1.tif", @@ -76,8 +76,12 @@ def test_model_vit(datapipe): assert ( len(os.listdir(path=f"{tmpdirname}/data/embeddings")) == 2 # noqa: PLR2004 ) - assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60HTE_v01.gpq") - assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60GUV_v01.gpq") + assert os.path.exists( + path := f"{tmpdirname}/data/embeddings/60HTE_20200101_20200101_v001.gpq" + ) + assert os.path.exists( + path := f"{tmpdirname}/data/embeddings/60GUV_20210615_20221231_v001.gpq" + ) geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path) assert geodataframe.shape == (2, 4) # 2 rows, 4 columns From 98a39b730a6aa0b145ad65fb0060df84083ad286 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 21 Dec 2023 21:54:18 +1300 Subject: [PATCH 4/6] :white_check_mark: Parametrized test to check CLAYModule's predict loop Splitting the previous integration test on the neural network model into separate fit and predict unit tests. Only testing the prediction loop of CLAYModule, because training/validating the model might be too much for CPU-based Continuous Integration. Also for testing CLAYModule, we are using 32-true precision instead of bf16-mixed, because `torch.cat` doesn't work with float16 tensors on the CPU, see https://github.com/pytorch/pytorch/issues/100932 (should be fixed with Pytorch 2.2). --- src/tests/test_model.py | 45 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/tests/test_model.py b/src/tests/test_model.py index 19ae8b4e..989aad20 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -15,6 +15,7 @@ import torchdata import torchdata.dataloader2 +from src.model_clay import CLAYModule from src.model_vit import ViTLitModule @@ -27,7 +28,18 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe: datapipe = torchdata.datapipes.iter.IterableWrapper( iterable=[ { + # For ViTLitModule "image": torch.randn(3, 13, 512, 512).to(dtype=torch.float16), + # For CLAYModule + "pixels": torch.randn(3, 13, 512, 512).to(dtype=torch.float32), + "timestep": torch.tensor( + data=[(2020, 1, 1), (2021, 6, 15), (2022, 12, 31)], + dtype=torch.float32, + ), + "latlon": torch.tensor( + data=[(12, 34), (56, 78), (90, 100)], dtype=torch.float32 + ), + # For both "bbox": torch.tensor( data=[ [499975.0, 3397465.0, 502535.0, 3400025.0], @@ -49,9 +61,9 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe: # %% -def test_model_vit(datapipe): +def test_model_vit_fit(datapipe): """ - Run a full train, val, test and prediction loop using 1 batch. + Run a full train and validation loop using 1 batch. """ # Get some random data dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe) @@ -71,6 +83,35 @@ def test_model_vit(datapipe): ) trainer.fit(model=model, train_dataloaders=dataloader) + +@pytest.mark.parametrize( + "litmodule,precision", + [ + (CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"), + (ViTLitModule, "bf16-mixed"), + ], +) +def test_model_predict(datapipe, litmodule, precision): + """ + Run a single prediction loop using 1 batch. + """ + # Get some random data + dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe) + + # Initialize model + model: L.LightningModule = litmodule() + + # Run tests in a temporary folder + with tempfile.TemporaryDirectory() as tmpdirname: + # Training + trainer: L.Trainer = L.Trainer( + accelerator="auto", + devices="auto", + precision=precision, + fast_dev_run=True, + default_root_dir=tmpdirname, + ) + # Prediction trainer.predict(model=model, dataloaders=dataloader) assert ( From f1439e3f11915559d22feb8926d3224c15d932a8 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:53:25 +1300 Subject: [PATCH 5/6] :rewind: Save index column to GeoParquet file Decided that the index column might be good to keep for now, since it might help to speed up row counts? But we are resetting the index first before saving it. Partially reverts f19cf8f8522cbf508c735a3f8095867755551509. --- src/model_clay.py | 6 ++---- src/model_vit.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/model_clay.py b/src/model_clay.py index e6bde892..313e84f7 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -935,7 +935,7 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame: ) # Subset GeoDataFrame to a single MGRS code - _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code] + _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index() # Get min/max date from GeoDataFrame minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"]) @@ -945,9 +945,7 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame: # Output to a GeoParquet filename like # {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq" - _gdf.to_parquet( - path=outpath, index=False, compression="ZSTD", schema_version="1.0.0" - ) + _gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0") print( f"Saved {len(_gdf)} rows of embeddings of " f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" diff --git a/src/model_vit.py b/src/model_vit.py index 383315bf..be916b57 100644 --- a/src/model_vit.py +++ b/src/model_vit.py @@ -320,7 +320,7 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame: ) # Subset GeoDataFrame to a single MGRS code - _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code] + _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index() # Get min/max date from GeoDataFrame minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"]) @@ -330,9 +330,7 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame: # Output to a GeoParquet filename like # {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq" - _gdf.to_parquet( - path=outpath, index=False, compression="ZSTD", schema_version="1.0.0" - ) + _gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0") print( f"Saved {len(_gdf)} rows of embeddings of " f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" From 4c71db3beb94ce988fa1c2a94ad0245ad4d1fa42 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:50:48 +1300 Subject: [PATCH 6/6] :white_check_mark: Fix unit test to include index column After f1439e3f11915559d22feb8926d3224c15d932a8, need to ensure that the index column is checked in the output geodataframe. --- src/tests/test_model.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/tests/test_model.py b/src/tests/test_model.py index 989aad20..8a2dcfba 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -125,10 +125,15 @@ def test_model_predict(datapipe, litmodule, precision): ) geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path) - assert geodataframe.shape == (2, 4) # 2 rows, 4 columns - assert all( - geodataframe.columns == ["source_url", "date", "embeddings", "geometry"] - ) + assert geodataframe.shape == (2, 5) # 2 rows, 5 columns + assert list(geodataframe.columns) == [ + "index", + "source_url", + "date", + "embeddings", + "geometry", + ] + assert geodataframe.index.dtype == "int64" assert geodataframe.source_url.dtype == "string" assert geodataframe.date.dtype == "date32[day][pyarrow]" assert geodataframe.embeddings.dtype == "object"