Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More simplifications after grudge array container support #579

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions examples/lump-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pyopencl.tools as cl_tools
from functools import partial

from arraycontext import flatten
from meshmode.array_context import (
PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext
Expand All @@ -43,7 +44,8 @@
from mirgecom.euler import euler_operator
from mirgecom.simutil import (
get_sim_timestep,
generate_and_distribute_mesh
generate_and_distribute_mesh,
componentwise_norms
)
from mirgecom.io import make_init_message
from mirgecom.mpi import mpi_entry_point
Expand Down Expand Up @@ -247,10 +249,10 @@ def my_health_check(dv, state, exact):
health_error = True
logger.info(f"{rank=}: Invalid pressure data found.")

from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
max_error = actx.to_numpy(actx.np.max(component_errors))
exittol = .09
if max(component_errors) > exittol:
if max_error > exittol:
health_error = True
if rank == 0:
logger.info("Solution diverged from exact soln.")
Expand Down Expand Up @@ -296,11 +298,11 @@ def my_pre_step(step, t, dt, state):
if do_status:
if exact is None:
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
flat_component_errors = actx.to_numpy(flatten(
componentwise_norms(discr, state - exact), actx))
status_msg = (
"------- errors="
+ ", ".join("%.3g" % en for en in component_errors))
+ ", ".join("%.3g" % en for en in flat_component_errors))
if rank == 0:
logger.info(status_msg)

Expand Down
16 changes: 9 additions & 7 deletions examples/mixture-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pyopencl.tools as cl_tools
from functools import partial

from arraycontext import flatten
from meshmode.array_context import (
PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext
Expand All @@ -43,7 +44,8 @@
from mirgecom.euler import euler_operator
from mirgecom.simutil import (
get_sim_timestep,
generate_and_distribute_mesh
generate_and_distribute_mesh,
componentwise_norms
)
from mirgecom.io import make_init_message
from mirgecom.mpi import mpi_entry_point
Expand Down Expand Up @@ -225,9 +227,10 @@ def main(ctx_factory=cl.create_some_context, use_logmgr=True,
logger.info(init_message)

def my_write_status(component_errors):
flat_component_errors = actx.to_numpy(flatten(component_errors, actx))
status_msg = (
"------- errors="
+ ", ".join("%.3g" % en for en in component_errors))
+ ", ".join("%.3g" % en for en in flat_component_errors))
if rank == 0:
logger.info(status_msg)

Expand Down Expand Up @@ -270,8 +273,9 @@ def my_health_check(dv, component_errors):
health_error = True
logger.info(f"{rank=}: Invalid pressure data found.")

max_error = actx.to_numpy(actx.np.max(component_errors))
exittol = .09
if max(component_errors) > exittol:
if max_error > exittol:
health_error = True
if rank == 0:
logger.info("Solution diverged from exact soln.")
Expand All @@ -296,8 +300,7 @@ def my_pre_step(step, t, dt, state):
if do_health:
dv = eos.dependent_vars(state)
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
health_errors = global_reduce(
my_health_check(dv, component_errors), op="lor")
if health_errors:
Expand All @@ -321,8 +324,7 @@ def my_pre_step(step, t, dt, state):
if component_errors is None:
if exact is None:
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
my_write_status(component_errors)

except MyRuntimeError:
Expand Down
16 changes: 9 additions & 7 deletions examples/scalar-lump-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from functools import partial
from pytools.obj_array import make_obj_array

from arraycontext import flatten
from meshmode.array_context import (
PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext
Expand All @@ -44,7 +45,8 @@
from mirgecom.euler import euler_operator
from mirgecom.simutil import (
get_sim_timestep,
generate_and_distribute_mesh
generate_and_distribute_mesh,
componentwise_norms
)
from mirgecom.io import make_init_message
from mirgecom.mpi import mpi_entry_point
Expand Down Expand Up @@ -215,10 +217,11 @@ def main(ctx_factory=cl.create_some_context, use_logmgr=True,
logger.info(init_message)

def my_write_status(component_errors):
flat_component_errors = actx.to_numpy(flatten(component_errors, actx))
if rank == 0:
logger.info(
"------- errors="
+ ", ".join("%.3g" % en for en in component_errors))
+ ", ".join("%.3g" % en for en in flat_component_errors))

def my_write_viz(step, t, state, dv=None, exact=None, resid=None):
if dv is None:
Expand Down Expand Up @@ -258,8 +261,9 @@ def my_health_check(pressure, component_errors):
health_error = True
logger.info(f"{rank=}: Invalid pressure data found.")

max_error = actx.to_numpy(actx.np.max(component_errors))
exittol = .09
if max(component_errors) > exittol:
if max_error > exittol:
health_error = True
if rank == 0:
logger.info("Solution diverged from exact soln.")
Expand All @@ -284,8 +288,7 @@ def my_pre_step(step, t, dt, state):
if do_health:
dv = eos.dependent_vars(state)
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
health_errors = global_reduce(
my_health_check(dv.pressure, component_errors), op="lor")
if health_errors:
Expand All @@ -309,8 +312,7 @@ def my_pre_step(step, t, dt, state):
if component_errors is None:
if exact is None:
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
my_write_status(component_errors)

except MyRuntimeError:
Expand Down
17 changes: 9 additions & 8 deletions examples/sod-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pyopencl.tools as cl_tools
from functools import partial

from arraycontext import flatten
from meshmode.array_context import (
PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext
Expand All @@ -43,7 +44,8 @@
from mirgecom.euler import euler_operator
from mirgecom.simutil import (
get_sim_timestep,
generate_and_distribute_mesh
generate_and_distribute_mesh,
componentwise_norms
)
from mirgecom.io import make_init_message
from mirgecom.mpi import mpi_entry_point
Expand Down Expand Up @@ -204,10 +206,11 @@ def main(ctx_factory=cl.create_some_context, use_logmgr=True,
logger.info(init_message)

def my_write_status(component_errors):
flat_component_errors = actx.to_numpy(flatten(component_errors, actx))
if rank == 0:
logger.info(
"------- errors="
+ ", ".join("%.3g" % en for en in component_errors)
+ ", ".join("%.3g" % en for en in flat_component_errors)
)

def my_write_viz(step, t, state, dv=None, exact=None, resid=None):
Expand Down Expand Up @@ -248,8 +251,9 @@ def my_health_check(pressure, component_errors):
health_error = True
logger.info(f"{rank=}: Invalid pressure data found.")

max_error = actx.to_numpy(actx.np.max(component_errors))
exittol = .09
if max(component_errors) > exittol:
if max_error > exittol:
health_error = True
if rank == 0:
logger.info("Solution diverged from exact soln.")
Expand All @@ -274,8 +278,7 @@ def my_pre_step(step, t, dt, state):
if do_health:
dv = eos.dependent_vars(state)
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
health_errors = global_reduce(
my_health_check(dv.pressure, component_errors), op="lor")
if health_errors:
Expand All @@ -299,9 +302,7 @@ def my_pre_step(step, t, dt, state):
if component_errors is None:
if exact is None:
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = \
compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
my_write_status(component_errors)

except MyRuntimeError:
Expand Down
16 changes: 9 additions & 7 deletions examples/vortex-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pyopencl.tools as cl_tools
from functools import partial

from arraycontext import flatten
from meshmode.array_context import (
PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext
Expand All @@ -43,7 +44,8 @@
from mirgecom.simutil import (
get_sim_timestep,
generate_and_distribute_mesh,
check_step
check_step,
componentwise_norms
)
from mirgecom.io import make_init_message
from mirgecom.mpi import mpi_entry_point
Expand Down Expand Up @@ -232,10 +234,11 @@ def my_write_status(state, component_errors, cfl=None):
discr, "vol",
get_inviscid_cfl(discr, eos, current_dt, cv=state)))[()]
if rank == 0:
flat_component_errors = actx.to_numpy(flatten(component_errors, actx))
logger.info(
f"------ {cfl=}\n"
"------- errors="
+ ", ".join("%.3g" % en for en in component_errors))
+ ", ".join("%.3g" % en for en in flat_component_errors))

def my_write_viz(step, t, state, dv=None, exact=None, resid=None):
if dv is None:
Expand Down Expand Up @@ -275,8 +278,9 @@ def my_health_check(pressure, component_errors):
health_error = True
logger.info(f"{rank=}: Invalid pressure data found.")

max_error = actx.to_numpy(actx.np.max(component_errors))
exittol = .1
if max(component_errors) > exittol:
if max_error > exittol:
health_error = True
if rank == 0:
logger.info("Solution diverged from exact soln.")
Expand All @@ -300,8 +304,7 @@ def my_pre_step(step, t, dt, state):
if do_health:
dv = eos.dependent_vars(state)
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
health_errors = global_reduce(
my_health_check(dv.pressure, component_errors), op="lor")
if health_errors:
Expand All @@ -316,8 +319,7 @@ def my_pre_step(step, t, dt, state):
if component_errors is None:
if exact is None:
exact = initializer(x_vec=nodes, eos=eos, time=t)
from mirgecom.simutil import compare_fluid_solutions
component_errors = compare_fluid_solutions(discr, state, exact)
component_errors = componentwise_norms(discr, state - exact)
my_write_status(state, component_errors)

if do_viz:
Expand Down
11 changes: 2 additions & 9 deletions mirgecom/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@
import numpy as np
import grudge.dof_desc as dof_desc

from arraycontext import map_array_container

from functools import partial

from arraycontext import multimapped_over_array_containers
from meshmode.dof_array import DOFArray

from pytools import keyed_memoize_in
Expand Down Expand Up @@ -167,6 +164,7 @@ def apply_spectral_filter(actx, modal_field, discr, cutoff,
)


@multimapped_over_array_containers(leaf_class=DOFArray)
def filter_modally(dcoll, dd, cutoff, mode_resp_func, field):
"""Stand-alone procedural interface to spectral filtering.

Expand Down Expand Up @@ -200,11 +198,6 @@ def filter_modally(dcoll, dd, cutoff, mode_resp_func, field):
result: :class:`mirgecom.fluid.ConservedVars`
An array container containing the filtered field(s).
"""
if not isinstance(field, DOFArray):
return map_array_container(
partial(filter_modally, dcoll, dd, cutoff, mode_resp_func), field
)

actx = field.array_context
dd = dof_desc.as_dofdesc(dd)
dd_modal = dof_desc.DD_VOLUME_MODAL
Expand Down
34 changes: 11 additions & 23 deletions mirgecom/simutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

.. autofunction:: compare_fluid_solutions
.. autofunction:: componentwise_norms
.. autofunction:: max_component_norm
.. autofunction:: check_naninf_local
.. autofunction:: check_range_local

Expand Down Expand Up @@ -50,9 +49,7 @@
import numpy as np
import grudge.op as op

from arraycontext import map_array_container, flatten

from functools import partial
from arraycontext import multimapped_over_array_containers

from meshmode.dof_array import DOFArray

Expand Down Expand Up @@ -294,42 +291,33 @@ def check_naninf_local(discr, dd, field):
def compare_fluid_solutions(discr, red_state, blue_state):
"""Return inf norm of (*red_state* - *blue_state*) for each component.

Deprecated. Do not use in new code.

.. note::
This is a collective routine and must be called by all MPI ranks.
"""
from warnings import warn
warn("compare_fluid_solutions is deprecated and will disappear in Q3 2022. "
"Use componentwise_norms instead.", DeprecationWarning, stacklevel=2)

actx = red_state.array_context
resid = red_state - blue_state
from arraycontext import flatten
resid_errs = actx.to_numpy(
flatten(componentwise_norms(discr, resid, order=np.inf), actx))

return resid_errs.tolist()


# FIXME: Add componentwise norm functionality to grudge?
@multimapped_over_array_containers(leaf_class=DOFArray)
def componentwise_norms(discr, fields, order=np.inf):
"""Return the *order*-norm for each component of *fields*.

.. note::
This is a collective routine and must be called by all MPI ranks.
"""
if not isinstance(fields, DOFArray):
return map_array_container(
partial(componentwise_norms, discr, order=order), fields)
if len(fields) > 0:
return discr.norm(fields, order)
else:
# FIXME: This work-around for #575 can go away after #569
return 0


def max_component_norm(discr, fields, order=np.inf):
"""Return the max *order*-norm over the components of *fields*.

.. note::
This is a collective routine and must be called by all MPI ranks.
"""
actx = fields.array_context
return max(actx.to_numpy(flatten(
componentwise_norms(discr, fields, order), actx)))
return discr.norm(fields, order)


def generate_and_distribute_mesh(comm, generate_mesh):
Expand Down
Loading