-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate embeddings from CLAYModule trained with latlon/time encodings (
#96) * 🍻 Implement CLAYModule's predict_step to generate embeddings table Output embeddings to a geopandas.GeoDataFrame with columns 'source_url', 'date', 'embeddings', and 'geometry'. Essentially copying and adapting the code from a767164 in #73, but modifying how the encoder's masking is disabled, and how the mean/average of the embeddings is computed over a slice of the raw embeddings. * 🚚 Rename output file to {MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq The output GeoParquet file now has a filename with a format like "{MGRS:5}_{MINDATE:8}_{MAXDATE:8}_v{VERSION:3}.gpq", e.g. "12ABC_20210101_20231231_v001.gpq". Have implemented this in model_vit.py, and copied over the same `on_predict_epoch_end` method to model_clay.py. Also, we are no longer saving out the index column to the GeoParquet file. * ✅ Fix failing test by updating to new output filename Forgot to update the filename in the unit test to conform to the new `{MGRS}_{MINDATE}_{MAXDATE}_v{VERSION}.gpq` format. Patches f19cf8f. * ✅ Parametrized test to check CLAYModule's predict loop Splitting the previous integration test on the neural network model into separate fit and predict unit tests. Only testing the prediction loop of CLAYModule, because training/validating the model might be too much for CPU-based Continuous Integration. Also for testing CLAYModule, we are using 32-true precision instead of bf16-mixed, because `torch.cat` doesn't work with float16 tensors on the CPU, see pytorch/pytorch#100932 (should be fixed with Pytorch 2.2). * ⏪ Save index column to GeoParquet file Decided that the index column might be good to keep for now, since it might help to speed up row counts? But we are resetting the index first before saving it. Partially reverts f19cf8f. * ✅ Fix unit test to include index column After f1439e3, need to ensure that the index column is checked in the output geodataframe.
- Loading branch information
Showing
3 changed files
with
201 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters