Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate embeddings from CLAYModule trained with latlon/time encodings #96

Merged
merged 7 commits into from
Jan 12, 2024
130 changes: 129 additions & 1 deletion src/model_clay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import os
import re

import geopandas as gpd
import lightning as L
import numpy as np
import pandas as pd
import pyarrow as pa
import shapely
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
Expand All @@ -10,6 +18,7 @@
torch.set_float32_matmul_precision(precision="medium")


# %%
class Patchify(nn.Module):
"""
Patchify the input cube & create embeddings per patch
Expand Down Expand Up @@ -154,7 +163,7 @@ def add_encodings(self, patches):
A tensor of shape (B, G, L, D) containing the embeddings of the
patches + position & band encoding.
"""
B, G, L, D = patches.shape
self.B, G, L, D = patches.shape

# Align position & band embeddings across patches
pos_encoding = repeat(
Expand Down Expand Up @@ -842,3 +851,122 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int):

def validation_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
return self.shared_step(batch, batch_idx, phase="val")

def predict_step(
self, batch: dict[str, torch.Tensor | list[str]], batch_idx: int
) -> gpd.GeoDataFrame:
"""
Logic for the neural network's prediction loop.
"""
# Get image, bounding box, EPSG code, and date inputs
# x: torch.Tensor = batch["pixels"] # image of shape (1, 13, 512, 512) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']
source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif']
"source_url"
]

# Forward encoder
self.model.encoder.mask_ratio = 0.0 # disable masking
outputs_encoder: dict = self.model.encoder(
datacube=batch # input (pixels, timestep, latlon)
)

# Get embeddings generated from encoder
# (encoded_unmasked_patches, _, _, _) = outputs_encoder
embeddings_raw: torch.Tensor = outputs_encoder[0]
assert embeddings_raw.shape == torch.Size(
[self.model.encoder.B, 1538, 768] # (batch_size, seq_length, hidden_size)
)
assert not torch.isnan(embeddings_raw).any() # ensure no NaNs in embedding

# Take the mean of the embeddings along the sequence_length dimension
# excluding the last two latlon_ and time_ embeddings, i.e. compute
# mean over patch embeddings only
embeddings_mean: torch.Tensor = embeddings_raw[:, :-2, :].mean(dim=1)
assert embeddings_mean.shape == torch.Size(
[self.model.encoder.B, 768] # (batch_size, hidden_size)
)

# Create table to store the embeddings with spatiotemporal metadata
unique_epsg_codes = set(int(epsg) for epsg in epsgs)
if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG
epsg: int = batch["epsg"][0]
else:
raise NotImplementedError(
f"More than 1 EPSG code detected: {unique_epsg_codes}"
)

gdf = gpd.GeoDataFrame(
data={
"source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"),
"date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
embeddings_mean.cpu().detach().__array__()
),
},
geometry=shapely.box(
xmin=bboxes[:, 0],
ymin=bboxes[:, 1],
xmax=bboxes[:, 2],
ymax=bboxes[:, 3],
),
crs=f"EPSG:{epsg}",
)
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates

return gdf

def on_predict_epoch_end(self) -> gpd.GeoDataFrame:
"""
Logic to gather all the results from one epoch in a prediction loop.
"""
# Combine list of geopandas.GeoDataFrame objects
results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions
if results:
gdf: gpd.GeoDataFrame = pd.concat(
objs=results, axis="index", ignore_index=True
)
else:
print(
"No embeddings generated, "
f"possibly no GeoTIFF files in {self.trainer.datamodule.data_dir}"
)
return

# Save embeddings in GeoParquet format, one file for each MGRS code
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings"
os.makedirs(name=outfolder, exist_ok=True)

# Find unique MGRS names (e.g. '12ABC'), e.g.
# from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC
mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1]
unique_mgrs_codes = mgrs_codes.unique()
for mgrs_code in unique_mgrs_codes:
if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None:
raise ValueError(
"MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), "
f"but got {mgrs_code} instead"
)

# Subset GeoDataFrame to a single MGRS code
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index()

# Get min/max date from GeoDataFrame
minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"])
min_date: str = minmax_date["min"].strftime("%Y%m%d")
max_date: str = minmax_date["max"].strftime("%Y%m%d")

# Output to a GeoParquet filename like
# {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq
outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq"
_gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
print(
f"Saved {len(_gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
)

return gdf
18 changes: 13 additions & 5 deletions src/model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def predict_step(
| s3://3C.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
"""
# Get image, bounding box, EPSG code, and date inputs
x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW
x: torch.Tensor = batch["image"] # image of shape (1, 13, 512, 512) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']
Expand Down Expand Up @@ -319,10 +319,18 @@ def on_predict_epoch_end(self) -> gpd.GeoDataFrame:
f"but got {mgrs_code} instead"
)

# Output to a GeoParquet filename like {MGRS:5}_v{VERSION:2}.gpq
outpath = f"{outfolder}/{mgrs_code}_v01.gpq"
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code]
_gdf.to_parquet(path=outpath, schema_version="1.0.0", compression="ZSTD")
# Subset GeoDataFrame to a single MGRS code
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code].reset_index()

# Get min/max date from GeoDataFrame
minmax_date: pd.Series = _gdf.date.agg(func=["min", "max"])
min_date: str = minmax_date["min"].strftime("%Y%m%d")
max_date: str = minmax_date["max"].strftime("%Y%m%d")

# Output to a GeoParquet filename like
# {MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq
outpath = f"{outfolder}/{mgrs_code}_{min_date}_{max_date}_v001.gpq"
_gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
print(
f"Saved {len(_gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
Expand Down
68 changes: 59 additions & 9 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torchdata
import torchdata.dataloader2

from src.model_clay import CLAYModule
from src.model_vit import ViTLitModule


Expand All @@ -27,15 +28,26 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:
datapipe = torchdata.datapipes.iter.IterableWrapper(
iterable=[
{
# For ViTLitModule
"image": torch.randn(3, 13, 512, 512).to(dtype=torch.float16),
# For CLAYModule
"pixels": torch.randn(3, 13, 512, 512).to(dtype=torch.float32),
"timestep": torch.tensor(
data=[(2020, 1, 1), (2021, 6, 15), (2022, 12, 31)],
dtype=torch.float32,
),
"latlon": torch.tensor(
data=[(12, 34), (56, 78), (90, 100)], dtype=torch.float32
),
# For both
"bbox": torch.tensor(
data=[
[499975.0, 3397465.0, 502535.0, 3400025.0],
[530695.0, 3397465.0, 533255.0, 3400025.0],
[561415.0, 3397465.0, 563975.0, 3400025.0],
]
),
"date": ["2020-01-01", "2020-12-31", "2020-12-31"],
"date": ["2020-01-01", "2021-06-15", "2022-12-31"],
"epsg": torch.tensor(data=[32760, 32760, 32760]),
"source_url": [
"s3://claytile_60HTE_1.tif",
Expand All @@ -49,9 +61,9 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:


# %%
def test_model_vit(datapipe):
def test_model_vit_fit(datapipe):
"""
Run a full train, val, test and prediction loop using 1 batch.
Run a full train and validation loop using 1 batch.
"""
# Get some random data
dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe)
Expand All @@ -71,19 +83,57 @@ def test_model_vit(datapipe):
)
trainer.fit(model=model, train_dataloaders=dataloader)


@pytest.mark.parametrize(
"litmodule,precision",
[
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some torch.cat calls in CLAYModule that don't work when run on CPU with float16 tensors, see pytorch/pytorch#100932. The patch at pytorch/pytorch#96093 to fix this issue is merged already though, so we can remove this if-then statement in the future when Pytorch 2.2 is out. Note that running CLAYModule on CUDA-enabled GPUs should be fine with float16 or bfloat16.

(ViTLitModule, "bf16-mixed"),
],
)
def test_model_predict(datapipe, litmodule, precision):
"""
Run a single prediction loop using 1 batch.
"""
# Get some random data
dataloader = torchdata.dataloader2.DataLoader2(datapipe=datapipe)

# Initialize model
model: L.LightningModule = litmodule()

# Run tests in a temporary folder
with tempfile.TemporaryDirectory() as tmpdirname:
# Training
trainer: L.Trainer = L.Trainer(
accelerator="auto",
devices="auto",
precision=precision,
fast_dev_run=True,
default_root_dir=tmpdirname,
)

# Prediction
trainer.predict(model=model, dataloaders=dataloader)
assert (
len(os.listdir(path=f"{tmpdirname}/data/embeddings")) == 2 # noqa: PLR2004
)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60HTE_v01.gpq")
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60GUV_v01.gpq")
assert os.path.exists(
path := f"{tmpdirname}/data/embeddings/60HTE_20200101_20200101_v001.gpq"
)
assert os.path.exists(
path := f"{tmpdirname}/data/embeddings/60GUV_20210615_20221231_v001.gpq"
)
geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path)

assert geodataframe.shape == (2, 4) # 2 rows, 4 columns
assert all(
geodataframe.columns == ["source_url", "date", "embeddings", "geometry"]
)
assert geodataframe.shape == (2, 5) # 2 rows, 5 columns
assert list(geodataframe.columns) == [
"index",
"source_url",
"date",
"embeddings",
"geometry",
]
assert geodataframe.index.dtype == "int64"
assert geodataframe.source_url.dtype == "string"
assert geodataframe.date.dtype == "date32[day][pyarrow]"
assert geodataframe.embeddings.dtype == "object"
Expand Down