diff --git a/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp b/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp index 171e3486ba4..82029abbc6a 100644 --- a/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp +++ b/components/eamxx/src/physics/ml_correction/eamxx_ml_correction_process_interface.cpp @@ -104,10 +104,9 @@ void MLCorrection::run_impl(const double dt) { const auto &v = get_field_out("horiz_winds").get_component(1).get_view(); // For precipitation adjustment we need to track the change in column integrated 'qv' - host_view2d_type qv_told("",qv.extent(0),qv.extent(1)); + decltype(qv) qv_told("", qv.extent(0), qv.extent(1)); Kokkos::deep_copy(qv_told,qv); - auto h_lat = m_lat.get_view(); auto h_lon = m_lon.get_view(); @@ -155,8 +154,8 @@ void MLCorrection::run_impl(const double dt) { using MT = typename KT::MemberType; using ESU = ekat::ExeSpaceUtils; const auto &pseudo_density = get_field_in("pseudo_density").get_view(); - const auto &precip_liq_surf_mass = get_field_out("precip_liq_surf_mass").get_view(); - const auto &precip_ice_surf_mass = get_field_out("precip_ice_surf_mass").get_view(); + const auto &precip_liq_surf_mass = get_field_out("precip_liq_surf_mass").get_view(); + const auto &precip_ice_surf_mass = get_field_out("precip_ice_surf_mass").get_view(); constexpr Real g = PC::gravit; const auto num_levs = m_num_levs; const auto policy = ESU::get_default_team_policy(m_num_cols, m_num_levs); diff --git a/components/eamxx/src/physics/ml_correction/ml_correction.py b/components/eamxx/src/physics/ml_correction/ml_correction.py index 4e22f711c5a..acecdd95b46 100644 --- a/components/eamxx/src/physics/ml_correction/ml_correction.py +++ b/components/eamxx/src/physics/ml_correction/ml_correction.py @@ -25,7 +25,7 @@ def ensure_correction_ordering(correction): return correction -def get_ML_correction_dQ1_dQ2(model, T_mid, qv, cos_zenith, dt): +def get_ML_correction_dQ1_dQ2(model, T_mid, qv, cos_zenith, lat, phis, dt): """Get ML correction for air temperature (dQ1) and specific humidity (dQ2) Args: @@ -40,6 +40,8 @@ def get_ML_correction_dQ1_dQ2(model, T_mid, qv, cos_zenith, dt): T_mid=(["ncol", "z"], T_mid), qv=(["ncol", "z"], qv), cos_zenith_angle=(["ncol"], cos_zenith), + lat=(["ncol"], lat), + surface_geopotential=(["ncol"], phis), ) ) return ensure_correction_ordering(predict(model, ds, dt)) @@ -180,7 +182,13 @@ def update_fields( ) if model_tq is not None: correction_tq = get_ML_correction_dQ1_dQ2( - model_tq, T_mid, qv[:, 0, :], cos_zenith, dt + model_tq, + T_mid, + qv[:, 0, :], + cos_zenith, + lat, + phis, + dt ) T_mid[:, :] += correction_tq["dQ1"].values * dt qv[:, 0, :] += correction_tq["dQ2"].values * dt