Skip to content

Commit

Permalink
Add denormalisation routine for the gw_convect_dp net.
Browse files Browse the repository at this point in the history
  • Loading branch information
jatkinson1000 committed Aug 19, 2024
1 parent 5644950 commit fd4959b
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/physics/cam/gw_ml.F90
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,8 @@ subroutine gw_drag_convect_dp_ml(ncol, dt, &
call torch_delete(net_input_tensors)
call torch_delete(net_output_tensors)

! Extract data and return
do i = 1, ncol
utgw(i, :) = net_outputs(1:pver, i)
vtgw(i, :) = net_outputs(pver+1:2*pver, i)
end do
! Denormalise outputs and extract the data
call denormalise_data(ncol, utgw, vtgw, net_outputs)

end subroutine gw_drag_convect_dp_ml

Expand Down Expand Up @@ -393,5 +390,20 @@ subroutine normalise_data(ncol, u, v, t, dse, nm, netdt, zm, rhoi, ps, lat, lon,

end subroutine normalise_data

subroutine denormalise_data(ncol, utgw, vtgw, nn_output)

integer, intent(in) :: ncol
real(r8), intent(out) :: utgw(ncol,pver), vtgw(ncol,pver)
real(r8), intent(in) :: nn_output(2*pver, ncol)

integer :: i

! Extract data, denormalise, and deconcatenate from NN output tensor
do i = 1, ncol
utgw(i, :) = (nn_output(1:pver, i) * utgw_std(:)) + utgw_mean(:)
vtgw(i, :) = (nn_output(pver+1:2*pver, i) * vtgw_std(:)) + vtgw_mean(:)
end do

end subroutine denormalise_data

end module gw_ml

0 comments on commit fd4959b

Please sign in to comment.