diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 8b939ed6d6..07917b9b0b 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -492,7 +492,7 @@ def __init__( offset_provider=self._grid.connectivities, ) - self._apply_divergence_damping_and_update_vn = setup_program( + self._apply_divergence_damping_and_update_vn_first_half = setup_program( backend=backend, program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn, constant_args={ @@ -517,6 +517,35 @@ def __init__( }, vertical_sizes={ "vertical_start": gtx.int32(0), + "vertical_end": gtx.int32(self._grid.num_levels // 2), + }, + offset_provider=self._grid.connectivities, + ) + self._apply_divergence_damping_and_update_vn_second_half = setup_program( + backend=backend, + program=compute_edge_diagnostics_for_dycore_and_update_vn.apply_divergence_damping_and_update_vn, + constant_args={ + "horizontal_mask_for_3d_divdamp": self._metric_state_nonhydro.horizontal_mask_for_3d_divdamp, + "scaling_factor_for_3d_divdamp": self._metric_state_nonhydro.scaling_factor_for_3d_divdamp, + "inv_dual_edge_length": self._edge_geometry.inverse_dual_edge_lengths, + "nudgecoeff_e": self._interpolation_state.nudgecoeff_e, + "geofac_grdiv": self._interpolation_state.geofac_grdiv, + "advection_explicit_weight_parameter": self._params.advection_explicit_weight_parameter, + "advection_implicit_weight_parameter": self._params.advection_implicit_weight_parameter, + "iau_wgt_dyn": self._config.iau_wgt_dyn, + "is_iau_active": self._config.is_iau_active, + "limited_area": self._grid.limited_area, + }, + variants={ + "apply_2nd_order_divergence_damping": [False, True], + "apply_4th_order_divergence_damping": [False, True], + }, + horizontal_sizes={ + "horizontal_start": gtx.int32(self._start_edge_nudging_level_2), + "horizontal_end": self._end_edge_local, + }, + vertical_sizes={ + "vertical_start": gtx.int32(self._grid.num_levels // 2), "vertical_end": gtx.int32(self._grid.num_levels), }, offset_provider=self._grid.connectivities, @@ -547,26 +576,51 @@ def __init__( offset_provider=self._grid.connectivities, ) - self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection = setup_program( - backend=backend, - program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, - constant_args={ - "e_flx_avg": self._interpolation_state.e_flx_avg, - "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, - }, - variants={ - "at_first_substep": [False, True], - "prepare_advection": [False, True], - }, - horizontal_sizes={ - "horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5), - "horizontal_end": self._end_edge_halo_level_2, - }, - vertical_sizes={ - "vertical_start": gtx.int32(0), - "vertical_end": gtx.int32(self._grid.num_levels), - }, - offset_provider=self._grid.connectivities, + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half = ( + setup_program( + backend=backend, + program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, + constant_args={ + "e_flx_avg": self._interpolation_state.e_flx_avg, + "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, + }, + variants={ + "at_first_substep": [False, True], + "prepare_advection": [False, True], + }, + horizontal_sizes={ + "horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5), + "horizontal_end": self._end_edge_halo_level_2, + }, + vertical_sizes={ + "vertical_start": gtx.int32(0), + "vertical_end": gtx.int32(self._grid.num_levels // 2), + }, + offset_provider=self._grid.connectivities, + ) + ) + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half = ( + setup_program( + backend=backend, + program=compute_averaged_vn_and_fluxes_and_prepare_tracer_advection, + constant_args={ + "e_flx_avg": self._interpolation_state.e_flx_avg, + "ddqz_z_full_e": self._metric_state_nonhydro.ddqz_z_full_e, + }, + variants={ + "at_first_substep": [False, True], + "prepare_advection": [False, True], + }, + horizontal_sizes={ + "horizontal_start": gtx.int32(self._start_edge_lateral_boundary_level_5), + "horizontal_end": self._end_edge_halo_level_2, + }, + vertical_sizes={ + "vertical_start": gtx.int32(self._grid.num_levels // 2), + "vertical_end": gtx.int32(self._grid.num_levels), + }, + offset_provider=self._grid.connectivities, + ) ) self._vertically_implicit_solver_at_predictor_step = setup_program( @@ -819,6 +873,8 @@ def __init__( self.p_test_run = False + self._first_half_cache = {} + self._second_half_cache = {} self._dtime_previous_substep: float = 0.0 """ Dynamic substep length of previous substep in order to track if rayleigh damping coefficients need to be @@ -1268,6 +1324,38 @@ def run_predictor_step( log.debug("exchanging prognostic field 'w'") self._exchange.exchange_and_wait(dims.CellDim, prognostic_states.next.w) + def get_first_half_vn(self, vn: gtx.Field): + try: + return self._first_half_cache[vn.__gt_buffer_info__.hash_key] + except KeyError: + self._first_half_cache[vn.__gt_buffer_info__.hash_key] = gtx_common._field( + vn.ndarray[:, : self._grid.num_levels // 2], + domain=gtx_common.Domain( + dims=vn.domain.dims, + ranges=( + vn.domain.ranges[0], + gtx_common.UnitRange(0, self._grid.num_levels // 2), + ), + ), + ) + return self._first_half_cache[vn.__gt_buffer_info__.hash_key] + + def get_second_half_vn(self, vn: gtx.Field): + try: + return self._second_half_cache[vn.__gt_buffer_info__.hash_key] + except KeyError: + self._second_half_cache[vn.__gt_buffer_info__.hash_key] = gtx_common._field( + vn.ndarray[:, self._grid.num_levels // 2 :], + domain=gtx_common.Domain( + dims=vn.domain.dims, + ranges=( + vn.domain.ranges[0], + gtx_common.UnitRange(self._grid.num_levels // 2, self._grid.num_levels), + ), + ), + ) + return self._second_half_cache[vn.__gt_buffer_info__.hash_key] + def run_corrector_step( self, diagnostic_state_nh: dycore_states.DiagnosticStateNonHydro, @@ -1341,7 +1429,32 @@ def run_corrector_step( ) ) - self._apply_divergence_damping_and_update_vn( + # EXCHANGE OVERLAP EXPERIMENT START + self._apply_divergence_damping_and_update_vn_first_half( + horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, + next_vn=prognostic_states.next.vn, + current_vn=prognostic_states.current.vn, + dwdz_at_cells_on_model_levels=z_fields.dwdz_at_cells_on_model_levels, + predictor_normal_wind_advective_tendency=diagnostic_state_nh.normal_wind_advective_tendency.predictor, + corrector_normal_wind_advective_tendency=diagnostic_state_nh.normal_wind_advective_tendency.corrector, + normal_wind_tendency_due_to_slow_physics_process=diagnostic_state_nh.normal_wind_tendency_due_to_slow_physics_process, + normal_wind_iau_increment=diagnostic_state_nh.normal_wind_iau_increment, + theta_v_at_edges_on_model_levels=z_fields.theta_v_at_edges_on_model_levels, + horizontal_pressure_gradient=z_fields.horizontal_pressure_gradient, + reduced_fourth_order_divdamp_coeff_at_nest_boundary=self.reduced_fourth_order_divdamp_coeff_at_nest_boundary, + fourth_order_divdamp_scaling_coeff=self.fourth_order_divdamp_scaling_coeff, + second_order_divdamp_scaling_coeff=second_order_divdamp_scaling_coeff, + dtime=dtime, + apply_2nd_order_divergence_damping=apply_2nd_order_divergence_damping, + apply_4th_order_divergence_damping=apply_4th_order_divergence_damping, + ) + + log.debug("exchanging prognostic field 'vn' first half") + + first_half_vn = self.get_first_half_vn(prognostic_states.next.vn) + first_half_exchange = self._exchange.exchange(dims.EdgeDim, first_half_vn) + + self._apply_divergence_damping_and_update_vn_second_half( horizontal_gradient_of_normal_wind_divergence=z_fields.horizontal_gradient_of_normal_wind_divergence, next_vn=prognostic_states.next.vn, current_vn=prognostic_states.current.vn, @@ -1360,10 +1473,27 @@ def run_corrector_step( apply_4th_order_divergence_damping=apply_4th_order_divergence_damping, ) - log.debug("exchanging prognostic field 'vn'") - self._exchange.exchange_and_wait(dims.EdgeDim, (prognostic_states.next.vn)) + log.debug("exchanging prognostic field 'vn' second half") + # TODO(havogt): this wait could be after the next exchange starts, but we need to duplicate the ghex communication object + first_half_exchange.wait() + second_half_vn = self.get_second_half_vn(prognostic_states.next.vn) + second_half_exchange = self._exchange.exchange(dims.EdgeDim, second_half_vn) + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_first_half( + spatially_averaged_vn=self.z_vn_avg, + mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, + theta_v_flux_at_edges_on_model_levels=self.theta_v_flux_at_edges_on_model_levels, + substep_and_spatially_averaged_vn=prep_adv.vn_traj, + substep_averaged_mass_flux=prep_adv.mass_flx_me, + vn=prognostic_states.next.vn, + rho_at_edges_on_model_levels=z_fields.rho_at_edges_on_model_levels, + theta_v_at_edges_on_model_levels=z_fields.theta_v_at_edges_on_model_levels, + prepare_advection=lprep_adv, + at_first_substep=at_first_substep, + r_nsubsteps=r_nsubsteps, + ) - self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection( + second_half_exchange.wait() + self._compute_averaged_vn_and_fluxes_and_prepare_tracer_advection_second_half( spatially_averaged_vn=self.z_vn_avg, mass_flux_at_edges_on_model_levels=diagnostic_state_nh.mass_flux_at_edges_on_model_levels, theta_v_flux_at_edges_on_model_levels=self.theta_v_flux_at_edges_on_model_levels, @@ -1376,6 +1506,7 @@ def run_corrector_step( at_first_substep=at_first_substep, r_nsubsteps=r_nsubsteps, ) + # EXCHANGE OVERLAP EXPERIMENT END self._vertically_implicit_solver_at_corrector_step( next_w=prognostic_states.next.w, diff --git a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py index 6145782474..e68925ebb4 100644 --- a/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py +++ b/model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py @@ -239,11 +239,8 @@ def exchange(self, dim: gtx.Dimension, *fields: gtx.Field) -> MultiNodeResult: the granule context where fields otherwise have length nproma. """ applied_patterns = [self._get_applied_pattern(dim, f) for f in fields] - assert hasattr(fields[0], "array_ns") - if hasattr(fields[0].array_ns, "cuda"): - # TODO(havogt): this is a workaround as ghex does not know that it should synchronize - # the GPU before the exchange. This is necessary to ensure that all data is ready for the exchange. - fields[0].array_ns.cuda.runtime.deviceSynchronize() + # With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream, + # otherwise we need an explicit device synchronize here. handle = self._comm.exchange(applied_patterns) log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.") return MultiNodeResult(handle, applied_patterns) diff --git a/pyproject.toml b/pyproject.toml index 8f343dfdf8..ab0c7b845c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -355,7 +355,7 @@ url = "https://test.pypi.org/simple/" [tool.uv.sources] dace = {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_11_05"} -# ghex = {git = "https://github.com/ghex-org/GHEX.git", branch = "master"} +ghex = {git = "https://github.com/msimberg/GHEX.git", branch = "async-mpi"} # gt4py = {git = "https://github.com/GridTools/gt4py", branch = "main"} # gt4py = {index = "test.pypi"} icon4py-atmosphere-advection = {workspace = true} diff --git a/uv.lock b/uv.lock index 5889f36141..6a1cbeea7c 100644 --- a/uv.lock +++ b/uv.lock @@ -1352,13 +1352,12 @@ wheels = [ [[package]] name = "ghex" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } +version = "0.4.1" +source = { git = "https://github.com/msimberg/GHEX.git?branch=async-mpi#6d896166994cedbcfc50da1873239a5edb212e3f" } dependencies = [ { name = "mpi4py" }, { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d4/4f/d6217b2afcecff78620c8d3df315b3a354820447ad48962889fe029a3b2c/ghex-0.4.0.tar.gz", hash = "sha256:65135fee88a0bea16bbcc6a48fda9065850db7af4340726c0ea804affed04890", size = 8309041, upload-time = "2024-12-18T14:40:05.407Z" } [[package]] name = "gitdb" @@ -1881,7 +1880,7 @@ requires-dist = [ { name = "cupy-cuda12x", marker = "extra == 'cuda12'", specifier = ">=13.0" }, { name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_11_05" }, { name = "datashader", marker = "extra == 'io'", specifier = ">=0.16.1" }, - { name = "ghex", marker = "extra == 'distributed'", specifier = ">=0.3.0" }, + { name = "ghex", marker = "extra == 'distributed'", git = "https://github.com/msimberg/GHEX.git?branch=async-mpi" }, { name = "gt4py", specifier = "==1.1.0" }, { name = "gt4py", extras = ["cuda11"], marker = "extra == 'cuda11'" }, { name = "gt4py", extras = ["cuda12"], marker = "extra == 'cuda12'" },