Skip to content

Commit 46fd3a7

Browse files
srmsoumyaSRMpre-commit-ci[bot]weiji14
authored
Implement MAE with support for position, time, latlon & channel embeddings (#47)
* Add modified ViT to encode latlon, time, channels & position embeddings * Add MAE for modified ViT * Add docstrings & fix issue with complex indexing * Fix the comments on loss computation * Add datamodule & trainer to run an epoch of training * Normalize data before feeding to the model * Add fixed sincos embedding for position & bands * Add logging & ckpt options * Fix the order of coords from lat,lon to lon,lat * Add clay tiny,small,medium,large model versions * Remove hardcoded patch size in LogIntermediatePredictions callback Retrieve the patch size value from the model architecture, rather than hardcoding as 32. Also ensure that the input image shape is the same as the predicted image from the decoder. * Run clay small on image size 512 for 10 epochs with grad_acc * Make the clay construction configurable * Return the data path to reference for vector embeddings * Remove duplicate dataset.py & geovit.py * 🔀 Merge srm_trainer.py into trainer.py Have one entrypoint to run the model using Lightning CLI. Switched model from VitLitModule to CLAYModule, and datamodule from GeoTIFFDataPipeModule to ClayDataModule. Temporarily disabling the logging and monitoring callbacks for now. * 🔀 Combine clay.py and model.py into model_clay Putting the CLAYModule (LightningModule) together with the CLAY torch.nn.Module in a single model_clay.py file. Have mentioned in src/README.md that model_clay.py is the one with custom spatiotemporal encoders, while the previous model_vit.py contains vanilla Vision Transformer implementation. * ➕ Add matplotlib-base Publication quality figures in Python! * 🚚 Move ClayDataset and ClayDataModule into datamodule.py Putting the DataLoader code in one file - datamodule.py. The regular torch Dataset classes are placed on top of the existing torchdata-based functions/classes. * 🚚 Move LogIntermediatePredictions callback into callbacks_wandb Moving the LogIntermediatePredictions callback class from callbacks.py into callbacks_wandb.py. * ♻️ Get WandB logger properly using a shared function Getting the WandbLogger directly from the trainer, rather than having to pass it through __init__. Adapted from https://github.com/ashleve/lightning-hydra-template/blob/334601c0326a50ff301fbd76057b36408cf97ffa/src/callbacks/wandb_callbacks.py#L16C1-L34C6 * 🚨 Wrap docstring and fix too-many-arguments lint error Converted docstrings from numpydoc style which uses less horizontal space but more vertical space. Also added a noqa comment for three instances of `PLR0913 Too many arguments in function definition`. --------- Co-authored-by: SRM <soumya@developmentseed.org> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>
1 parent e259165 commit 46fd3a7

File tree

8 files changed

+1318
-53
lines changed

8 files changed

+1318
-53
lines changed

conda-lock.yml

Lines changed: 189 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
version: 1
1414
metadata:
1515
content_hash:
16-
linux-64: b7b524b1ba497b9648707a5bfb9c9dd4abb412a6b07af7b0472985be7ba39abe
16+
linux-64: ee11aee8049ddee9b71ce4ec66285c84b8d915535c1a19e608f515e8ac97f7a9
1717
channels:
1818
- url: conda-forge
1919
used_env_vars: []
@@ -529,33 +529,33 @@ package:
529529
category: main
530530
optional: false
531531
- name: boto3
532-
version: 1.33.13
532+
version: 1.34.1
533533
manager: conda
534534
platform: linux-64
535535
dependencies:
536-
botocore: '>=1.33.13,<1.34.0'
536+
botocore: '>=1.34.1,<1.35.0'
537537
jmespath: '>=0.7.1,<2.0.0'
538-
python: '>=3.7'
539-
s3transfer: '>=0.8.2,<0.9.0'
540-
url: https://conda.anaconda.org/conda-forge/noarch/boto3-1.33.13-pyhd8ed1ab_0.conda
538+
python: '>=3.8'
539+
s3transfer: '>=0.9.0,<0.10.0'
540+
url: https://conda.anaconda.org/conda-forge/noarch/boto3-1.34.1-pyhd8ed1ab_0.conda
541541
hash:
542-
md5: 4a1b38a0938b9fc23fb4fc202d832097
543-
sha256: 01f797f967ac92346a5bffdae165c5d4bacda8405b828fcd58534a0f82287f76
542+
md5: 506cb54f3b548a2d372d860b591afcd6
543+
sha256: 9a9a97588d5d4a526ce76fece8d1c9a9870047c2a9a04ca5600a6c687378e207
544544
category: main
545545
optional: false
546546
- name: botocore
547-
version: 1.33.13
547+
version: 1.34.1
548548
manager: conda
549549
platform: linux-64
550550
dependencies:
551551
jmespath: '>=0.7.1,<2.0.0'
552-
python: '>=3.7'
552+
python: '>=3.8'
553553
python-dateutil: '>=2.1,<3.0.0'
554554
urllib3: '>=1.25.4,<1.27'
555-
url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.33.13-pyhd8ed1ab_0.conda
555+
url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.34.1-pyhd8ed1ab_0.conda
556556
hash:
557-
md5: d2566fd9134b6f8b8e69a07e9a1fa17e
558-
sha256: 498d08274880ef279e9a6dd68f66b384d71321d92301fa60330712f9edceab0c
557+
md5: b60f64c4bef6aee697f8291bfffa42b2
558+
sha256: 73a1348256e47aa620b014b3a7e3a0de29f8adfff97ded965fa4e05dd9a8d7e2
559559
category: main
560560
optional: false
561561
- name: brotli
@@ -981,6 +981,22 @@ package:
981981
sha256: c6fc314161263f031eb23ac53868e0d9b0242efe669e176901effdac4bd87376
982982
category: main
983983
optional: false
984+
- name: contourpy
985+
version: 1.2.0
986+
manager: conda
987+
platform: linux-64
988+
dependencies:
989+
libgcc-ng: '>=12'
990+
libstdcxx-ng: '>=12'
991+
numpy: '>=1.20,<2'
992+
python: '>=3.11,<3.12.0a0'
993+
python_abi: 3.11.*
994+
url: https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py311h9547e67_0.conda
995+
hash:
996+
md5: 40828c5b36ef52433e21f89943e09f33
997+
sha256: 2c76e2a970b74eef92ef9460aa705dbdc506dd59b7382bfbedce39d9c189d7f4
998+
category: main
999+
optional: false
9841000
- name: crashtest
9851001
version: 0.4.1
9861002
manager: conda
@@ -1032,10 +1048,10 @@ package:
10321048
cuda-version: '>=12.0,<12.1.0a0'
10331049
libgcc-ng: '>=12'
10341050
libstdcxx-ng: '>=12'
1035-
url: https://conda.anaconda.org/conda-forge/linux-64/cuda-cudart-12.0.107-hd3aeb46_7.conda
1051+
url: https://conda.anaconda.org/conda-forge/linux-64/cuda-cudart-12.0.107-hd3aeb46_8.conda
10361052
hash:
1037-
md5: def0e966c7fad7e13f8840b1cdb92dbd
1038-
sha256: bc953b78fa1b0b989da5be01a3f8adc4aa80bc7eaa8a2f08850f5558871b6654
1053+
md5: 3a747a8e83767681eba39d7b8957e5fb
1054+
sha256: 5680d6d1ab9d61d28319b8b1f3b5f29fc9e55c0905051dd4277a662fceff0e59
10391055
category: main
10401056
optional: false
10411057
- name: cuda-cudart_linux-64
@@ -1044,10 +1060,10 @@ package:
10441060
platform: linux-64
10451061
dependencies:
10461062
cuda-version: '>=12.0,<12.1.0a0'
1047-
url: https://conda.anaconda.org/conda-forge/noarch/cuda-cudart_linux-64-12.0.107-h59595ed_7.conda
1063+
url: https://conda.anaconda.org/conda-forge/noarch/cuda-cudart_linux-64-12.0.107-h59595ed_8.conda
10481064
hash:
1049-
md5: 36de472d13cff5e19543998deb4a9093
1050-
sha256: 01eb5aba8324fbf99fb9eac3ea3423bc3a0a9ecf42e40da27aa3aea557c33cdf
1065+
md5: 839830d4d2b02950bf6ccf18d252d69f
1066+
sha256: df6b88e0bc323da65b82f0658e2fa5722ac81b251abfa98291c3b409f56ec253
10511067
category: main
10521068
optional: false
10531069
- name: cuda-nvrtc
@@ -1108,6 +1124,18 @@ package:
11081124
sha256: 0cd17a35182f3d35b6f7813a345f430814113fdf23417c37bdff4c310a4ce03b
11091125
category: main
11101126
optional: false
1127+
- name: cycler
1128+
version: 0.12.1
1129+
manager: conda
1130+
platform: linux-64
1131+
dependencies:
1132+
python: '>=3.8'
1133+
url: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda
1134+
hash:
1135+
md5: 5cd86562580f274031ede6aa6aa24441
1136+
sha256: f221233f21b1d06971792d491445fd548224641af9443739b4b7b6d5d72954a8
1137+
category: main
1138+
optional: false
11111139
- name: dask-core
11121140
version: 2023.12.0
11131141
manager: conda
@@ -1324,6 +1352,18 @@ package:
13241352
sha256: 190dbafc9e699f74cf8d287e91246acac1e14afda8ce6aedafac87e392e1bc96
13251353
category: main
13261354
optional: false
1355+
- name: einops
1356+
version: 0.7.0
1357+
manager: conda
1358+
platform: linux-64
1359+
dependencies:
1360+
python: '>=3.8'
1361+
url: https://conda.anaconda.org/conda-forge/noarch/einops-0.7.0-pyhd8ed1ab_1.conda
1362+
hash:
1363+
md5: 1641890c9375ddb22381f3eb9ac157df
1364+
sha256: cc08bb969a4458b7afd48e7ba8151c95b48f1c315d3567644ed4a97ee2987247
1365+
category: main
1366+
optional: false
13271367
- name: ensureconda
13281368
version: 1.4.3
13291369
manager: conda
@@ -1541,6 +1581,22 @@ package:
15411581
sha256: 53f23a3319466053818540bcdf2091f253cbdbab1e0e9ae7b9e509dcaa2a5e38
15421582
category: main
15431583
optional: false
1584+
- name: fonttools
1585+
version: 4.46.0
1586+
manager: conda
1587+
platform: linux-64
1588+
dependencies:
1589+
brotli: ''
1590+
libgcc-ng: '>=12'
1591+
munkres: ''
1592+
python: '>=3.11,<3.12.0a0'
1593+
python_abi: 3.11.*
1594+
url: https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.46.0-py311h459d7ec_0.conda
1595+
hash:
1596+
md5: a14114f70e23f7fd5ab9941fec45b095
1597+
sha256: a40f8415d9ceaf5f217034814b984d13017e4dab577085a83a2d0cc39b9d7239
1598+
category: main
1599+
optional: false
15441600
- name: fqdn
15451601
version: 1.5.1
15461602
manager: conda
@@ -1964,17 +2020,17 @@ package:
19642020
category: main
19652021
optional: false
19662022
- name: imageio
1967-
version: 2.31.5
2023+
version: 2.33.1
19682024
manager: conda
19692025
platform: linux-64
19702026
dependencies:
19712027
numpy: ''
19722028
pillow: '>=8.3.2'
19732029
python: '>=3'
1974-
url: https://conda.anaconda.org/conda-forge/noarch/imageio-2.31.5-pyh8c1a49c_0.conda
2030+
url: https://conda.anaconda.org/conda-forge/noarch/imageio-2.33.1-pyh8c1a49c_0.conda
19752031
hash:
1976-
md5: 6820ccf6a3a27df348f18c85dd89014a
1977-
sha256: 0554fbf2136a1ab380551963c5884941f7852034cbe40f002ae040e10e457365
2032+
md5: 1c34d58ac469a34e7e96832861368bce
2033+
sha256: 0565f3666de4d02a83c5c8e14b7d878c382720f84318d6ce1ff83b66603c01d7
19782034
category: main
19792035
optional: false
19802036
- name: importlib-metadata
@@ -2543,6 +2599,21 @@ package:
25432599
sha256: 150c05a6e538610ca7c43beb3a40d65c90537497a4f6a5f4d15ec0451b6f5ebb
25442600
category: main
25452601
optional: false
2602+
- name: kiwisolver
2603+
version: 1.4.5
2604+
manager: conda
2605+
platform: linux-64
2606+
dependencies:
2607+
libgcc-ng: '>=12'
2608+
libstdcxx-ng: '>=12'
2609+
python: '>=3.11,<3.12.0a0'
2610+
python_abi: 3.11.*
2611+
url: https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py311h9547e67_1.conda
2612+
hash:
2613+
md5: 2c65bdf442b0d37aad080c8a4e0d452f
2614+
sha256: 723b0894d2d2b05a38f9c5a285d5a0a5baa27235ceab6531dbf262ba7c6955c1
2615+
category: main
2616+
optional: false
25462617
- name: krb5
25472618
version: 1.21.2
25482619
manager: conda
@@ -3896,6 +3967,33 @@ package:
38963967
sha256: e1a9930f35e39bf65bc293e24160b83ebf9f800f02749f65358e1c04882ee6b0
38973968
category: main
38983969
optional: false
3970+
- name: matplotlib-base
3971+
version: 3.8.2
3972+
manager: conda
3973+
platform: linux-64
3974+
dependencies:
3975+
certifi: '>=2020.06.20'
3976+
contourpy: '>=1.0.1'
3977+
cycler: '>=0.10'
3978+
fonttools: '>=4.22.0'
3979+
freetype: '>=2.12.1,<3.0a0'
3980+
kiwisolver: '>=1.3.1'
3981+
libgcc-ng: '>=12'
3982+
libstdcxx-ng: '>=12'
3983+
numpy: '>=1.23.5,<2.0a0'
3984+
packaging: '>=20.0'
3985+
pillow: '>=8'
3986+
pyparsing: '>=2.3.1'
3987+
python: '>=3.11,<3.12.0a0'
3988+
python-dateutil: '>=2.7'
3989+
python_abi: 3.11.*
3990+
tk: '>=8.6.13,<8.7.0a0'
3991+
url: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py311h54ef318_0.conda
3992+
hash:
3993+
md5: 9f80753bc008bfc9b95f39d9ff9f1694
3994+
sha256: 69319da0e6bad1711cac1573710370f31e9630fe6c972ff7eac95649e0c04114
3995+
category: main
3996+
optional: false
38993997
- name: matplotlib-inline
39003998
version: 0.1.6
39013999
manager: conda
@@ -4061,6 +4159,18 @@ package:
40614159
sha256: eca27e6fb5fb4ee73f04ae030bce29f5daa46fea3d6abdabb91740646f0d188e
40624160
category: main
40634161
optional: false
4162+
- name: munkres
4163+
version: 1.1.4
4164+
manager: conda
4165+
platform: linux-64
4166+
dependencies:
4167+
python: ''
4168+
url: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2
4169+
hash:
4170+
md5: 2ba8498c1018c1e9c61eb99b973dfe19
4171+
sha256: f86fb22b58e93d04b6f25e0d811b56797689d598788b59dcb47f59045b568306
4172+
category: main
4173+
optional: false
40644174
- name: nbclient
40654175
version: 0.8.0
40664176
manager: conda
@@ -5543,16 +5653,16 @@ package:
55435653
category: main
55445654
optional: false
55455655
- name: s3transfer
5546-
version: 0.8.2
5656+
version: 0.9.0
55475657
manager: conda
55485658
platform: linux-64
55495659
dependencies:
55505660
botocore: '>=1.33.2,<2.0a.0'
5551-
python: '>=3.7'
5552-
url: https://conda.anaconda.org/conda-forge/noarch/s3transfer-0.8.2-pyhd8ed1ab_0.conda
5661+
python: '>=3.8'
5662+
url: https://conda.anaconda.org/conda-forge/noarch/s3transfer-0.9.0-pyhd8ed1ab_0.conda
55535663
hash:
5554-
md5: 75e12933f4bf755c9cdd37072bcb6203
5555-
sha256: 2e5679abcec8eb646df37518ecdbdaa224d7ff5295a1e56707317d52b47d9c79
5664+
md5: 27ad14e5fc6a13f05b90140debc72cd2
5665+
sha256: c9fc315d830238113160471467259740593966bcf59f17287a0baeaf1f6a76d8
55565666
category: main
55575667
optional: false
55585668
- name: sacremoses
@@ -5661,17 +5771,17 @@ package:
56615771
category: main
56625772
optional: false
56635773
- name: sentry-sdk
5664-
version: 1.39.0
5774+
version: 1.39.1
56655775
manager: conda
56665776
platform: linux-64
56675777
dependencies:
56685778
certifi: ''
56695779
python: '>=3.7'
56705780
urllib3: '>=1.25.7'
5671-
url: https://conda.anaconda.org/conda-forge/noarch/sentry-sdk-1.39.0-pyhd8ed1ab_0.conda
5781+
url: https://conda.anaconda.org/conda-forge/noarch/sentry-sdk-1.39.1-pyhd8ed1ab_0.conda
56725782
hash:
5673-
md5: c9cf03170f08f8c27a792572c19a6925
5674-
sha256: c91ab95dab71f76618d9d90cfdc3fdcf8f7189fece57d111f19ddc67f65f9c0e
5783+
md5: 7acf7f9aa3b5659b764ac602c7bfe8f2
5784+
sha256: 1165a3d68f91653d3174b87a8743272ac3db06c803222c3dd199c8d9e02845e5
56755785
category: main
56765786
optional: false
56775787
- name: setproctitle
@@ -5962,7 +6072,7 @@ package:
59626072
category: main
59636073
optional: false
59646074
- name: tiledb
5965-
version: 2.18.2
6075+
version: 2.18.3
59666076
manager: conda
59676077
platform: linux-64
59686078
dependencies:
@@ -5972,15 +6082,15 @@ package:
59726082
libgcc-ng: '>=12'
59736083
libgoogle-cloud: '>=2.12.0,<2.13.0a0'
59746084
libstdcxx-ng: '>=12'
5975-
libxml2: '>=2.12.2,<2.13.0a0'
6085+
libxml2: '>=2.12.3,<2.13.0a0'
59766086
libzlib: '>=1.2.13,<1.3.0a0'
59776087
lz4-c: '>=1.9.3,<1.10.0a0'
59786088
openssl: '>=3.2.0,<4.0a0'
59796089
zstd: '>=1.5.5,<1.6.0a0'
5980-
url: https://conda.anaconda.org/conda-forge/linux-64/tiledb-2.18.2-hc1131af_2.conda
6090+
url: https://conda.anaconda.org/conda-forge/linux-64/tiledb-2.18.3-hc1131af_0.conda
59816091
hash:
5982-
md5: d86a5ef28245c435828b3f2ef1ee3f28
5983-
sha256: 58c0d9906ceda465b0d44ef4878e4a2e22512670bf7b12023d7495d3fd988bdc
6092+
md5: 57e713fdf115812155e4b52e0b0f0f2a
6093+
sha256: d90525476d804a373355b6de2bb0b9725af7a9074914990f14f097efdba6f9e3
59846094
category: main
59856095
optional: false
59866096
- name: tinycss2
@@ -6098,6 +6208,33 @@ package:
60986208
sha256: 26abe526d9514f096f196628ede28fb10d8e65cc19350b3be19f7bc465a22cec
60996209
category: main
61006210
optional: false
6211+
- name: torchvision
6212+
version: 0.16.1
6213+
manager: conda
6214+
platform: linux-64
6215+
dependencies:
6216+
__glibc: '>=2.17,<3.0.a0'
6217+
cuda-version: '>=12.0,<13'
6218+
cudnn: '>=8.8.0.121,<9.0a0'
6219+
libcublas: '>=12.0.1.189,<13.0a0'
6220+
libcusolver: '>=11.4.2.57,<12.0a0'
6221+
libcusparse: '>=12.0.0.76,<13.0a0'
6222+
libgcc-ng: '>=12'
6223+
libjpeg-turbo: '>=3.0.0,<4.0a0'
6224+
libpng: '>=1.6.39,<1.7.0a0'
6225+
libstdcxx-ng: '>=12'
6226+
numpy: '>=1.23.5,<2.0a0'
6227+
pillow: '>=5.3.0,!=8.3.0,!=8.3.1'
6228+
python: '>=3.11,<3.12.0a0'
6229+
python_abi: 3.11.*
6230+
pytorch: '>=2.1.0,<2.2.0a0'
6231+
requests: ''
6232+
url: https://conda.anaconda.org/conda-forge/linux-64/torchvision-0.16.1-cuda120py311h6416cd9_2.conda
6233+
hash:
6234+
md5: 97382c92246c44859ee53c3a62b65822
6235+
sha256: 57c294734e67b616f308c1dea1ab825bf9ca01f7ea97502f3fb5bd5a2a63ad0b
6236+
category: main
6237+
optional: false
61016238
- name: tornado
61026239
version: 6.3.3
61036240
manager: conda
@@ -6357,6 +6494,21 @@ package:
63576494
sha256: 50827c3721a9dbf973b568709d4381add2a6552fa562f26a385c5edc16a534af
63586495
category: main
63596496
optional: false
6497+
- name: vit-pytorch
6498+
version: 1.6.4
6499+
manager: conda
6500+
platform: linux-64
6501+
dependencies:
6502+
einops: '>=0.7.0'
6503+
python: '>=3.6'
6504+
pytorch: '>=1.10'
6505+
torchvision: ''
6506+
url: https://conda.anaconda.org/conda-forge/noarch/vit-pytorch-1.6.4-pyhd8ed1ab_0.conda
6507+
hash:
6508+
md5: b190197078841ee7072474fa34fc6b98
6509+
sha256: 8101b5a00b0fb9317acdd8072a6e6b2727a055cb00d58d8de19cabd0704167ab
6510+
category: main
6511+
optional: false
63606512
- name: wandb
63616513
version: 0.15.12
63626514
manager: conda

0 commit comments

Comments
 (0)