Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MAE with support for position, time, latlon & channel embeddings #47

Merged
merged 32 commits into from
Dec 15, 2023

Conversation

srmsoumya
Copy link
Collaborator

@srmsoumya srmsoumya commented Nov 21, 2023

  • Add GeoViT from vit-pytorch that encodes position, time, latlon & channel embeddings
  • Add GeoMAE from vit-pytorch
  • Run a sample epoch to test the loss curve & training

@weiji14 weiji14 added the model-architecture Pull requests about the neural network model architecture label Nov 29, 2023
@weiji14 weiji14 modified the milestone: v0 Release Nov 29, 2023
srm_trainer.py Outdated
Comment on lines 23 to 35
dm = ClayDataModule(batch_size=128, num_workers=8)
model = GeoMAEModule(lr=1e-2)
trainer = L.Trainer(
fast_dev_run=False,
accelerator="gpu",
devices=1,
# precision="16-mixed",
max_epochs=50,
# log_every_n_steps=100,
logger=CSVLogger("logs", name="geomae"),
)

trainer.fit(model, dm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know once you've got the code to a stable state, and I can help with setting up the LightningCLI code to allow for different datamodules/models, following https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli_intermediate_2.html#multiple-lightningdatamodules. That way we can have a single trainer file and do:

python trainer.py fit --data=GeoTIFFDataPipeModule --model=ViTLitModule  # from PR52
python trainer.py fit --data=ClayDataModule --model=GeoMAEModule  # this PR47

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I initially had issues with the cuda 12.x set up on my machine & thought it might be because of Lightning CLI.
Yes, I would like to move to CLI once the experiment is working fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it an error like RuntimeError: GET was unable to find an engine to execute this computation? I've been getting that on my AWS EC2 G5 instance, and had to set export TORCH_CUDNN_V8_API_DISABLED=1 as a workaround from https://discuss.pytorch.org/t/segment-anything-model-unable-to-find-engine-to-execute-this-computation/179172.

Copy link
Contributor

@weiji14 weiji14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted some issues with how you're getting the coordinates from the GeoTIFF (there are UTM coordinates, not latlon or lonlat).

Heads up also that I've opened a small-ish PR to have the LightningDataModule return the spatiotemporal metdata (see #66), since we'll need to save the embeddings with some spatiotemporal metadata too.

src/dataset.py Outdated
Comment on lines 21 to 23
bounds = chip.bounds
centroid_x = (bounds.left + bounds.right) / 2
centroid_y = (bounds.top + bounds.bottom) / 2
Copy link
Contributor

@weiji14 weiji14 Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coordinates from chip.bounds are UTM coordinates and not latlon. You would need to reproject first, or use chip.lnglat (https://rasterio.readthedocs.io/en/latest/api/rasterio._base.html#rasterio._base.DatasetBase.lnglat).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an implementation to compute lat/lon centroids from bounds in the first tiler version, have a look

def bbox_centroid(bounds, epsg):

src/dataset.py Outdated
return {
"pixels": chip.read(),
"timestep": (year, month, day),
"latlon": (centroid_x, centroid_y),
Copy link
Contributor

@weiji14 weiji14 Dec 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tuple like (centroid_x, centroid_y) would be lonlat, not latlon. Longitude is X, latitude is Y.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @weiji14 swapping the values

@srmsoumya srmsoumya marked this pull request as ready for review December 6, 2023 13:23
@srmsoumya
Copy link
Collaborator Author

@weiji14 @lillythomas I ran the clay-small model locally & was able to get a loss curve similar to vanilla MAE implementations from FAIR & vit-pytorch. Here is the wandb log: https://wandb.ai/devseed/CLAY-v0/runs/zj0su9df

@brunosan
Copy link
Member

brunosan commented Dec 7, 2023

@weiji14 @lillythomas I ran the clay-small model locally & was able to get a loss curve similar to vanilla MAE implementations from FAIR & vit-pytorch. Here is the wandb log: https://wandb.ai/devseed/CLAY-v0/runs/zj0su9df

Link is 404 for me. Most probably not shared with me.

Retrieve the patch size value from the model architecture, rather than hardcoding as 32. Also ensure that the input image shape is the same as the predicted image from the decoder.
@weiji14
Copy link
Contributor

weiji14 commented Dec 14, 2023

Gonna tidy up a few things before merging in this PR. Specifically, moving code from:

  • callbacks.py -> callback_wandb.py (after #88 maybe)- logs matplotlib figures to WandB
  • srm_datamodule.py -> combine into datamodule.py. May need to recompute mean and std?
  • model.py + clay.py -> combine into single model_clay.py file
  • srm_trainer.py -> combine into trainer.py

Also, other things:

  • Check that all required dependencies are in environment.yml

Have one entrypoint to run the model using Lightning CLI. Switched model from VitLitModule to CLAYModule, and datamodule from GeoTIFFDataPipeModule to ClayDataModule. Temporarily disabling the logging and monitoring callbacks for now.
Putting the CLAYModule (LightningModule) together with the CLAY torch.nn.Module in a single model_clay.py file. Have mentioned in src/README.md that model_clay.py is the one with custom spatiotemporal encoders, while the previous model_vit.py contains vanilla Vision Transformer implementation.
Publication quality figures in Python!
Putting the DataLoader code in one file - datamodule.py. The regular torch Dataset classes are placed on top of the existing torchdata-based functions/classes.
Moving the LogIntermediatePredictions callback class from callbacks.py into callbacks_wandb.py.
Converted docstrings from numpydoc style which uses less horizontal space but more vertical space. Also added a noqa comment for three instances of `PLR0913 Too many arguments in function definition`.
@weiji14 weiji14 changed the title Implement MAE with support for position, time, latlon & channel embeddings. Implement MAE with support for position, time, latlon & channel embeddings Dec 15, 2023
Copy link
Contributor

@weiji14 weiji14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok @srmsoumya, I've cleaned up the code a little, and fixed all the lint errors. Would be nice to get some unit tests in at some point, but we can iterate on this later. Thanks again for implementing this Clay model for v0 🚀

@weiji14 weiji14 merged commit 46fd3a7 into main Dec 15, 2023
1 check passed
@weiji14 weiji14 deleted the vit-pytorch branch December 15, 2023 04:14
brunosan pushed a commit that referenced this pull request Dec 27, 2023
…dings (#47)

* Add modified ViT to encode latlon, time, channels & position embeddings

* Add MAE for modified ViT

* Add docstrings & fix issue with complex indexing

* Fix the comments on loss computation

* Add datamodule & trainer to run an epoch of training

* Normalize data before feeding to the model

* Add fixed sincos embedding for position & bands

* Add logging & ckpt options

* Fix the order of coords from lat,lon to lon,lat

* Add clay tiny,small,medium,large model versions

* Remove hardcoded patch size in LogIntermediatePredictions callback

Retrieve the patch size value from the model architecture, rather than hardcoding as 32. Also ensure that the input image shape is the same as the predicted image from the decoder.

* Run clay small on image size 512 for 10 epochs with grad_acc

* Make the clay construction configurable

* Return the data path to reference for vector embeddings

* Remove duplicate dataset.py & geovit.py

* 🔀 Merge srm_trainer.py into trainer.py

Have one entrypoint to run the model using Lightning CLI. Switched model from VitLitModule to CLAYModule, and datamodule from GeoTIFFDataPipeModule to ClayDataModule. Temporarily disabling the logging and monitoring callbacks for now.

* 🔀 Combine clay.py and model.py into model_clay

Putting the CLAYModule (LightningModule) together with the CLAY torch.nn.Module in a single model_clay.py file. Have mentioned in src/README.md that model_clay.py is the one with custom spatiotemporal encoders, while the previous model_vit.py contains vanilla Vision Transformer implementation.

* ➕ Add matplotlib-base

Publication quality figures in Python!

* 🚚 Move ClayDataset and ClayDataModule into datamodule.py

Putting the DataLoader code in one file - datamodule.py. The regular torch Dataset classes are placed on top of the existing torchdata-based functions/classes.

* 🚚 Move LogIntermediatePredictions callback into callbacks_wandb

Moving the LogIntermediatePredictions callback class from callbacks.py into callbacks_wandb.py.

* ♻️ Get WandB logger properly using a shared function

Getting the WandbLogger directly from the trainer, rather than having to pass it through __init__. Adapted from https://github.com/ashleve/lightning-hydra-template/blob/334601c0326a50ff301fbd76057b36408cf97ffa/src/callbacks/wandb_callbacks.py#L16C1-L34C6

* 🚨 Wrap docstring and fix too-many-arguments lint error

Converted docstrings from numpydoc style which uses less horizontal space but more vertical space. Also added a noqa comment for three instances of `PLR0913 Too many arguments in function definition`.

---------

Co-authored-by: SRM <soumya@developmentseed.org>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model-architecture Pull requests about the neural network model architecture
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants