Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions data_process/convert_makani_output_to_wb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def convert(file_names_to_convert: List[str], output_file: str, batch_size: Opti
longitudes = comm.bcast(longitudes, root=0)
timestamps = comm.bcast(timestamps, root=0)
entries_per_year = comm.bcast(entries_per_year, root=0)

# IMPORTANT! ECMWF convention flips the latitudes, so that they start on the south pole
# we use co-latitude definition, where 90 degrees is the north pole
latitudes = np.flip(latitudes)

# total hours:
total_entries = sum(entries_per_year)
Expand Down
2 changes: 1 addition & 1 deletion data_process/convert_wb2_to_makani_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# MPI
from mpi4py import MPI

from .wb2_helpers import surface_variables, atmospheric_variables, split_convert_channel_names, DistributedProgressBar, gcs_storage_options
from .wb2_helpers import split_convert_channel_names, DistributedProgressBar, gcs_storage_options


def convert(input_file: str, output_dir: str, metadata_file: str, years: List[int],
Expand Down
20 changes: 12 additions & 8 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG DLFW_VERSION
ARG DLFW_VERSION=25.12
FROM nvcr.io/nvidia/pytorch:${DLFW_VERSION}-py3

# update repo info
Expand All @@ -30,21 +30,22 @@ RUN SETUPTOOLS_USE_DISTUTILS=local pip install mpi4py
RUN pip install onnx onnxruntime onnxruntime-gpu

# hdf5 and h5py
ENV HDF5_VERSION=1.14.6
RUN cd /tmp && wget https://support.hdfgroup.org/releases/hdf5/v1_14/v1_14_6/downloads/hdf5-${HDF5_VERSION}.tar.gz && \
ENV HDF5_VERSION=2.0.0
RUN cd /tmp && wget https://support.hdfgroup.org/releases/hdf5/v2_0/v2_0_0/downloads/hdf5-${HDF5_VERSION}.tar.gz && \
gzip -cd hdf5-${HDF5_VERSION}.tar.gz | tar xvf - && \
mkdir hdf5-${HDF5_VERSION}/build && cd hdf5-${HDF5_VERSION}/build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/hdf5 \
-DHDF5_ENABLE_OPTIMIZATION=ON \
-DHDF5_ENABLE_ZLIB_SUPPORT=ON \
-DHDF5_ENABLE_DIRECT_VFD=ON \
-DHDF5_ENABLE_ROS3_VFD=ON \
-DHDF5_ENABLE_PARALLEL=ON \
-DHDF5_TEST_API=ON \
-DHDF5_TEST_VFD=ON \
-DHDF5_TEST_PARALLEL=ON \
.. && \
make -j 8 && make install
RUN CC="mpicc" HDF5_MPI=ON H5PY_ROS3=1 H5PY_DIRECT_VFD=1 HDF5_DIR=/opt/hdf5 pip install --no-build-isolation --no-binary=h5py h5py==3.13.0
make -j 8 && make install && \
mkdir -p /opt/hdf5/lib/plugin
RUN CC="mpicc" HDF5_MPI=ON H5PY_DIRECT_VFD=1 HDF5_DIR=/opt/hdf5 pip install --no-build-isolation --no-binary=h5py h5py==3.15.1
ENV PATH=/opt/hdf5/bin:${PATH}

# install cdsapi for downloading the dataset
Expand All @@ -66,13 +67,16 @@ ENV NUMBA_DISABLE_CUDA=1
# scoring tools
RUN pip install xskillscore properscoring

# weatherbench2
RUN pip install "git+https://github.com/google-research/weatherbench2.git"

# some useful scripts from mlperf
RUN pip install --ignore-installed "git+https://github.com/NVIDIA/mlperf-common.git"

# torch-harmonics
ENV FORCE_CUDA_EXTENSION 1
ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 9.0 10.0 12.0+PTX"
ENV HARMONICS_VERSION 0.8.3
ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 8.7 9.0 10.0 12.0+PTX"
ENV HARMONICS_VERSION 0.8.5
RUN cd /opt && git clone https://github.com/NVIDIA/torch-harmonics.git && \
cd torch-harmonics && \
pip install --no-build-isolation -e .
Expand Down
183 changes: 183 additions & 0 deletions docker/Dockerfile.aws
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This Dockerfile adds S3 support to HDF5 so that makani can read datasets from AWS storage.
# Since it requires installing a lot of dependencies, we decided to keep it in a separate file

ARG DLFW_VERSION=25.12
FROM nvcr.io/nvidia/pytorch:${DLFW_VERSION}-py3

# update repo info
RUN apt update -y && apt install -y libibmad5

# upgrade cmake
RUN apt remove cmake -y && \
pip install cmake --upgrade

# install mpi4py
RUN SETUPTOOLS_USE_DISTUTILS=local pip install mpi4py

# install onnx
RUN pip install onnx onnxruntime onnxruntime-gpu

# install AWS prereqs for HDF5
RUN cd /opt && git clone https://github.com/awslabs/aws-c-common.git && \
cd aws-c-common && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
make install

# 2. aws-checksums
RUN cd /opt && git clone https://github.com/awslabs/aws-checksums.git && \
cd aws-checksums && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=/usr && \
make -j$(nproc) && \
make install

# 2.5. s2n-tls (required by aws-c-io)
RUN cd /opt && git clone https://github.com/aws/s2n-tls.git && \
cd s2n-tls && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release -GNinja && \
ninja -j$(nproc) && \
ninja install

# 3. aws-c-cal (crypto abstraction layer)
RUN cd /opt && git clone https://github.com/awslabs/aws-c-cal.git && \
cd aws-c-cal && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
make install

# 3.1 aws-c-sdkutils (new dependency for auth/http/s3)
RUN cd /opt && \
git clone https://github.com/awslabs/aws-c-sdkutils.git && \
cd aws-c-sdkutils && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=/usr/local && \
make -j$(nproc) && \
make install

# 5. aws-c-io (needs common + checksums)
RUN cd /opt && git clone https://github.com/awslabs/aws-c-io.git && \
cd aws-c-io && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
make install

# 2.7 aws-c-compression (depends on aws-c-common)
RUN cd /opt && \
git clone https://github.com/awslabs/aws-c-compression.git && \
cd aws-c-compression && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=/usr/local && \
make -j$(nproc) && \
make install

# 6. aws-c-http (needs io + auth)
RUN cd /opt && git clone https://github.com/awslabs/aws-c-http.git && \
cd aws-c-http && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
make install

# 4. aws-c-auth
RUN cd /opt && git clone https://github.com/awslabs/aws-c-auth.git && \
cd aws-c-auth && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
make install

# 7. aws-c-s3 (final target for HDF5)
RUn cd /opt && git clone https://github.com/awslabs/aws-c-s3.git && \
cd aws-c-s3 && \
mkdir build && cd build && \
cmake .. -DCMAKE_BUILD_TYPE=Release && \
make -j$(nproc) && \
make install

# hdf5 and h5py
ENV HDF5_VERSION=2.0.0
RUN cd /tmp && wget https://support.hdfgroup.org/releases/hdf5/v2_0/v2_0_0/downloads/hdf5-${HDF5_VERSION}.tar.gz && \
gzip -cd hdf5-${HDF5_VERSION}.tar.gz | tar xvf - && \
mkdir hdf5-${HDF5_VERSION}/build && cd hdf5-${HDF5_VERSION}/build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/hdf5 \
-DHDF5_ENABLE_OPTIMIZATION=ON \
-DHDF5_ENABLE_ZLIB_SUPPORT=ON \
-DHDF5_ENABLE_DIRECT_VFD=ON \
-DHDF5_ENABLE_ROS3_VFD=ON \
-DHDF5_ENABLE_PARALLEL=ON \
-DHDF5_TEST_API=ON \
-DHDF5_TEST_VFD=ON \
-DHDF5_TEST_PARALLEL=ON \
.. && \
make -j 8 && make install && \
mkdir -p /opt/hdf5/lib/plugin
RUN CC="mpicc" HDF5_MPI=ON H5PY_ROS3=1 H5PY_DIRECT_VFD=1 HDF5_DIR=/opt/hdf5 pip install --no-build-isolation --no-binary=h5py h5py==3.15.1
ENV PATH=/opt/hdf5/bin:${PATH}

# install cdsapi for downloading the dataset
RUN pip install cdsapi>=0.7.2

# install zarr and data stuff
RUN pip install more_itertools zarr xarray pandas gcsfs boto3

# moviepy imageio for wandb video logging
RUN pip install moviepy imageio

# other python stuff
RUN pip install --upgrade wandb ruamel.yaml tqdm progressbar2

# numba
RUN pip install numba
ENV NUMBA_DISABLE_CUDA=1

# scoring tools
RUN pip install xskillscore properscoring

# weatherbench2
RUN pip install "git+https://github.com/google-research/weatherbench2.git"

# some useful scripts from mlperf
RUN pip install --ignore-installed "git+https://github.com/NVIDIA/mlperf-common.git"

# torch-harmonics
ENV FORCE_CUDA_EXTENSION 1
ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 8.7 9.0 10.0 12.0+PTX"
ENV HARMONICS_VERSION 0.8.5
RUN cd /opt && git clone https://github.com/NVIDIA/torch-harmonics.git && \
cd torch-harmonics && \
pip install --no-build-isolation -e .

# physicsnemo
RUN pip install git+https://github.com/NVIDIA/physicsnemo.git@v1.1.1

# copy source code
RUN mkdir -p /opt/makani
COPY config /opt/makani/config
COPY docker /opt/makani/docker
COPY data_process /opt/makani/data_process
COPY datasets /opt/makani/datasets
COPY makani /opt/makani/makani
COPY tests /opt/makani/tests
COPY pyproject.toml /opt/makani/pyproject.toml
COPY README.md /opt/makani/README.md
RUN cd /opt/makani && pip install -e .
3 changes: 0 additions & 3 deletions makani/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
import argparse
import os
import glob
import tempfile
from collections import OrderedDict

import pynvml
import torch
from torch import nn

Expand All @@ -29,7 +27,6 @@
from makani.models import model_registry

# distributed computing stuff
from makani.utils import comm
from makani.utils.driver import Driver
from makani.utils.YParams import ParamsBase

Expand Down
1 change: 0 additions & 1 deletion makani/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
import numpy as np
import torch
import logging
from functools import partial
Expand Down
2 changes: 0 additions & 2 deletions makani/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.

import os
import numpy as np
import argparse
import torch
import logging
from functools import partial
Expand Down
11 changes: 5 additions & 6 deletions makani/models/common/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from torch import amp

from typing import Tuple, List, Optional
from typing import Tuple, Optional

# quadrature stuff
from makani.utils.grids import grid_to_quadrature_rule, GridQuadrature
Expand Down Expand Up @@ -73,19 +73,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape

xtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.to(torch.float32)
with amp.autocast(device_type=x.device.type, enabled=False):
xf = x.to(torch.float32)

# compute var and mean
mean = self.quadrature(x)
var = self.quadrature(torch.square(x - mean.reshape(B, C, 1, 1)))
mean = self.quadrature(xf)
var = self.quadrature(torch.square(xf - mean.reshape(B, C, 1, 1)))

# reshape
var = var.reshape(B, C, 1, 1)
mean = mean.reshape(B, C, 1, 1)

# convert types
x = x.to(xtype)
mean = mean.to(xtype)
var = var.to(xtype)

Expand Down
3 changes: 1 addition & 2 deletions makani/models/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

import torch
import torch.nn as nn
from torch.nn.modules.container import Sequential
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torch.utils.checkpoint import checkpoint
import math

from makani.utils.context import rng_context
Expand Down
Loading
Loading