-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement MAE with support for position, time, latlon & channel embeddings #47
Conversation
srmsoumya
commented
Nov 21, 2023
•
edited
Loading
edited
- Add GeoViT from vit-pytorch that encodes position, time, latlon & channel embeddings
- Add GeoMAE from vit-pytorch
- Run a sample epoch to test the loss curve & training
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
srm_trainer.py
Outdated
dm = ClayDataModule(batch_size=128, num_workers=8) | ||
model = GeoMAEModule(lr=1e-2) | ||
trainer = L.Trainer( | ||
fast_dev_run=False, | ||
accelerator="gpu", | ||
devices=1, | ||
# precision="16-mixed", | ||
max_epochs=50, | ||
# log_every_n_steps=100, | ||
logger=CSVLogger("logs", name="geomae"), | ||
) | ||
|
||
trainer.fit(model, dm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know once you've got the code to a stable state, and I can help with setting up the LightningCLI code to allow for different datamodules/models, following https://lightning.ai/docs/pytorch/2.1.0/cli/lightning_cli_intermediate_2.html#multiple-lightningdatamodules. That way we can have a single trainer file and do:
python trainer.py fit --data=GeoTIFFDataPipeModule --model=ViTLitModule # from PR52
python trainer.py fit --data=ClayDataModule --model=GeoMAEModule # this PR47
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I initially had issues with the cuda 12.x set up on my machine & thought it might be because of Lightning CLI.
Yes, I would like to move to CLI once the experiment is working fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it an error like RuntimeError: GET was unable to find an engine to execute this computation
? I've been getting that on my AWS EC2 G5 instance, and had to set export TORCH_CUDNN_V8_API_DISABLED=1
as a workaround from https://discuss.pytorch.org/t/segment-anything-model-unable-to-find-engine-to-execute-this-computation/179172.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spotted some issues with how you're getting the coordinates from the GeoTIFF (there are UTM coordinates, not latlon or lonlat).
Heads up also that I've opened a small-ish PR to have the LightningDataModule return the spatiotemporal metdata (see #66), since we'll need to save the embeddings with some spatiotemporal metadata too.
src/dataset.py
Outdated
bounds = chip.bounds | ||
centroid_x = (bounds.left + bounds.right) / 2 | ||
centroid_y = (bounds.top + bounds.bottom) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coordinates from chip.bounds
are UTM coordinates and not latlon. You would need to reproject first, or use chip.lnglat
(https://rasterio.readthedocs.io/en/latest/api/rasterio._base.html#rasterio._base.DatasetBase.lnglat).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an implementation to compute lat/lon centroids from bounds in the first tiler version, have a look
Line 97 in 9a41e3d
def bbox_centroid(bounds, epsg): |
src/dataset.py
Outdated
return { | ||
"pixels": chip.read(), | ||
"timestep": (year, month, day), | ||
"latlon": (centroid_x, centroid_y), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tuple like (centroid_x, centroid_y) would be lonlat, not latlon. Longitude is X, latitude is Y.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @weiji14 swapping the values
@weiji14 @lillythomas I ran the |
Link is 404 for me. Most probably not shared with me. |
Retrieve the patch size value from the model architecture, rather than hardcoding as 32. Also ensure that the input image shape is the same as the predicted image from the decoder.
for more information, see https://pre-commit.ci
Gonna tidy up a few things before merging in this PR. Specifically, moving code from:
Also, other things:
|
Have one entrypoint to run the model using Lightning CLI. Switched model from VitLitModule to CLAYModule, and datamodule from GeoTIFFDataPipeModule to ClayDataModule. Temporarily disabling the logging and monitoring callbacks for now.
Putting the CLAYModule (LightningModule) together with the CLAY torch.nn.Module in a single model_clay.py file. Have mentioned in src/README.md that model_clay.py is the one with custom spatiotemporal encoders, while the previous model_vit.py contains vanilla Vision Transformer implementation.
Publication quality figures in Python!
Putting the DataLoader code in one file - datamodule.py. The regular torch Dataset classes are placed on top of the existing torchdata-based functions/classes.
Moving the LogIntermediatePredictions callback class from callbacks.py into callbacks_wandb.py.
Getting the WandbLogger directly from the trainer, rather than having to pass it through __init__. Adapted from https://github.com/ashleve/lightning-hydra-template/blob/334601c0326a50ff301fbd76057b36408cf97ffa/src/callbacks/wandb_callbacks.py#L16C1-L34C6
Converted docstrings from numpydoc style which uses less horizontal space but more vertical space. Also added a noqa comment for three instances of `PLR0913 Too many arguments in function definition`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok @srmsoumya, I've cleaned up the code a little, and fixed all the lint errors. Would be nice to get some unit tests in at some point, but we can iterate on this later. Thanks again for implementing this Clay model for v0 🚀
…dings (#47) * Add modified ViT to encode latlon, time, channels & position embeddings * Add MAE for modified ViT * Add docstrings & fix issue with complex indexing * Fix the comments on loss computation * Add datamodule & trainer to run an epoch of training * Normalize data before feeding to the model * Add fixed sincos embedding for position & bands * Add logging & ckpt options * Fix the order of coords from lat,lon to lon,lat * Add clay tiny,small,medium,large model versions * Remove hardcoded patch size in LogIntermediatePredictions callback Retrieve the patch size value from the model architecture, rather than hardcoding as 32. Also ensure that the input image shape is the same as the predicted image from the decoder. * Run clay small on image size 512 for 10 epochs with grad_acc * Make the clay construction configurable * Return the data path to reference for vector embeddings * Remove duplicate dataset.py & geovit.py * 🔀 Merge srm_trainer.py into trainer.py Have one entrypoint to run the model using Lightning CLI. Switched model from VitLitModule to CLAYModule, and datamodule from GeoTIFFDataPipeModule to ClayDataModule. Temporarily disabling the logging and monitoring callbacks for now. * 🔀 Combine clay.py and model.py into model_clay Putting the CLAYModule (LightningModule) together with the CLAY torch.nn.Module in a single model_clay.py file. Have mentioned in src/README.md that model_clay.py is the one with custom spatiotemporal encoders, while the previous model_vit.py contains vanilla Vision Transformer implementation. * ➕ Add matplotlib-base Publication quality figures in Python! * 🚚 Move ClayDataset and ClayDataModule into datamodule.py Putting the DataLoader code in one file - datamodule.py. The regular torch Dataset classes are placed on top of the existing torchdata-based functions/classes. * 🚚 Move LogIntermediatePredictions callback into callbacks_wandb Moving the LogIntermediatePredictions callback class from callbacks.py into callbacks_wandb.py. * ♻️ Get WandB logger properly using a shared function Getting the WandbLogger directly from the trainer, rather than having to pass it through __init__. Adapted from https://github.com/ashleve/lightning-hydra-template/blob/334601c0326a50ff301fbd76057b36408cf97ffa/src/callbacks/wandb_callbacks.py#L16C1-L34C6 * 🚨 Wrap docstring and fix too-many-arguments lint error Converted docstrings from numpydoc style which uses less horizontal space but more vertical space. Also added a noqa comment for three instances of `PLR0913 Too many arguments in function definition`. --------- Co-authored-by: SRM <soumya@developmentseed.org> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>