Skip to content

Commit

Permalink
Generate embeddings from CLAYModule trained with latlon/time encodings (
Browse files Browse the repository at this point in the history
#96)

* 🍻 Implement CLAYModule's predict_step to generate embeddings table

Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a767164 in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings.

* 🚚 Rename output file to {MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq

The output GeoParquet file now has a filename with a format like "{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq", e.g. "12ABC_20210101_20231231_v001.gpq". Have implemented this in model_vit.py, and copied over the same `on_predict_epoch_end` method to model_clay.py. Also, we are no longer saving out the index column to the GeoParquet file.

* ✅ Fix failing test by updating to new output filename

Forgot to update the filename in the unit test to conform to the new `{MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq` format. Patches f19cf8f.

* ✅ Parametrized test to check CLAYModule's predict loop

Splitting the previous integration test on the neural network model into separate fit and predict unit tests. Only testing the prediction loop of CLAYModule, because training/validating the model might be too much for CPU-based Continuous Integration. Also for testing CLAYModule, we are using 32-true precision instead of bf16-mixed, because `torch.cat` doesn't work with float16 tensors on the CPU, see pytorch/pytorch#100932 (should be fixed with Pytorch 2.2).

* ⏪ Save index column to GeoParquet file

Decided that the index column might be good to keep for now, since it might help to speed up row counts? But we are resetting the index first before saving it. Partially reverts f19cf8f.

* ✅ Fix unit test to include index column

After f1439e3, need to ensure that the index column is checked in the output geodataframe.
  • Loading branch information
weiji14 authored Jan 12, 2024
1 parent de1556b commit 7082c54
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 15 deletions.
130 changes: 129 additions & 1 deletion src/model_clay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import os
import re

import geopandas as gpd
import lightning as L
import numpy as np
import pandas as pd
import pyarrow as pa
import shapely
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
Expand All @@ -10,6 +18,7 @@
torch.set_float32_matmul_precision(precision="medium")


# %%
class Patchify(nn.Module):
"""
Patchify the input cube & create embeddings per patch
Expand Down Expand Up @@ -154,7 +163,7 @@ def add_encodings(self, patches):
A tensor of shape (B, G, L, D) containing the embeddings of the
patches + position & band encoding.
"""
B, G, L, D = patches.shape
self.B, G, L, D = patches.shape

# Align position & band embeddings across patches
pos_encoding = repeat(
Expand Down Expand Up @@ -842,3 +851,122 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int):

def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
return self.shared_step(batch, batch_idx, phase="val")

def predict_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> gpd.GeoDataFrame:
"""
Logic for the neural network's prediction loop.
"""
# Get image, bounding box, EPSG code, and date inputs
# x: torch.Tensor = batch["pixels"] # image of shape (1, 13, 512, 512) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']
source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif']
"source_url"
]

# Forward encoder
self.model.encoder.mask_ratio = 0.0 # disable masking
outputs_encoder: dict = self.model.encoder(
datacube=batch # input (pixels, timestep, latlon)
)

# Get embeddings generated from encoder
# (encoded_unmasked_patches, _, _, _) = outputs_encoder
embeddings_raw: torch.Tensor = outputs_encoder[0]
assert embeddings_raw.shape == torch.Size(
[self.model.encoder.B, 1538, 768] # (batch_size, seq_length, hidden_size)
)
assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding

# Take the mean of the embeddings along the sequence_length dimension
# excluding the last two latlon_ and time_ embeddings, i.e. compute
# mean over patch embeddings only
embeddings_mean: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1)
assert embeddings_mean.shape == torch.Size(
[self.model.encoder.B, 768] # (batch_size, hidden_size)
)

# Create table to store the embeddings with spatiotemporal metadata
unique_epsg_codes = set(int(epsg) for epsg in epsgs)
if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG
epsg: int = batch["epsg"][0]
else:
raise NotImplementedError(
f"More than 1 EPSG code detected: {unique_epsg_codes}"
)

gdf = gpd.GeoDataFrame(
data={
"source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"),
"date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
),
},
geometry=shapely.box(
xmin=bboxes[:, 0],
ymin=bboxes[:, 1],
xmax=bboxes[:, 2],
ymax=bboxes[:, 3],
),
crs=f"EPSG:{epsg}",
)
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates

return gdf

def on_predict_epoch_end(self) -> gpd.GeoDataFrame:
"""
Logic to gather all the results from one epoch in a prediction loop.
"""
# Combine list of geopandas.GeoDataFrame objects
results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions
if results:
gdf: gpd.GeoDataFrame = pd.concat(
objs=results, axis="index", ignore_index=True
)
else:
print(
"No embeddings generated, "
f"possibly no GeoTIFF files in {self.trainer.datamodule.data_dir}"
)
return

# Save embeddings in GeoParquet format, one file for each MGRS code
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings"
os.makedirs(name=outfolder, exist_ok=True)

# Find unique MGRS names (e.g. '12ABC'), e.g.
# from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC
mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1]
unique_mgrs_codes = mgrs_codes.unique()
for mgrs_code in unique_mgrs_codes:
if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None:
raise ValueError(
"MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), "
f"but got {mgrs_code} instead"
)

# Subset GeoDataFrame to a single MGRS code
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index()

# Get min/max date from GeoDataFrame
minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"])
min_date: str = minmax_date["min"].strftime("%Y%m%d")
max_date: str = minmax_date["max"].strftime("%Y%m%d")

# Output to a GeoParquet filename like
# {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq
outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq"
_gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
print(
f"Saved {len(_gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
)

return gdf
18 changes: 13 additions & 5 deletions src/model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def predict_step(
| s3://3C.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
"""
# Get image, bounding box, EPSG code, and date inputs
x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW
x: torch.Tensor = batch["image"] # image of shape (1, 13, 512, 512) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']
Expand Down Expand Up @@ -319,10 +319,18 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame:
f"but got {mgrs_code} instead"
)

# Output to a GeoParquet filename like {MGRS:5}_v{VERSION:2}.gpq
outpath = f"{outfolder}/{mgrs_code}_v01.gpq"
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code]
_gdf.to_parquet(path=outpath, schema_version="1.0.0", compression="ZSTD")
# Subset GeoDataFrame to a single MGRS code
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index()

# Get min/max date from GeoDataFrame
minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"])
min_date: str = minmax_date["min"].strftime("%Y%m%d")
max_date: str = minmax_date["max"].strftime("%Y%m%d")

# Output to a GeoParquet filename like
# {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq
outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq"
_gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
print(
f"Saved {len(_gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
Expand Down
68 changes: 59 additions & 9 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torchdata
import torchdata.dataloader2

from src.model_clay import CLAYModule
from src.model_vit import ViTLitModule


Expand All @@ -27,15 +28,26 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:
datapipe = torchdata.datapipes.iter.IterableWrapper(
iterable=[
{
# For ViTLitModule
"image": torch.randn(3, 13, 512, 512).to(dtype=torch.float16),
# For CLAYModule
"pixels": torch.randn(3, 13, 512, 512).to(dtype=torch.float32),
"timestep": torch.tensor(
data=[(2020, 1, 1), (2021, 6, 15), (2022, 12, 31)],
dtype=torch.float32,
),
"latlon": torch.tensor(
data=[(12, 34), (56, 78), (90, 100)], dtype=torch.float32
),
# For both
"bbox": torch.tensor(
data=[
[499975.0, 3397465.0, 502535.0, 3400025.0],
[530695.0, 3397465.0, 533255.0, 3400025.0],
[561415.0, 3397465.0, 563975.0, 3400025.0],
]
),
"date": ["2020-01-01", "2020-12-31", "2020-12-31"],
"date": ["2020-01-01", "2021-06-15", "2022-12-31"],
"epsg": torch.tensor(data=[32760, 32760, 32760]),
"source_url": [
"s3://claytile_60HTE_1.tif",
Expand All @@ -49,9 +61,9 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:


# %%
def test_model_vit(datapipe):
def test_model_vit_fit(datapipe):
"""
Run a full train, val, test and prediction loop using 1 batch.
Run a full train and validation loop using 1 batch.
"""
# Get some random data
dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe)
Expand All @@ -71,19 +83,57 @@ def test_model_vit(datapipe):
)
trainer.fit(model=model, train_dataloaders=dataloader)


@pytest.mark.parametrize(
"litmodule,precision",
[
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"),
(ViTLitModule, "bf16-mixed"),
],
)
def test_model_predict(datapipe, litmodule, precision):
"""
Run a single prediction loop using 1 batch.
"""
# Get some random data
dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe)

# Initialize model
model: L.LightningModule = litmodule()

# Run tests in a temporary folder
with tempfile.TemporaryDirectory() as tmpdirname:
# Training
trainer: L.Trainer = L.Trainer(
accelerator="auto",
devices="auto",
precision=precision,
fast_dev_run=True,
default_root_dir=tmpdirname,
)

# Prediction
trainer.predict(model=model, dataloaders=dataloader)
assert (
len(os.listdir(path=f"{tmpdirname}/data/embeddings")) == 2 # noqa: PLR2004
)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60HTE_v01.gpq")
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60GUV_v01.gpq")
assert os.path.exists(
path := f"{tmpdirname}/data/embeddings/60HTE_20200101_20200101_v001.gpq"
)
assert os.path.exists(
path := f"{tmpdirname}/data/embeddings/60GUV_20210615_20221231_v001.gpq"
)
geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path)

assert geodataframe.shape == (2, 4) # 2 rows, 4 columns
assert all(
geodataframe.columns == ["source_url", "date", "embeddings", "geometry"]
)
assert geodataframe.shape == (2, 5) # 2 rows, 5 columns
assert list(geodataframe.columns) == [
"index",
"source_url",
"date",
"embeddings",
"geometry",
]
assert geodataframe.index.dtype == "int64"
assert geodataframe.source_url.dtype == "string"
assert geodataframe.date.dtype == "date32[day][pyarrow]"
assert geodataframe.embeddings.dtype == "object"
Expand Down

0 comments on commit 7082c54

Please sign in to comment.