From a1d7a2797e34dbd9be073c853fdc205f496c067a Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 12 Aug 2025 14:24:28 +0000 Subject: [PATCH 01/11] Update to develop, prepare for new experiment series --- config/default_config.yml | 24 ++++++++++++------------ config/eval_config.yml | 28 ++++++++++++++++++++++++++++ config/runs_plot_train.yml | 6 ++++++ 3 files changed, 46 insertions(+), 12 deletions(-) create mode 100644 config/eval_config.yml create mode 100644 config/runs_plot_train.yml diff --git a/config/default_config.yml b/config/default_config.yml index 76bdd2694..e3772e842 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -10,7 +10,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 +ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -24,7 +24,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -40,13 +40,13 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: "fixed" forecast_freeze_model: False -forecast_att_dense_rate: 0.25 -fe_num_blocks: 0 +forecast_att_dense_rate: 1.0 +fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -75,7 +75,7 @@ batch_size_validation_per_gpu: 1 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -91,17 +91,17 @@ masking_strategy: "random" # "channel": requires "mode" to be specified, "per_cell" or "global", masking_strategy_config: {"hl_mask": 3} -num_epochs: 32 +num_epochs: 64 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True lr_scaling_policy: "sqrt" lr_start: 1e-6 -lr_max: 5e-5 -lr_final_decay: 1e-6 +lr_max: 0.0001 +lr_final_decay: 2e-6 lr_final: 0.0 -lr_steps_warmup: 512 +lr_steps_warmup: 256 lr_steps_cooldown: 512 lr_policy_warmup: "cosine" lr_policy_decay: "linear" diff --git a/config/eval_config.yml b/config/eval_config.yml new file mode 100644 index 000000000..937bc59be --- /dev/null +++ b/config/eval_config.yml @@ -0,0 +1,28 @@ +verbose: true +image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +dpi_val : 300 +summary_plots : true +print_summary: false + +evaluation: + metrics : ["rmse"] + regions: ["global"] + +run_ids : + + ptluswdo: + label: "ptluswdo: 64ep 2fs (naoj54ch) + 32ep 8fs 2e-5" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + #channels: ["2t", "q_850", ] + evaluation: + sample: "all" + forecast_step: "all" + plotting: + sample: [0] + forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] + plot_maps: true + plot_histograms: false \ No newline at end of file diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..49924b524 --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,6 @@ +train : + plot : + lnjzhore : + slurm_id: 0 + description: "Christian's naoj54ch with new code" + eval: vgbndhco \ No newline at end of file From c12e1905a1fa50390f626f76bd3e64c5c9b6f3a8 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 10 Oct 2025 21:25:36 +0200 Subject: [PATCH 02/11] Setting o48 as default in era5 config Committer: Matthias Karlbauer On branch mk/develop/fe_experiments Your branch is ahead of 'origin/mk/develop/fe_experiments' by 57 commits. (use "git push" to publish your local commits) Changes to be committed: modified: config/streams/era5_1deg/era5.yml --- config/streams/era5_1deg/era5.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index bb2234c4e..e9cc9a6b8 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -9,7 +9,8 @@ ERA5 : type : anemoi - filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + #filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. From d95277e33754969e0652005e08f7d9ab4c8c1785 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 10 Oct 2025 21:28:38 +0200 Subject: [PATCH 03/11] Updated default config to 256 dim latent size On branch mk/develop/fe_experiments Your branch is ahead of 'origin/mk/develop/fe_experiments' by 58 commits. (use "git push" to publish your local commits) Changes to be committed: modified: config/default_config.yml --- config/default_config.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 140d04892..3bb87c950 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -9,7 +9,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True -ae_local_dim_embed: 1024 +ae_local_dim_embed: 256 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -23,9 +23,9 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 +ae_global_dim_embed: 256 ae_global_num_blocks: 4 -ae_global_num_heads: 32 +ae_global_num_heads: 16 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True # TODO: switching to < 1 triggers triton-related issues. From a73447178f00993efcad4c7e1058dc2e47cf3b8e Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 13 Oct 2025 12:24:48 +0200 Subject: [PATCH 04/11] Update branch to latest develop --- uv.lock | 272 ++++++-------------------------------------------------- 1 file changed, 26 insertions(+), 246 deletions(-) diff --git a/uv.lock b/uv.lock index 56e875859..469c6a41f 100644 --- a/uv.lock +++ b/uv.lock @@ -1251,52 +1251,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/1c/6d343e030815c7c97a1f9fbad00211b47717c7fe446834c224bd5311e6f1/numpy-2.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:bd8df082b6c4695753ad6193018c05aac465d634834dca47a3ae06d4bb22d9ea", size = 9891498, upload-time = "2025-06-07T14:43:36.332Z" }, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771, upload-time = "2024-06-18T19:28:09.881Z" }, - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805, upload-time = "2024-04-03T20:57:06.025Z" }, - { url = "https://files.pythonhosted.org/packages/e2/2a/4f27ca96232e8b5269074a72e03b4e0d43aa68c9b965058b1684d07c6ff8/nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc", size = 396895858, upload-time = "2024-04-03T21:03:31.996Z" }, -] - [[package]] name = "nvidia-cublas-cu12" version = "12.6.4.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322, upload-time = "2024-11-20T17:40:25.65Z" }, { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615, upload-time = "2024-11-20T17:39:52.715Z" }, { url = "https://files.pythonhosted.org/packages/84/f7/985e9bdbe3e0ac9298fcc8cfa51a392862a46a0ffaccbbd56939b62a9c83/nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8", size = 434535301, upload-time = "2024-11-20T17:50:41.681Z" }, ] -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556, upload-time = "2024-06-18T19:30:40.546Z" }, - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957, upload-time = "2024-04-03T20:55:01.564Z" }, - { url = "https://files.pythonhosted.org/packages/f3/79/8cf313ec17c58ccebc965568e5bcb265cdab0a1df99c4e674bb7a3b99bfe/nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922", size = 9938035, upload-time = "2024-04-03T21:01:01.109Z" }, -] - [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.6.80" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764, upload-time = "2024-11-20T17:35:41.03Z" }, { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756, upload-time = "2024-10-01T16:57:45.507Z" }, @@ -1305,52 +1273,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/81/7796f096afaf726796b1b648f3bc80cafc61fe7f77f44a483c89e6c5ef34/nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a", size = 5724175, upload-time = "2024-10-01T17:09:47.955Z" }, ] -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372, upload-time = "2024-06-18T19:32:00.576Z" }, - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306, upload-time = "2024-04-03T20:56:01.463Z" }, - { url = "https://files.pythonhosted.org/packages/7c/30/8c844bfb770f045bcd8b2c83455c5afb45983e1a8abf0c4e5297b481b6a5/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec", size = 19751955, upload-time = "2024-04-03T21:01:51.133Z" }, -] - [[package]] name = "nvidia-cuda-nvrtc-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/f4/2f/72df534873235983cc0a5371c3661bebef7c4682760c275590b972c7b0f9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13", size = 23162955, upload-time = "2024-10-01T16:59:50.922Z" }, { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380, upload-time = "2024-10-01T17:00:14.643Z" }, { url = "https://files.pythonhosted.org/packages/f5/46/d3a1cdda8bb113c80f43a0a6f3a853356d487b830f3483f92d49ce87fa55/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a", size = 39026742, upload-time = "2024-10-01T17:10:49.058Z" }, ] -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177, upload-time = "2024-06-18T19:32:52.877Z" }, - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737, upload-time = "2024-04-03T20:54:51.355Z" }, - { url = "https://files.pythonhosted.org/packages/a8/8b/450e93fab75d85a69b50ea2d5fdd4ff44541e0138db16f9cd90123ef4de4/nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e", size = 878808, upload-time = "2024-04-03T21:00:49.77Z" }, -] - [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052, upload-time = "2024-11-20T17:35:19.905Z" }, { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040, upload-time = "2024-10-01T16:57:22.221Z" }, @@ -1359,30 +1295,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/76/4c80fa138333cc975743fd0687a745fccb30d167f906f13c1c7f9a85e5ea/nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f", size = 891773, upload-time = "2024-10-01T17:09:26.362Z" }, ] -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741, upload-time = "2024-04-22T15:24:15.253Z" }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892, upload-time = "2024-04-22T15:24:53.333Z" }, -] - [[package]] name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/99/93/a201a12d3ec1caa8c6ac34c1c2f9eeb696b886f0c36ff23c638b46603bd0/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def", size = 570523509, upload-time = "2024-10-25T19:53:03.148Z" }, @@ -1390,31 +1308,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/b2/3f60d15f037fa5419d9d7f788b100ef33ea913ae5315c87ca6d6fa606c35/nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8", size = 565440743, upload-time = "2024-10-25T19:55:49.74Z" }, ] -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548, upload-time = "2024-06-18T19:33:39.396Z" }, - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117, upload-time = "2024-04-03T20:57:40.402Z" }, - { url = "https://files.pythonhosted.org/packages/f6/ee/3f3f8e9874f0be5bbba8fb4b62b3de050156d159f8b6edc42d6f1074113b/nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b", size = 210576476, upload-time = "2024-04-03T21:04:06.422Z" }, -] - [[package]] name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144, upload-time = "2024-11-20T17:40:58.288Z" }, @@ -1424,26 +1323,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/38/36fd800cec8f6e89b7c1576edaaf8076e69ec631644cdbc1b5f2e2b5a9df/nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464", size = 199356881, upload-time = "2024-10-01T17:13:01.861Z" }, ] -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811, upload-time = "2024-06-18T19:34:48.575Z" }, - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206, upload-time = "2024-04-03T20:58:08.722Z" }, - { url = "https://files.pythonhosted.org/packages/1c/22/2573503d0d4e45673c263a313f79410e110eb562636b0617856fdb2ff5f6/nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771", size = 55799918, upload-time = "2024-04-03T21:04:34.45Z" }, -] - [[package]] name = "nvidia-curand-cu12" version = "10.3.7.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/42/ac/36543605358a355632f1a6faa3e2d5dfb91eab1e4bc7d552040e0383c335/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8", size = 56289881, upload-time = "2024-10-01T17:04:18.981Z" }, { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010, upload-time = "2024-11-20T17:42:50.958Z" }, @@ -1452,35 +1335,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/a8/0cd0cec757bd4b4b4ef150fca62ec064db7d08a291dced835a0be7d2c147/nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905", size = 55783873, upload-time = "2024-10-01T17:13:30.377Z" }, ] -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111, upload-time = "2024-06-18T19:35:01.793Z" }, - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057, upload-time = "2024-04-03T20:58:28.735Z" }, - { url = "https://files.pythonhosted.org/packages/f2/be/d435b7b020e854d5d5a682eb5de4328fd62f6182507406f2818280e206e2/nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c", size = 125224015, upload-time = "2024-04-03T21:04:53.339Z" }, -] - [[package]] name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628, upload-time = "2024-10-01T17:05:05.591Z" }, @@ -1490,31 +1352,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/53/fff50a0808df7113d77e3bbc7c2b7eaed6f57d5eb80fbe93ead2aea1e09a/nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7", size = 149287877, upload-time = "2024-10-01T17:13:49.804Z" }, ] -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987, upload-time = "2024-06-18T19:35:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763, upload-time = "2024-04-03T20:58:59.995Z" }, - { url = "https://files.pythonhosted.org/packages/a2/e0/3155ca539760a8118ec94cc279b34293309bcd14011fc724f87f31988843/nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f", size = 204684315, upload-time = "2024-04-03T21:05:26.031Z" }, -] - [[package]] name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147, upload-time = "2024-11-20T17:44:18.055Z" }, @@ -1524,26 +1367,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/ef/876ad8e4260e1128e6d4aac803d9d51baf3791ebdb4a9b8d9b8db032b4b0/nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20", size = 213712630, upload-time = "2024-10-01T17:14:23.779Z" }, ] -[[package]] -name = "nvidia-cusparselt-cu12" -version = "0.6.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/8e/675498726c605c9441cf46653bd29cb1b8666da1fb1469ffa25f67f20c58/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8", size = 149422781, upload-time = "2024-07-23T17:35:27.203Z" }, - { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751, upload-time = "2024-07-23T02:35:53.074Z" }, - { url = "https://files.pythonhosted.org/packages/56/8f/2c33082238b6c5e783a877dc8786ab62619e3e6171c083bd3bba6e3fe75e/nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70", size = 148755794, upload-time = "2024-07-23T02:35:00.261Z" }, -] - [[package]] name = "nvidia-cusparselt-cu12" version = "0.6.3" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/62/da/4de092c61c6dea1fc9c936e69308a02531d122e12f1f649825934ad651b5/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1", size = 156402859, upload-time = "2024-10-16T02:23:17.184Z" }, { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796, upload-time = "2024-10-15T21:29:17.709Z" }, @@ -1567,52 +1394,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414, upload-time = "2024-04-03T15:32:57.427Z" }, ] -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510, upload-time = "2024-06-18T20:20:13.871Z" }, - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810, upload-time = "2024-04-03T20:59:46.957Z" }, - { url = "https://files.pythonhosted.org/packages/81/19/0babc919031bee42620257b9a911c528f05fb2688520dcd9ca59159ffea8/nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1", size = 95336325, upload-time = "2024-04-03T21:06:25.073Z" }, -] - [[package]] name = "nvidia-nvjitlink-cu12" version = "12.6.85" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971, upload-time = "2024-11-20T17:46:53.366Z" }, { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338, upload-time = "2024-11-20T17:46:29.758Z" }, { url = "https://files.pythonhosted.org/packages/89/76/93c1467b1387387440a4d25102d86b7794535449b689f8e2dc22c1c8ff7f/nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c", size = 161908572, upload-time = "2024-11-20T17:52:40.124Z" }, ] -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417, upload-time = "2024-06-18T20:16:22.484Z" }, - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144, upload-time = "2024-04-03T20:56:12.406Z" }, - { url = "https://files.pythonhosted.org/packages/54/1b/f77674fbb73af98843be25803bbd3b9a4f0a96c75b8d33a2854a5c7d2d77/nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485", size = 66307, upload-time = "2024-04-03T21:02:01.959Z" }, -] - [[package]] name = "nvidia-nvtx-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/b9/93/80f8a520375af9d7ee44571a6544653a176e53c2b8ccce85b97b83c2491b/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b", size = 90549, upload-time = "2024-11-20T17:38:17.387Z" }, { url = "https://files.pythonhosted.org/packages/2b/53/36e2fd6c7068997169b49ffc8c12d5af5e5ff209df6e1a2c4d373b3a638f/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059", size = 90539, upload-time = "2024-10-01T17:00:27.179Z" }, @@ -2426,8 +2221,8 @@ wheels = [ [[package]] name = "torch" -version = "2.6.0" -source = { registry = "https://pypi.org/simple" } +version = "2.6.0+cpu" +source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", "platform_machine == 'x86_64' and sys_platform == 'linux'", @@ -2437,29 +2232,14 @@ dependencies = [ { name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "networkx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.1.0.70", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.2.1.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", version = "10.3.5.147", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.6.1.9", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", version = "0.6.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sympy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563, upload-time = "2025-01-29T16:23:19.084Z" }, - { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867, upload-time = "2025-01-29T16:25:55.649Z" }, - { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469, upload-time = "2025-01-29T16:24:01.821Z" }, - { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538, upload-time = "2025-01-29T16:24:18.976Z" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:59e78aa0c690f70734e42670036d6b541930b8eabbaa18d94e090abf14cc4d91" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:318290e8924353c61b125cdc8768d15208704e279e7757c113b9620740deca98" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:4027d982eb2781c93825ab9527f17fbbb12dbabf422298e4b954be60016f87d8" }, ] [[package]] @@ -2508,19 +2288,19 @@ dependencies = [ { name = "fsspec", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "jinja2", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "networkx", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", version = "10.3.7.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", version = "0.6.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -2696,7 +2476,7 @@ dependencies = [ [package.optional-dependencies] cpu = [ - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] gpu = [ { name = "flash-attn", version = "2.7.3", source = { url = "https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-10-weathergen-gpu') or (platform_machine != 'aarch64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2737,9 +2517,9 @@ requires-dist = [ { name = "pynvml" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { name = "torch", marker = "sys_platform == 'linux' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, + { name = "torch", marker = "sys_platform != 'linux' and extra == 'cpu'", specifier = "==2.6.0" }, { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'gpu') or (sys_platform != 'linux' and extra == 'gpu')", specifier = "==2.6.0+cu126" }, - { name = "torch", marker = "sys_platform == 'macosx' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, - { name = "torch", marker = "sys_platform != 'macosx' and extra == 'cpu'", specifier = "==2.6.0" }, { name = "tqdm" }, { name = "weathergen-common", editable = "packages/common" }, { name = "weathergen-evaluate", editable = "packages/evaluate" }, From eba89a6a8181ae3905fc64157cf247e5e3ce2fe2 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 13 Oct 2025 17:01:52 +0200 Subject: [PATCH 05/11] Change epochs from 64 to 32 --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index abbcb47f2..efb6e95b3 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -109,7 +109,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 64 +num_epochs: 32 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True From 298c032c4240d31e587724bc82f88f97fc06c1b2 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Sat, 25 Oct 2025 23:20:43 +0200 Subject: [PATCH 06/11] changes to engines and layers --- src/weathergen/model/engines.py | 39 +++++++++++------ src/weathergen/model/layers.py | 77 +++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 13 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..fbb930ad3 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -10,8 +10,8 @@ import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint - from weathergen.common.config import Config + from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -24,7 +24,7 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import FEMLP, MLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -317,18 +317,31 @@ def create(self) -> torch.nn.ModuleList: attention_dtype=get_dtype(self.cf.attention_dtype), ) ) - # Add MLP block - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, - norm_eps=self.cf.mlp_norm_eps, + + if i + 1 == self.cf.ae_global_num_blocks: + self.fe_blocks.append( + FEMLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) + ) + else: + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) ) - ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..17cca11e8 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -93,3 +93,80 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + + +class FEMLP(torch.nn.Module): + def __init__( + self, + dim_in, + dim_out, + num_layers=2, + hidden_factor=2, + pre_layer_norm=True, + dropout_rate=0.0, + nonlin=torch.nn.GELU, + with_residual=False, + norm_type="LayerNorm", + dim_aux=None, + norm_eps=1e-5, + name: str | None = None, + ): + """Constructor""" + + super(FEMLP, self).__init__() + + if name is not None: + self.name = name + + assert num_layers >= 2 + + self.with_residual = with_residual + self.with_aux = dim_aux is not None + dim_hidden = int(dim_in * hidden_factor) + + self.layers = torch.nn.ModuleList() + + norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm + + if pre_layer_norm: + self.layers.append( + norm(dim_in, eps=norm_eps) + if dim_aux is None + else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + ) + + self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + + self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + + # Add LayerNorm after skip connection if residuals are used + if self.with_residual: + # self.residual_norm = AdaLayerNorm( + # dim_out, dim_aux, norm_eps=norm_eps + # ) # norm(dim_out, eps=norm_eps) + self.residual_norm = torch.nn.LayerNorm(dim_out, eps=norm_eps, elementwise_affine=False) + + def forward(self, *args): + x, x_in, aux = args[0], args[0], args[-1] + + for i, layer in enumerate(self.layers): + x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + + if self.with_residual: + if x.shape[-1] == x_in.shape[-1]: + x = x_in + x + else: + assert x.shape[-1] % x_in.shape[-1] == 0 + x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) + + # Apply LayerNorm to the residual connection + x = self.residual_norm(x) + + return x From e36ee2adcbbf8a94c8cf99f0c18c77ca708f247d Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Mon, 27 Oct 2025 12:16:31 +0100 Subject: [PATCH 07/11] Correct bug with blocks --- src/weathergen/model/engines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index fbb930ad3..7d03b8e65 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -318,7 +318,7 @@ def create(self) -> torch.nn.ModuleList: ) ) - if i + 1 == self.cf.ae_global_num_blocks: + if i + 1 == self.cf.fe_num_blocks: self.fe_blocks.append( FEMLP( self.cf.ae_global_dim_embed, From ec514256782d259f27515faf1fa4695f7f3ef338 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Tue, 28 Oct 2025 19:22:18 +0100 Subject: [PATCH 08/11] Add the LayerNorm as block --- src/weathergen/model/engines.py | 34 ++++++-------- src/weathergen/model/layers.py | 83 ++++----------------------------- 2 files changed, 25 insertions(+), 92 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 7d03b8e65..061f70504 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -24,7 +24,7 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import FEMLP, MLP +from weathergen.model.layers import MLP, LayerNormBlock from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -318,28 +318,24 @@ def create(self) -> torch.nn.ModuleList: ) ) - if i + 1 == self.cf.fe_num_blocks: - self.fe_blocks.append( - FEMLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, - norm_eps=self.cf.mlp_norm_eps, - ) + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, ) - else: + ) + + if i + 1 == self.cf.fe_num_blocks: self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, + LayerNormBlock( self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, norm_eps=self.cf.mlp_norm_eps, + elementwise_affine=False, ) ) diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 17cca11e8..cbec643fb 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -95,78 +95,15 @@ def forward(self, *args): return x -class FEMLP(torch.nn.Module): - def __init__( - self, - dim_in, - dim_out, - num_layers=2, - hidden_factor=2, - pre_layer_norm=True, - dropout_rate=0.0, - nonlin=torch.nn.GELU, - with_residual=False, - norm_type="LayerNorm", - dim_aux=None, - norm_eps=1e-5, - name: str | None = None, - ): - """Constructor""" - - super(FEMLP, self).__init__() - - if name is not None: - self.name = name - - assert num_layers >= 2 - - self.with_residual = with_residual - self.with_aux = dim_aux is not None - dim_hidden = int(dim_in * hidden_factor) - - self.layers = torch.nn.ModuleList() - - norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm - - if pre_layer_norm: - self.layers.append( - norm(dim_in, eps=norm_eps) - if dim_aux is None - else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) - ) - - self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) - self.layers.append(nonlin()) - self.layers.append(torch.nn.Dropout(p=dropout_rate)) - - for _ in range(num_layers - 2): - self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) - self.layers.append(nonlin()) - self.layers.append(torch.nn.Dropout(p=dropout_rate)) - - self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) - - # Add LayerNorm after skip connection if residuals are used - if self.with_residual: - # self.residual_norm = AdaLayerNorm( - # dim_out, dim_aux, norm_eps=norm_eps - # ) # norm(dim_out, eps=norm_eps) - self.residual_norm = torch.nn.LayerNorm(dim_out, eps=norm_eps, elementwise_affine=False) +class LayerNormBlock(torch.nn.Module): + def __init__(self, dim_out, norm_eps=1e-5, elementwise_affine=False): + super().__init__() + self.ln = nn.LayerNorm( + dim_out, + eps=norm_eps, + elementwise_affine=elementwise_affine, + ) def forward(self, *args): - x, x_in, aux = args[0], args[0], args[-1] - - for i, layer in enumerate(self.layers): - x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) - - if self.with_residual: - if x.shape[-1] == x_in.shape[-1]: - x = x_in + x - else: - assert x.shape[-1] % x_in.shape[-1] == 0 - x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) - - # Apply LayerNorm to the residual connection - x = self.residual_norm(x) - - return x + x = args[0] + return self.ln(x) From 6c6b08bc6724de6fbd1fca9ebb3a38a727f2077c Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Tue, 28 Oct 2025 19:36:56 +0100 Subject: [PATCH 09/11] Add some doc comments --- src/weathergen/model/engines.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 061f70504..28ed837bf 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -318,6 +318,7 @@ def create(self) -> torch.nn.ModuleList: ) ) + # Add MLP block self.fe_blocks.append( MLP( self.cf.ae_global_dim_embed, @@ -330,6 +331,7 @@ def create(self) -> torch.nn.ModuleList: ) ) + # Add a LayerNorm block as the last block of the FE if i + 1 == self.cf.fe_num_blocks: self.fe_blocks.append( LayerNormBlock( From a8ccaf54bb4658243087b42b196de74a5fa87e4e Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Thu, 30 Oct 2025 11:45:44 +0100 Subject: [PATCH 10/11] change to original code and submit also the configs --- config/default_config.yml | 2 +- config/streams/era5_1deg/era5.yml | 4 +- src/weathergen/model/engines.py | 36 +++++++------- src/weathergen/model/layers.py | 83 +++++++++++++++++++++++++++---- 4 files changed, 95 insertions(+), 30 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index efb6e95b3..8f0f2d459 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -23,7 +23,7 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 256 +ae_global_dim_embed: 2048 ae_global_num_blocks: 4 ae_global_num_heads: 16 ae_global_dropout_rate: 0.1 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index e9cc9a6b8..912075c4b 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -9,8 +9,8 @@ ERA5 : type : anemoi - #filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] - filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 28ed837bf..fbb930ad3 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -24,7 +24,7 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import MLP, LayerNormBlock +from weathergen.model.layers import FEMLP, MLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -318,26 +318,28 @@ def create(self) -> torch.nn.ModuleList: ) ) - # Add MLP block - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, - norm_eps=self.cf.mlp_norm_eps, + if i + 1 == self.cf.ae_global_num_blocks: + self.fe_blocks.append( + FEMLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) ) - ) - - # Add a LayerNorm block as the last block of the FE - if i + 1 == self.cf.fe_num_blocks: + else: self.fe_blocks.append( - LayerNormBlock( + MLP( self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, norm_eps=self.cf.mlp_norm_eps, - elementwise_affine=False, ) ) diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index cbec643fb..17cca11e8 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -95,15 +95,78 @@ def forward(self, *args): return x -class LayerNormBlock(torch.nn.Module): - def __init__(self, dim_out, norm_eps=1e-5, elementwise_affine=False): - super().__init__() - self.ln = nn.LayerNorm( - dim_out, - eps=norm_eps, - elementwise_affine=elementwise_affine, - ) +class FEMLP(torch.nn.Module): + def __init__( + self, + dim_in, + dim_out, + num_layers=2, + hidden_factor=2, + pre_layer_norm=True, + dropout_rate=0.0, + nonlin=torch.nn.GELU, + with_residual=False, + norm_type="LayerNorm", + dim_aux=None, + norm_eps=1e-5, + name: str | None = None, + ): + """Constructor""" + + super(FEMLP, self).__init__() + + if name is not None: + self.name = name + + assert num_layers >= 2 + + self.with_residual = with_residual + self.with_aux = dim_aux is not None + dim_hidden = int(dim_in * hidden_factor) + + self.layers = torch.nn.ModuleList() + + norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm + + if pre_layer_norm: + self.layers.append( + norm(dim_in, eps=norm_eps) + if dim_aux is None + else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + ) + + self.layers.append(torch.nn.Linear(dim_in, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + + for _ in range(num_layers - 2): + self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden)) + self.layers.append(nonlin()) + self.layers.append(torch.nn.Dropout(p=dropout_rate)) + + self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) + + # Add LayerNorm after skip connection if residuals are used + if self.with_residual: + # self.residual_norm = AdaLayerNorm( + # dim_out, dim_aux, norm_eps=norm_eps + # ) # norm(dim_out, eps=norm_eps) + self.residual_norm = torch.nn.LayerNorm(dim_out, eps=norm_eps, elementwise_affine=False) def forward(self, *args): - x = args[0] - return self.ln(x) + x, x_in, aux = args[0], args[0], args[-1] + + for i, layer in enumerate(self.layers): + x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + + if self.with_residual: + if x.shape[-1] == x_in.shape[-1]: + x = x_in + x + else: + assert x.shape[-1] % x_in.shape[-1] == 0 + x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) + + # Apply LayerNorm to the residual connection + x = self.residual_norm(x) + + return x From 66b3523b7976225d381044a05a90855dbc1761ec Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 30 Oct 2025 12:17:25 +0100 Subject: [PATCH 11/11] config epoch 32->64 --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index 8f0f2d459..2a833e7ba 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -109,7 +109,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 32 +num_epochs: 64 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True