-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate embeddings from CLAYModule trained with latlon/time encodings #96
Conversation
Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a767164 in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings.
The output GeoParquet file now has a filename with a format like "{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq", e.g. "12ABC_20210101_20231231_v001.gpq". Have implemented this in model_vit.py, and copied over the same `on_predict_epoch_end` method to model_clay.py. Also, we are no longer saving out the index column to the GeoParquet file.
Forgot to update the filename in the unit test to conform to the new `{MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq` format. Patches f19cf8f.
Splitting the previous integration test on the neural network model into separate fit and predict unit tests. Only testing the prediction loop of CLAYModule, because training/validating the model might be too much for CPU-based Continuous Integration. Also for testing CLAYModule, we are using 32-true precision instead of bf16-mixed, because `torch.cat` doesn't work with float16 tensors on the CPU, see pytorch/pytorch#100932 (should be fixed with Pytorch 2.2).
@pytest.mark.parametrize( | ||
"litmodule,precision", | ||
[ | ||
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some torch.cat
calls in CLAYModule that don't work when run on CPU with float16 tensors, see pytorch/pytorch#100932. The patch at pytorch/pytorch#96093 to fix this issue is merged already though, so we can remove this if-then statement in the future when Pytorch 2.2 is out. Note that running CLAYModule on CUDA-enabled GPUs should be fine with float16 or bfloat16.
Decided that the index column might be good to keep for now, since it might help to speed up row counts? But we are resetting the index first before saving it. Partially reverts f19cf8f.
After f1439e3, need to ensure that the index column is checked in the output geodataframe.
Still many things that could be improved, such as sharing duplicated code between model_vit.py and model_clay.py, but will merge in to |
What I am changing
{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq
(e.g.12ABC_20210101_20231231_v001.gpq
), following discussions at Rename embeddings file to include MGRS code and store GeoTIFF source_url #86 (comment)How I did it
In the LightningModule's
predict_step
, implement the logic to do the forward pass and save-to-gpq stepRaw embeddings are of shape (1, 1538, 768), and we take the mean of the patch embeddings (1, 1536, 768) which becomes a (1, 768) shape embedding
Sample output table would look like this (same as Rename embeddings file to include MGRS code and store GeoTIFF source_url #86):
TODO in this PR:
predict_step
to generategpd.GeoDataFrame
tableon_predict_epoch_end
to mergegpd.GeoDataFrame
tables and output to GeoParquet file(s)TODO in the future:
model_vit.py
andmodel_clay.py
?How you can test it
s3://clay-tiles-02/02/
s3://clay-model-ckpt/v0/mae_epoch-02_val-loss-0.52.ckpt
to thecheckpoints/
folder.This should produce a
48MYU_20180813_20210424_v001.gpq
file under thedata/embeddings/
folder. Sample file (need to unzip): 48MYU_20180813_20210424_v001.gpq.zipExtra configuration options can be found using
python trainer.py predict --help
To load the embeddings from the GeoParquet file:
Related Issues
Towards #3