Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MAE with support for position, time, latlon & channel embeddings #47

Merged
merged 32 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4942a19
Add modified ViT to encode latlon, time, channels & position embeddings
Nov 21, 2023
6d17a4c
Add MAE for modified ViT
Nov 28, 2023
81792c6
Add docstrings & fix issue with complex indexing
Nov 28, 2023
9db48c9
Fix the comments on loss computation
Nov 29, 2023
05ecb9e
Merge main to vit-pytorch
Nov 29, 2023
3773045
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2023
f0c4170
Add datamodule & trainer to run an epoch of training
Nov 29, 2023
134ba4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2023
1987a1a
Normalize data before feeding to the model
Nov 30, 2023
37f6f82
Add fixed sincos embedding for position & bands
Nov 30, 2023
ec5a547
Fix pre-commit CI issue
Nov 30, 2023
bfc4bad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
c0393dd
Add logging & ckpt options
Dec 4, 2023
088a8c5
Fix the order of coords from lat,lon to lon,lat
Dec 7, 2023
25b59a3
Add clay tiny,small,medium,large model versions
Dec 7, 2023
aabbc83
Fix pre-commit formatting issue
Dec 7, 2023
ec7122a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2023
f3c2a6e
Remove hardcoded patch size in LogIntermediatePredictions callback
weiji14 Dec 8, 2023
ed6138f
Run clay small on image size 512 for 10 epochs with grad_acc
Dec 13, 2023
87dd31a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
ac5b71e
Make the clay construction configurable
Dec 14, 2023
2dd8ab0
Return the data path to reference for vector embeddings
Dec 14, 2023
eacee3b
Remove duplicate dataset.py & geovit.py
Dec 14, 2023
4c96387
:twisted_rightwards_arrows: Merge srm_trainer.py into trainer.py
weiji14 Dec 14, 2023
3dd7ffc
:twisted_rightwards_arrows: Merge branch 'main' into vit-pytorch
weiji14 Dec 14, 2023
f3d06eb
:twisted_rightwards_arrows: Combine clay.py and model.py into model_clay
weiji14 Dec 15, 2023
3f8a330
:heavy_plus_sign: Add matplotlib-base
weiji14 Dec 15, 2023
7375f09
:truck: Move ClayDataset and ClayDataModule into datamodule.py
weiji14 Dec 15, 2023
c202fdf
:twisted_rightwards_arrows: Merge branch 'main' into vit-pytorch
weiji14 Dec 15, 2023
4d011d0
:truck: Move LogIntermediatePredictions callback into callbacks_wandb
weiji14 Dec 15, 2023
ded0877
:recycle: Get WandB logger properly using a shared function
weiji14 Dec 15, 2023
2af0eed
:rotating_light: Wrap docstring and fix too-many-arguments lint error
weiji14 Dec 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 189 additions & 37 deletions conda-lock.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
version: 1
metadata:
content_hash:
linux-64: b7b524b1ba497b9648707a5bfb9c9dd4abb412a6b07af7b0472985be7ba39abe
linux-64: ee11aee8049ddee9b71ce4ec66285c84b8d915535c1a19e608f515e8ac97f7a9
channels:
- url: conda-forge
used_env_vars: []
Expand Down Expand Up @@ -529,33 +529,33 @@ package:
category: main
optional: false
- name: boto3
version: 1.33.13
version: 1.34.1
manager: conda
platform: linux-64
dependencies:
botocore: '>=1.33.13,<1.34.0'
botocore: '>=1.34.1,<1.35.0'
jmespath: '>=0.7.1,<2.0.0'
python: '>=3.7'
s3transfer: '>=0.8.2,<0.9.0'
url: https://conda.anaconda.org/conda-forge/noarch/boto3-1.33.13-pyhd8ed1ab_0.conda
python: '>=3.8'
s3transfer: '>=0.9.0,<0.10.0'
url: https://conda.anaconda.org/conda-forge/noarch/boto3-1.34.1-pyhd8ed1ab_0.conda
hash:
md5: 4a1b38a0938b9fc23fb4fc202d832097
sha256: 01f797f967ac92346a5bffdae165c5d4bacda8405b828fcd58534a0f82287f76
md5: 506cb54f3b548a2d372d860b591afcd6
sha256: 9a9a97588d5d4a526ce76fece8d1c9a9870047c2a9a04ca5600a6c687378e207
category: main
optional: false
- name: botocore
version: 1.33.13
version: 1.34.1
manager: conda
platform: linux-64
dependencies:
jmespath: '>=0.7.1,<2.0.0'
python: '>=3.7'
python: '>=3.8'
python-dateutil: '>=2.1,<3.0.0'
urllib3: '>=1.25.4,<1.27'
url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.33.13-pyhd8ed1ab_0.conda
url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.34.1-pyhd8ed1ab_0.conda
hash:
md5: d2566fd9134b6f8b8e69a07e9a1fa17e
sha256: 498d08274880ef279e9a6dd68f66b384d71321d92301fa60330712f9edceab0c
md5: b60f64c4bef6aee697f8291bfffa42b2
sha256: 73a1348256e47aa620b014b3a7e3a0de29f8adfff97ded965fa4e05dd9a8d7e2
category: main
optional: false
- name: brotli
Expand Down Expand Up @@ -981,6 +981,22 @@ package:
sha256: c6fc314161263f031eb23ac53868e0d9b0242efe669e176901effdac4bd87376
category: main
optional: false
- name: contourpy
version: 1.2.0
manager: conda
platform: linux-64
dependencies:
libgcc-ng: '>=12'
libstdcxx-ng: '>=12'
numpy: '>=1.20,<2'
python: '>=3.11,<3.12.0a0'
python_abi: 3.11.*
url: https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py311h9547e67_0.conda
hash:
md5: 40828c5b36ef52433e21f89943e09f33
sha256: 2c76e2a970b74eef92ef9460aa705dbdc506dd59b7382bfbedce39d9c189d7f4
category: main
optional: false
- name: crashtest
version: 0.4.1
manager: conda
Expand Down Expand Up @@ -1032,10 +1048,10 @@ package:
cuda-version: '>=12.0,<12.1.0a0'
libgcc-ng: '>=12'
libstdcxx-ng: '>=12'
url: https://conda.anaconda.org/conda-forge/linux-64/cuda-cudart-12.0.107-hd3aeb46_7.conda
url: https://conda.anaconda.org/conda-forge/linux-64/cuda-cudart-12.0.107-hd3aeb46_8.conda
hash:
md5: def0e966c7fad7e13f8840b1cdb92dbd
sha256: bc953b78fa1b0b989da5be01a3f8adc4aa80bc7eaa8a2f08850f5558871b6654
md5: 3a747a8e83767681eba39d7b8957e5fb
sha256: 5680d6d1ab9d61d28319b8b1f3b5f29fc9e55c0905051dd4277a662fceff0e59
category: main
optional: false
- name: cuda-cudart_linux-64
Expand All @@ -1044,10 +1060,10 @@ package:
platform: linux-64
dependencies:
cuda-version: '>=12.0,<12.1.0a0'
url: https://conda.anaconda.org/conda-forge/noarch/cuda-cudart_linux-64-12.0.107-h59595ed_7.conda
url: https://conda.anaconda.org/conda-forge/noarch/cuda-cudart_linux-64-12.0.107-h59595ed_8.conda
hash:
md5: 36de472d13cff5e19543998deb4a9093
sha256: 01eb5aba8324fbf99fb9eac3ea3423bc3a0a9ecf42e40da27aa3aea557c33cdf
md5: 839830d4d2b02950bf6ccf18d252d69f
sha256: df6b88e0bc323da65b82f0658e2fa5722ac81b251abfa98291c3b409f56ec253
category: main
optional: false
- name: cuda-nvrtc
Expand Down Expand Up @@ -1108,6 +1124,18 @@ package:
sha256: 0cd17a35182f3d35b6f7813a345f430814113fdf23417c37bdff4c310a4ce03b
category: main
optional: false
- name: cycler
version: 0.12.1
manager: conda
platform: linux-64
dependencies:
python: '>=3.8'
url: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda
hash:
md5: 5cd86562580f274031ede6aa6aa24441
sha256: f221233f21b1d06971792d491445fd548224641af9443739b4b7b6d5d72954a8
category: main
optional: false
- name: dask-core
version: 2023.12.0
manager: conda
Expand Down Expand Up @@ -1324,6 +1352,18 @@ package:
sha256: 190dbafc9e699f74cf8d287e91246acac1e14afda8ce6aedafac87e392e1bc96
category: main
optional: false
- name: einops
version: 0.7.0
manager: conda
platform: linux-64
dependencies:
python: '>=3.8'
url: https://conda.anaconda.org/conda-forge/noarch/einops-0.7.0-pyhd8ed1ab_1.conda
hash:
md5: 1641890c9375ddb22381f3eb9ac157df
sha256: cc08bb969a4458b7afd48e7ba8151c95b48f1c315d3567644ed4a97ee2987247
category: main
optional: false
- name: ensureconda
version: 1.4.3
manager: conda
Expand Down Expand Up @@ -1541,6 +1581,22 @@ package:
sha256: 53f23a3319466053818540bcdf2091f253cbdbab1e0e9ae7b9e509dcaa2a5e38
category: main
optional: false
- name: fonttools
version: 4.46.0
manager: conda
platform: linux-64
dependencies:
brotli: ''
libgcc-ng: '>=12'
munkres: ''
python: '>=3.11,<3.12.0a0'
python_abi: 3.11.*
url: https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.46.0-py311h459d7ec_0.conda
hash:
md5: a14114f70e23f7fd5ab9941fec45b095
sha256: a40f8415d9ceaf5f217034814b984d13017e4dab577085a83a2d0cc39b9d7239
category: main
optional: false
- name: fqdn
version: 1.5.1
manager: conda
Expand Down Expand Up @@ -1964,17 +2020,17 @@ package:
category: main
optional: false
- name: imageio
version: 2.31.5
version: 2.33.1
manager: conda
platform: linux-64
dependencies:
numpy: ''
pillow: '>=8.3.2'
python: '>=3'
url: https://conda.anaconda.org/conda-forge/noarch/imageio-2.31.5-pyh8c1a49c_0.conda
url: https://conda.anaconda.org/conda-forge/noarch/imageio-2.33.1-pyh8c1a49c_0.conda
hash:
md5: 6820ccf6a3a27df348f18c85dd89014a
sha256: 0554fbf2136a1ab380551963c5884941f7852034cbe40f002ae040e10e457365
md5: 1c34d58ac469a34e7e96832861368bce
sha256: 0565f3666de4d02a83c5c8e14b7d878c382720f84318d6ce1ff83b66603c01d7
category: main
optional: false
- name: importlib-metadata
Expand Down Expand Up @@ -2543,6 +2599,21 @@ package:
sha256: 150c05a6e538610ca7c43beb3a40d65c90537497a4f6a5f4d15ec0451b6f5ebb
category: main
optional: false
- name: kiwisolver
version: 1.4.5
manager: conda
platform: linux-64
dependencies:
libgcc-ng: '>=12'
libstdcxx-ng: '>=12'
python: '>=3.11,<3.12.0a0'
python_abi: 3.11.*
url: https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py311h9547e67_1.conda
hash:
md5: 2c65bdf442b0d37aad080c8a4e0d452f
sha256: 723b0894d2d2b05a38f9c5a285d5a0a5baa27235ceab6531dbf262ba7c6955c1
category: main
optional: false
- name: krb5
version: 1.21.2
manager: conda
Expand Down Expand Up @@ -3896,6 +3967,33 @@ package:
sha256: e1a9930f35e39bf65bc293e24160b83ebf9f800f02749f65358e1c04882ee6b0
category: main
optional: false
- name: matplotlib-base
version: 3.8.2
manager: conda
platform: linux-64
dependencies:
certifi: '>=2020.06.20'
contourpy: '>=1.0.1'
cycler: '>=0.10'
fonttools: '>=4.22.0'
freetype: '>=2.12.1,<3.0a0'
kiwisolver: '>=1.3.1'
libgcc-ng: '>=12'
libstdcxx-ng: '>=12'
numpy: '>=1.23.5,<2.0a0'
packaging: '>=20.0'
pillow: '>=8'
pyparsing: '>=2.3.1'
python: '>=3.11,<3.12.0a0'
python-dateutil: '>=2.7'
python_abi: 3.11.*
tk: '>=8.6.13,<8.7.0a0'
url: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.2-py311h54ef318_0.conda
hash:
md5: 9f80753bc008bfc9b95f39d9ff9f1694
sha256: 69319da0e6bad1711cac1573710370f31e9630fe6c972ff7eac95649e0c04114
category: main
optional: false
- name: matplotlib-inline
version: 0.1.6
manager: conda
Expand Down Expand Up @@ -4061,6 +4159,18 @@ package:
sha256: eca27e6fb5fb4ee73f04ae030bce29f5daa46fea3d6abdabb91740646f0d188e
category: main
optional: false
- name: munkres
version: 1.1.4
manager: conda
platform: linux-64
dependencies:
python: ''
url: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2
hash:
md5: 2ba8498c1018c1e9c61eb99b973dfe19
sha256: f86fb22b58e93d04b6f25e0d811b56797689d598788b59dcb47f59045b568306
category: main
optional: false
- name: nbclient
version: 0.8.0
manager: conda
Expand Down Expand Up @@ -5543,16 +5653,16 @@ package:
category: main
optional: false
- name: s3transfer
version: 0.8.2
version: 0.9.0
manager: conda
platform: linux-64
dependencies:
botocore: '>=1.33.2,<2.0a.0'
python: '>=3.7'
url: https://conda.anaconda.org/conda-forge/noarch/s3transfer-0.8.2-pyhd8ed1ab_0.conda
python: '>=3.8'
url: https://conda.anaconda.org/conda-forge/noarch/s3transfer-0.9.0-pyhd8ed1ab_0.conda
hash:
md5: 75e12933f4bf755c9cdd37072bcb6203
sha256: 2e5679abcec8eb646df37518ecdbdaa224d7ff5295a1e56707317d52b47d9c79
md5: 27ad14e5fc6a13f05b90140debc72cd2
sha256: c9fc315d830238113160471467259740593966bcf59f17287a0baeaf1f6a76d8
category: main
optional: false
- name: sacremoses
Expand Down Expand Up @@ -5661,17 +5771,17 @@ package:
category: main
optional: false
- name: sentry-sdk
version: 1.39.0
version: 1.39.1
manager: conda
platform: linux-64
dependencies:
certifi: ''
python: '>=3.7'
urllib3: '>=1.25.7'
url: https://conda.anaconda.org/conda-forge/noarch/sentry-sdk-1.39.0-pyhd8ed1ab_0.conda
url: https://conda.anaconda.org/conda-forge/noarch/sentry-sdk-1.39.1-pyhd8ed1ab_0.conda
hash:
md5: c9cf03170f08f8c27a792572c19a6925
sha256: c91ab95dab71f76618d9d90cfdc3fdcf8f7189fece57d111f19ddc67f65f9c0e
md5: 7acf7f9aa3b5659b764ac602c7bfe8f2
sha256: 1165a3d68f91653d3174b87a8743272ac3db06c803222c3dd199c8d9e02845e5
category: main
optional: false
- name: setproctitle
Expand Down Expand Up @@ -5962,7 +6072,7 @@ package:
category: main
optional: false
- name: tiledb
version: 2.18.2
version: 2.18.3
manager: conda
platform: linux-64
dependencies:
Expand All @@ -5972,15 +6082,15 @@ package:
libgcc-ng: '>=12'
libgoogle-cloud: '>=2.12.0,<2.13.0a0'
libstdcxx-ng: '>=12'
libxml2: '>=2.12.2,<2.13.0a0'
libxml2: '>=2.12.3,<2.13.0a0'
libzlib: '>=1.2.13,<1.3.0a0'
lz4-c: '>=1.9.3,<1.10.0a0'
openssl: '>=3.2.0,<4.0a0'
zstd: '>=1.5.5,<1.6.0a0'
url: https://conda.anaconda.org/conda-forge/linux-64/tiledb-2.18.2-hc1131af_2.conda
url: https://conda.anaconda.org/conda-forge/linux-64/tiledb-2.18.3-hc1131af_0.conda
hash:
md5: d86a5ef28245c435828b3f2ef1ee3f28
sha256: 58c0d9906ceda465b0d44ef4878e4a2e22512670bf7b12023d7495d3fd988bdc
md5: 57e713fdf115812155e4b52e0b0f0f2a
sha256: d90525476d804a373355b6de2bb0b9725af7a9074914990f14f097efdba6f9e3
category: main
optional: false
- name: tinycss2
Expand Down Expand Up @@ -6098,6 +6208,33 @@ package:
sha256: 26abe526d9514f096f196628ede28fb10d8e65cc19350b3be19f7bc465a22cec
category: main
optional: false
- name: torchvision
version: 0.16.1
manager: conda
platform: linux-64
dependencies:
__glibc: '>=2.17,<3.0.a0'
cuda-version: '>=12.0,<13'
cudnn: '>=8.8.0.121,<9.0a0'
libcublas: '>=12.0.1.189,<13.0a0'
libcusolver: '>=11.4.2.57,<12.0a0'
libcusparse: '>=12.0.0.76,<13.0a0'
libgcc-ng: '>=12'
libjpeg-turbo: '>=3.0.0,<4.0a0'
libpng: '>=1.6.39,<1.7.0a0'
libstdcxx-ng: '>=12'
numpy: '>=1.23.5,<2.0a0'
pillow: '>=5.3.0,!=8.3.0,!=8.3.1'
python: '>=3.11,<3.12.0a0'
python_abi: 3.11.*
pytorch: '>=2.1.0,<2.2.0a0'
requests: ''
url: https://conda.anaconda.org/conda-forge/linux-64/torchvision-0.16.1-cuda120py311h6416cd9_2.conda
hash:
md5: 97382c92246c44859ee53c3a62b65822
sha256: 57c294734e67b616f308c1dea1ab825bf9ca01f7ea97502f3fb5bd5a2a63ad0b
category: main
optional: false
- name: tornado
version: 6.3.3
manager: conda
Expand Down Expand Up @@ -6357,6 +6494,21 @@ package:
sha256: 50827c3721a9dbf973b568709d4381add2a6552fa562f26a385c5edc16a534af
category: main
optional: false
- name: vit-pytorch
version: 1.6.4
manager: conda
platform: linux-64
dependencies:
einops: '>=0.7.0'
python: '>=3.6'
pytorch: '>=1.10'
torchvision: ''
url: https://conda.anaconda.org/conda-forge/noarch/vit-pytorch-1.6.4-pyhd8ed1ab_0.conda
hash:
md5: b190197078841ee7072474fa34fc6b98
sha256: 8101b5a00b0fb9317acdd8072a6e6b2727a055cb00d58d8de19cabd0704167ab
category: main
optional: false
- name: wandb
version: 0.15.12
manager: conda
Expand Down
Loading