Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This commit addresses two fixes in Corrective-ML #2824

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ void MLCorrection::run_impl(const double dt) {
const auto &v = get_field_out("horiz_winds").get_component(1).get_view<Real **, Host>();

// For precipitation adjustment we need to track the change in column integrated 'qv'
host_view2d_type qv_told("",qv.extent(0),qv.extent(1));
decltype(qv) qv_told("", qv.extent(0), qv.extent(1));
Kokkos::deep_copy(qv_told,qv);


auto h_lat = m_lat.get_view<const Real*,Host>();
auto h_lon = m_lon.get_view<const Real*,Host>();

Expand Down Expand Up @@ -155,8 +154,8 @@ void MLCorrection::run_impl(const double dt) {
using MT = typename KT::MemberType;
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;
const auto &pseudo_density = get_field_in("pseudo_density").get_view<const Real**>();
const auto &precip_liq_surf_mass = get_field_out("precip_liq_surf_mass").get_view<Real *, Host>();
const auto &precip_ice_surf_mass = get_field_out("precip_ice_surf_mass").get_view<Real *, Host>();
const auto &precip_liq_surf_mass = get_field_out("precip_liq_surf_mass").get_view<Real *>();
const auto &precip_ice_surf_mass = get_field_out("precip_ice_surf_mass").get_view<Real *>();
constexpr Real g = PC::gravit;
const auto num_levs = m_num_levs;
const auto policy = ESU::get_default_team_policy(m_num_cols, m_num_levs);
Expand Down
12 changes: 10 additions & 2 deletions components/eamxx/src/physics/ml_correction/ml_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ensure_correction_ordering(correction):
return correction


def get_ML_correction_dQ1_dQ2(model, T_mid, qv, cos_zenith, dt):
def get_ML_correction_dQ1_dQ2(model, T_mid, qv, cos_zenith, lat, phis, dt):
"""Get ML correction for air temperature (dQ1) and specific humidity (dQ2)

Args:
Expand All @@ -40,6 +40,8 @@ def get_ML_correction_dQ1_dQ2(model, T_mid, qv, cos_zenith, dt):
T_mid=(["ncol", "z"], T_mid),
qv=(["ncol", "z"], qv),
cos_zenith_angle=(["ncol"], cos_zenith),
lat=(["ncol"], lat),
surface_geopotential=(["ncol"], phis),
)
)
return ensure_correction_ordering(predict(model, ds, dt))
Expand Down Expand Up @@ -180,7 +182,13 @@ def update_fields(
)
if model_tq is not None:
correction_tq = get_ML_correction_dQ1_dQ2(
model_tq, T_mid, qv[:, 0, :], cos_zenith, dt
model_tq,
T_mid,
qv[:, 0, :],
cos_zenith,
lat,
phis,
dt
)
T_mid[:, :] += correction_tq["dQ1"].values * dt
qv[:, 0, :] += correction_tq["dQ2"].values * dt
Expand Down