Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0adf769
initial commit for ICON data reader - new version
sbAsma Jul 2, 2025
f4f95f7
merging ICON Zarr and ICON CMIP6 NetCDF4 data readers
sbAsma Aug 4, 2025
21c0d27
merged with remote
sbAsma Aug 4, 2025
37d1de3
fixed inconsistencies + added descriptive comments
sbAsma Aug 4, 2025
177650f
removed unecessary attribute
sbAsma Aug 5, 2025
2e39465
added attribute cf.multiprocessing_method
sbAsma Aug 5, 2025
179a5ef
changed self.colnames self.cols_idx from parent to children classes
sbAsma Aug 5, 2025
563d12e
config file for CMIP6
sbAsma Aug 5, 2025
dc5d9b0
ICON config
sbAsma Aug 5, 2025
77509ee
code ruffing
sbAsma Aug 5, 2025
eeb131d
fixed lines that were too long
sbAsma Aug 5, 2025
9af5792
temporary fix for multiprocedding method
sbAsma Sep 28, 2025
4808871
unecessary code removed and added more channels
sbAsma Sep 28, 2025
6b42df0
title: changes to dataloading method
sbAsma Sep 28, 2025
905166e
resolved conflict from merging from develop
sbAsma Sep 28, 2025
a43bd1d
change in len_hrs and step_hrs
sbAsma Oct 5, 2025
3b0008d
config for CMIP6 monthly data
sbAsma Oct 5, 2025
0660a12
title: handling built-in ICON levels for CMIP6 data (details below)
sbAsma Oct 5, 2025
2c93f0b
removed breaking change
sbAsma Oct 5, 2025
83a72be
committing ruffing changes
sbAsma Oct 5, 2025
9f39221
added ICON CMIP6 monthly data to the stream
sbAsma Oct 12, 2025
efd39a3
removed the previous stream
sbAsma Oct 12, 2025
315f902
changed per-level reading for CMIP6 only
sbAsma Oct 12, 2025
ed5975c
dtype missmatch fix - training on ICON CMIP6 day + Amon channels
sbAsma Oct 12, 2025
8e23d16
change to accomodate with ICON CMIP6
sbAsma Oct 12, 2025
06562ea
changed len_hrs value to accommodate monthly data
sbAsma Oct 12, 2025
f333a60
resolved merge conflict
sbAsma Oct 14, 2025
3928ec6
resolving latest pull conflict
sbAsma Oct 14, 2025
1c987a9
excluding more variables
sbAsma Oct 15, 2025
27b33a2
working software stack on Levante for uv venv
sbAsma Oct 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,6 @@ run_id: ???
# Parameters for logging/printing in the training loop
train_log:
# The period to log metrics (in number of batch steps)
log_interval: 20
log_interval: 20

multiprocessing_method: "fork"
21 changes: 21 additions & 0 deletions config/icon_cmip6_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
streams_directory: "./config/streams/streams_icon_cmip6"

start_date: 185001011100
end_date: 200912311200
start_date_val: 201001011100
end_date_val: 201912311200

num_epochs: 200 # 10

# samples_per_epoch: 100 # works # 100 works
# samples_per_validation: 100
# shuffle: True

loader_num_workers: 16
masking_rate: 0.8
multiprocessing_method: "spawn"

# forecast_offset : 1

len_hrs: 745
step_hrs: 24
18 changes: 18 additions & 0 deletions config/icon_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
streams_directory: "./config/streams/streams_icon/"
start_date: 202107010000
end_date: 202107100000
start_date_val: 202208310000
end_date_val: 202212200000 # 202212200000 # dataset ends at 2022-12-31T00:00:00

num_epochs: 10

token_size: 64

step_hrs: 1

samples_per_epoch: 110
samples_per_validation: 17 # TODO @asma: inspect # don't go lower or it would cause an error

loader_num_workers: 8

masking_rate: 0.8
13 changes: 11 additions & 2 deletions config/streams/icon/icon.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,17 @@
ICON :
type : icon
filenames : ['icon-art-NWP_OH_CHEMISTRY-chem_DOM01_ML_daily_repeat_reduced_levels.zarr']
source : ['u_00', 'v_00', 'w_80', 'temp_00']
target : ['u_00', 'v_00', 'w_80', 'temp_00']
source_channels : ['u_00', 'v_00', 'w_80', 'temp_00']
target_channels : ['u_00', 'v_00', 'w_80', 'temp_00']
attributes:
lon: 'clon'
lat: 'clat'
grid: 'ncells'
variables: ['TRCH4_chemtr_00', 'TRCH4_chemtr_20', 'TRCH4_chemtr_40', 'TRCH4_chemtr_60', 'TRCH4_chemtr_80',
'TRO3_chemtr_00', 'TRO3_chemtr_20', 'TRO3_chemtr_40', 'TRO3_chemtr_60', 'TRO3_chemtr_80', 'clat',
'clon', 'pres_00', 'pres_20', 'pres_40', 'pres_60', 'pres_80', 'temp_00', 'temp_20', 'temp_40', 'temp_60',
'temp_80', 'time', 'u_00', 'u_20', 'u_40', 'u_60', 'u_80', 'v_00', 'v_20', 'v_40', 'v_60', 'v_80', 'w_00',
'w_20', 'w_40', 'w_60', 'w_80']
loss_weight : 1.
diagnostic : False
masking_rate : 0.6
Expand Down
67 changes: 67 additions & 0 deletions config/streams/streams_icon_cmip6/icon_cmip6_Amon.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# (C) Copyright 2024 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

ICONCMIP6Amon :
type : iconcmip6
filenames : ['1pctCO2_r1i1p1f1_Amon.json']
attributes:
lon: 'longitude'
lat: 'latitude'
grid: 'i' # i = grid indexes
# source_exclude: [
# "hur", # bcause it doesn't have data on all necessary levels
# 'clt', 'hfss', 'pr', 'rlds', 'rlut', 'rsus', 'tas', 'vas', # duplicate from day
# 'hfls', 'psl', 'rlus', 'rsds', 'sfcWind', 'uas' # duplicate from day
# ]
# target_exclude: [
# "hur", # bcause it doesn't have data on all necessary levels
# 'clt', 'hfss', 'pr', 'rlds', 'rlut', 'rsus', 'tas', 'vas', # duplicate from day
# 'hfls', 'psl', 'rlus', 'rsds', 'sfcWind', 'uas' # duplicate from day
# ]
source_exclude: [ #clivi # ua
"clwvi", "hfls", "hur", "pr", "prsn", "ps", "rlds", "rlus", "rlutcs", "rsdscs", "rsus", "rsut",
"rtmt", "ta", "tauu", "ts", "uas", "vas", "zg", "clt", "evspsbl", "hfss", "hus", "prc", "prw",
"psl", "rldscs", "rlut", "rsds", "rsdt", "rsuscs", "rsutcs", "sfcWind", "tas", "tauv", "va", "wap"
]
target_exclude: [ #clivi # ua
"clwvi", "hfls", "hur", "pr", "prsn", "ps", "rlds", "rlus", "rlutcs", "rsdscs", "rsus", "rsut",
"rtmt", "ta", "tauu", "ts", "uas", "vas", "zg", "clt", "evspsbl", "hfss", "hus", "prc", "prw",
"psl", "rldscs", "rlut", "rsds", "rsdt", "rsuscs", "rsutcs", "sfcWind", "tas", "tauv", "va", "wap"
]
pressure_levels: [
"5000", "10000", "15000", "20000", "25000", "30000",
"40000", "50000", "60000", "70000", "85000", "92500", "100000"
]
variables: [
"clivi", "clwvi", "hfls", "hur", "pr", "prsn", "ps", "rlds", "rlus", "rlutcs", "rsdscs", "rsus", "rsut",
"rtmt", "ta", "tauu", "ts", "uas", "vas", "zg", "clt", "evspsbl", "hfss", "hus", "prc", "prw",
"psl", "rldscs", "rlut", "rsds", "rsdt", "rsuscs", "rsutcs", "sfcWind", "tas", "tauv", "ua", "va", "wap"
]
loss_weight : 1.
diagnostic : False
masking_rate : 0.6
masking_rate_none : 0.05
token_size : 8
embed :
net : transformer
num_tokens : 1
num_heads : 8
dim_embed : 256
num_blocks : 2
embed_target_coords :
net : linear
dim_embed : 256
target_readout :
type : 'obs_value' # token or obs_value
num_layers : 2
num_heads : 4
# sampling_rate : 0.2
pred_head :
ens_size : 1
num_layers : 1
15 changes: 14 additions & 1 deletion packages/evaluate/src/weathergen/evaluate/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,18 @@ def scatter_plot(
vmin = map_kwargs_save.pop("vmin", None)
vmax = map_kwargs_save.pop("vmax", None)
cmap = plt.get_cmap(map_kwargs_save.pop("colormap", "coolwarm"))

# data["lon"] = ((data["lon"] + 180) % 360) - 180
print(f"\n\nvmin = {vmin}", flush=True)
print(f"vmax = {vmax}", flush=True)
print(
f"min(lon) = {min(data['lon'].values)} max(lon) = {max(data['lon'].values)}",
flush=True,
)
print(
f"min(lat) = {min(data['lat'].values)} max(lat) = {max(data['lat'].values)}",
flush=True,
)
print(f"data = {data}\n\n", flush=True)
if isinstance(map_kwargs_save.get("levels", False), oc.listconfig.ListConfig):
norm = mpl.colors.BoundaryNorm(
map_kwargs_save.pop("levels", None), cmap.N, extend="both"
Expand Down Expand Up @@ -451,6 +462,8 @@ def scatter_plot(

valid_time = str(data["valid_time"][0].values.astype("datetime64[m]"))

# print(f"\n\nlons = {lons}\n\n", flush=True)

scatter_plt = ax.scatter(
data["lon"],
data["lat"],
Expand Down
111 changes: 19 additions & 92 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ authors = [
requires-python = ">=3.12,<3.13"
# TODO: split the plotting dependencies into their own dep groups, they are not required.
dependencies = [
'torch==2.6.0',
'numpy~=2.2',
'astropy_healpix~=1.1.2',
'zarr~=2.17',
Expand All @@ -21,16 +22,14 @@ dependencies = [
'packaging',
'wheel',
'psutil',
"flash-attn; sys_platform == 'linux'",
"polars~=1.25.2",
"omegaconf~=2.3.0",
"dask~=2025.5.1",
"hatchling",
"numexpr>=2.11.0",
"weathergen-common",
"weathergen-evaluate",
]


[project.urls]
Homepage = "https://www.weathergenerator.eu"
Documentation = "https://readthedocs.org"
Expand All @@ -41,7 +40,7 @@ Issues = "https://github.com/ecmwf/WeatherGenerator/issues"
train = "weathergen.run_train:train"
train_continue = "weathergen.run_train:train_continue"
inference = "weathergen.run_train:inference"
evaluate = "weathergen.evaluate.run_evaluation:evaluate"
plot = "weathergen.evaluate.plot:plot"
plot_train = "weathergen.utils.plot_training:plot_train"

[build-system]
Expand All @@ -54,33 +53,9 @@ packages = ["src/weathergen"]
[dependency-groups]
# The development dependencies
dev = [
"ipykernel>=6.30.0",
"jupytext>=1.17.2",
"pytest~=8.3.5",
"pytest-mock>=3.14.1",
"ruff==0.9.7",
"tensorboard>=2.20.0",
"pdbpp>=0.11.7",
"pyrefly==0.33.0",
]


# Torch listed as optional dependencies.
# uv and python can only filter dependencies by platform, not by capability.
# Following the recommendations from https://docs.astral.sh/uv/guides/integration/pytorch
# We need to support:
# x86_64: cpu (unit tests) + gpu
# aarch64: gpu
[project.optional-dependencies]

cpu = [
'torch==2.6.0',
]

gpu = [
'torch==2.6.0+cu126',
# flash-attn also has a torch dependency.
"flash-attn",
]


Expand Down Expand Up @@ -143,11 +118,9 @@ ignore = [
line-ending = "lf"




[tool.uv]
# Most work is done a distributed filesystem, where hardlink is not always possible.
# Also, trying to resolve some permissions issue, see 44.
# Also, trying to resolve some permissions issue, see #344.
link-mode = "symlink"
# This guarantees that the build is deterministic and will not be impacted
# by future releases of dependencies or sub-dependencies.
Expand All @@ -161,50 +134,17 @@ link-mode = "symlink"
# Also, relatively recent versions are required to support workspaces.
required-version = ">=0.7.0"

# The supported environments
# TODO: add macos and windows (CPU only, for running tests)
environments = [
"sys_platform == 'linux' and platform_machine == 'aarch64'",
"sys_platform == 'linux' and platform_machine == 'x86_64'",
# "sys_platform == 'darwin'",
]

# One can only have cpu or gpu.
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]


# Following the recommendations from https://docs.astral.sh/uv/guides/integration/pytorch
# The current setup is:
# linux == GPU + flashattention
# windows == GPU
# macos == CPU
[[tool.uv.index]]
name = "pytorch-cu126"
url = "https://download.pytorch.org/whl/cu126"
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true


[tool.pyrefly]
project-includes = ["src/"]
project-excludes = [
]

[tool.pyrefly.errors]
bad-argument-type = false
unsupported-operation = false
missing-attribute = false
no-matching-overload = false
bad-context-manager = false

# To do:
bad-assignment = false
bad-return = false
index-error = false
not-iterable = false
not-callable = false



[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
Expand All @@ -213,26 +153,14 @@ explicit = true
[tool.uv.sources]
weathergen-common = { workspace = true }
weathergen-evaluate = { workspace = true }


flash-attn = [
# The build of Cathal O'Brien is not compatible with the libc build on santis.
# Hardcode the reference to the swiss cluster for the time being.
# TODO: open issue
# { url = "https://github.com/cathalobrien/get-flash-attn/releases/download/v0.1-alpha/flash_attn-2.7.4+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
# This version was rebuilt locally on santis and uploaded.
{ url = "https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl", marker = "sys_platform == 'linux' and platform_machine == 'aarch64'" },
{ url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" },
# { index = "pytorch-cpu", marker = "sys_platform == 'darwin'"},
]


torch = [
# Explicit pin for GPU
{ url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl", marker = 'sys_platform == "linux" and platform_machine == "aarch64"', extra="gpu" },
{ url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", marker = 'sys_platform == "linux" and platform_machine == "x86_64"', extra="gpu" },
# Use the public repo for CPU versions.
{ index = "pytorch-cpu", marker = "sys_platform == 'linux'", extra="cpu"},
{ index = "pytorch-cu124", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
{ index = "pytorch-cpu", marker = "sys_platform == 'macosx'"},
]
# This URL was evaluated this way:
# uv run ~/WeatherGenerator-private/hpc/hpc2020/ecmwf/get-flash-atten.sh
flash-attn = [
{ url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux'" },
]

[tool.pytest.ini_options]
Expand All @@ -246,5 +174,4 @@ log_cli_date_format = "%Y-%m-%d %H:%M:%S"
members = [
"packages/evaluate",
"packages/common"
]

]
Loading
Loading