Skip to content

Commit

Permalink
Implemented relative error estimates for adaptive step size selection (
Browse files Browse the repository at this point in the history
  • Loading branch information
brownbaerchen authored Jan 23, 2025
1 parent 5ac6441 commit 5b92e94
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 21 deletions.
27 changes: 22 additions & 5 deletions pySDC/implementations/convergence_controller_classes/adaptivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def setup(self, controller, params, description, **kwargs):
"""
defaults = {
"embedded_error_flavor": 'standard',
"rel_error": False,
}
return {**defaults, **super().setup(controller, params, description, **kwargs)}

Expand All @@ -328,6 +329,9 @@ def dependencies(self, controller, description, **kwargs):
controller.add_convergence_controller(
EstimateEmbeddedError.get_implementation(self.params.embedded_error_flavor, self.params.useMPI),
description=description,
params={
'rel_error': self.params.rel_error,
},
)

# load contraction factor estimator if necessary
Expand Down Expand Up @@ -837,6 +841,8 @@ def setup(self, controller, params, description, **kwargs):

defaults = {
'control_order': -50,
'problem_mesh_type': 'numpyesque',
'rel_error': False,
**super().setup(controller, params, description, **kwargs),
**params,
}
Expand All @@ -858,16 +864,27 @@ def dependencies(self, controller, description, **kwargs):
Returns:
None
"""
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
EstimatePolynomialError,
)
if self.params.problem_mesh_type.lower() == 'numpyesque':
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
EstimatePolynomialError as error_estimation_cls,
)
elif self.params.problem_mesh_type.lower() == 'firedrake':
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
EstimatePolynomialErrorFiredrake as error_estimation_cls,
)
else:
raise NotImplementedError(
f'Don\'t know what error estimation class to use for problems with mesh type {self.params.problem_mesh_type}'
)

super().dependencies(controller, description, **kwargs)

controller.add_convergence_controller(
EstimatePolynomialError,
error_estimation_cls,
description=description,
params={},
params={
'rel_error': self.params.rel_error,
},
)
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def setup(self, controller, params, description, **kwargs):
return {
"control_order": -80,
"sweeper_type": sweeper_type,
"rel_error": False,
**super().setup(controller, params, description, **kwargs),
}

Expand Down Expand Up @@ -94,13 +95,24 @@ def estimate_embedded_error_serial(self, L):
"""
if self.params.sweeper_type == "RK":
L.sweep.compute_end_point()
return abs(L.uend - L.sweep.u_secondary)
if self.params.rel_error:
return abs(L.uend - L.sweep.u_secondary) / abs(L.uend)
else:
return abs(L.uend - L.sweep.u_secondary)
elif self.params.sweeper_type == "SDC":
# order rises by one between sweeps
return abs(L.uold[-1] - L.u[-1])
if self.params.rel_error:
return abs(L.uold[-1] - L.u[-1]) / abs(L.u[-1])
else:
return abs(L.uold[-1] - L.u[-1])
elif self.params.sweeper_type == 'MPI':
comm = L.sweep.comm
return comm.bcast(abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]), root=comm.size - 1)
if self.params.rel_error:
return comm.bcast(
abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]) / abs(L.u[comm.rank + 1]), root=comm.size - 1
)
else:
return comm.bcast(abs(L.uold[comm.rank + 1] - L.u[comm.rank + 1]), root=comm.size - 1)
else:
raise NotImplementedError(
f"Don't know how to estimate embedded error for sweeper type \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def setup(self, controller, params, description, **kwargs):
defaults = {
'control_order': -75,
'estimate_on_node': num_nodes + 1 if quad_type == 'GAUSS' else num_nodes - 1,
'rel_error': False,
**super().setup(controller, params, description, **kwargs),
}
self.comm = description['sweeper_params'].get('comm', None)
Expand Down Expand Up @@ -103,6 +104,23 @@ def matmul(self, A, b, xp=np):
else:
return A @ xp.asarray(b)

def get_interpolated_solution(self, L, xp):
"""
Get the interpolated solution for numpy or cupy data types
Args:
u_vec (array): Vector of solutions
prob (pySDC.problem): Problem
"""
coll = L.sweep.coll

u = [
L.u[i].flatten() if L.u[i] is not None else L.u[i]
for i in range(coll.num_nodes + 1)
if i != self.params.estimate_on_node
]
return self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])

def post_iteration_processing(self, controller, S, **kwargs):
"""
Estimate the error
Expand All @@ -120,20 +138,19 @@ def post_iteration_processing(self, controller, S, **kwargs):
coll = L.sweep.coll
nodes = np.append(np.append(0, coll.nodes), 1.0)
estimate_on_node = self.params.estimate_on_node
xp = L.u[0].xp

if hasattr(L.u[0], 'xp'):
xp = L.u[0].xp
else:
xp = np

if self.interpolation_matrix is None:
interpolator = LagrangeApproximation(
points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
)
self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))

u = [
L.u[i].flatten() if L.u[i] is not None else L.u[i]
for i in range(coll.num_nodes + 1)
if i != estimate_on_node
]
u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
u_inter = self.get_interpolated_solution(L, xp)

# compute end point if needed
if estimate_on_node == len(nodes) - 1:
Expand All @@ -147,12 +164,14 @@ def post_iteration_processing(self, controller, S, **kwargs):
rank = estimate_on_node - 1
L.status.order_embedded_estimate = coll.num_nodes * 1

rescale = float(abs(u_inter)) if self.params.rel_error else 1

if self.comm:
buf = np.array(abs(u_inter - high_order_sol) if self.comm.rank == rank else 0.0)
buf = np.array(abs(u_inter - high_order_sol) / rescale if self.comm.rank == rank else 0.0)
self.comm.Bcast(buf, root=rank)
L.status.error_embedded_estimate = float(buf)
else:
L.status.error_embedded_estimate = abs(u_inter - high_order_sol)
L.status.error_embedded_estimate = abs(u_inter - high_order_sol) / rescale

self.debug(
f'Obtained error estimate: {L.status.error_embedded_estimate:.2e} of order {L.status.order_embedded_estimate}',
Expand All @@ -176,3 +195,59 @@ def check_parameters(self, controller, params, description, **kwargs):
return False, 'Need at least two collocation nodes to interpolate to one!'

return True, ""


class EstimatePolynomialErrorFiredrake(EstimatePolynomialError):
def matmul(self, A, b):
"""
Matrix vector multiplication, possibly MPI parallel.
The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
entire vector once could reduce the number of communications, this way we never need to store the entire vector
on any specific rank.
Args:
A (2d np.ndarray): Matrix
b (list): Vector
Returns:
List: Axb
"""

if self.comm:
res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
buf = 0 * b[0]
for i in range(0, A.shape[0]):
index = self.comm.rank + (1 if self.comm.rank < self.params.estimate_on_node - 1 else 0)
send_buf = (
(A[i, index] * b[index]) if self.comm.rank != self.params.estimate_on_node - 1 else 0 * res[0]
)
self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
res[i] += buf
return res
else:
res = []
for i in range(A.shape[0]):
res.append(A[i, 0] * b[0])
for j in range(1, A.shape[1]):
res[-1] += A[i, j] * b[j]

return res

def get_interpolated_solution(self, L):
"""
Get the interpolated solution for Firedrake data types
We are not 100% sure that you don't need to invert the mass matrix here, but should be fine.
Args:
u_vec (array): Vector of solutions
prob (pySDC.problem): Problem
"""
coll = L.sweep.coll

u = [
L.u[i] if L.u[i] is not None else L.u[i]
for i in range(coll.num_nodes + 1)
if i != self.params.estimate_on_node
]
return L.prob.dtype_u(self.matmul(self.interpolation_matrix, u)[0])
# return L.prob.invert_mass_matrix(self.matmul(self.interpolation_matrix, u)[0])
83 changes: 79 additions & 4 deletions pySDC/tests/test_convergence_controllers/test_polynomial_error.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
def get_controller(dt, num_nodes, quad_type, useMPI, useGPU, rel_error):
"""
Get a controller prepared for polynomial test equation
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
description['sweeper_params'] = sweeper_params
description['level_params'] = level_params
description['step_params'] = step_params
description['convergence_controllers'] = {EstimatePolynomialError: {}}
description['convergence_controllers'] = {EstimatePolynomialError: {'rel_error': rel_error}}

controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
return controller
Expand Down Expand Up @@ -177,13 +177,15 @@ def check_order(dts, **kwargs):
@pytest.mark.base
@pytest.mark.parametrize('num_nodes', [2, 3, 4, 5])
@pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])
def test_interpolation_error(num_nodes, quad_type):
@pytest.mark.parametrize('rel_error', [True, False])
def test_interpolation_error(num_nodes, quad_type, rel_error):
import numpy as np

kwargs = {
'num_nodes': num_nodes,
'quad_type': quad_type,
'useMPI': False,
'rel_error': rel_error,
}
steps = np.logspace(-1, -4, 20)
check_order(steps, **kwargs)
Expand All @@ -200,6 +202,7 @@ def test_interpolation_error_GPU(num_nodes, quad_type):
'quad_type': quad_type,
'useMPI': False,
'useGPU': True,
'rel_error': False,
}
steps = np.logspace(-1, -4, 20)
check_order(steps, **kwargs)
Expand Down Expand Up @@ -228,6 +231,77 @@ def test_interpolation_error_MPI(num_nodes, quad_type):
)


@pytest.mark.firedrake
def test_polynomial_error_firedrake(dt=1.0, num_nodes=3, useMPI=False):
from pySDC.implementations.problem_classes.HeatFiredrake import Heat1DForcedFiredrake
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.convergence_controller_classes.estimate_polynomial_error import (
EstimatePolynomialErrorFiredrake,
LagrangeApproximation,
)
import numpy as np

if useMPI:
from pySDC.implementations.sweeper_classes.generic_implicit_MPI import generic_implicit_MPI as sweeper_class
from mpi4py import MPI

comm = MPI.COMM_WORLD
else:
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class

comm = None

level_params = {}
level_params['dt'] = dt
level_params['restol'] = 1.0

sweeper_params = {}
sweeper_params['quad_type'] = 'RADAU-RIGHT'
sweeper_params['num_nodes'] = num_nodes
sweeper_params['comm'] = comm

problem_params = {'n': 1}

step_params = {}
step_params['maxiter'] = 0

controller_params = {}
controller_params['logger_level'] = 30
controller_params['mssdc_jac'] = False

description = {}
description['problem_class'] = Heat1DForcedFiredrake
description['problem_params'] = problem_params
description['sweeper_class'] = sweeper_class
description['sweeper_params'] = sweeper_params
description['level_params'] = level_params
description['step_params'] = step_params
description['convergence_controllers'] = {EstimatePolynomialErrorFiredrake: {}}

controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)

L = controller.MS[0].levels[0]

cont = controller.convergence_controllers[
np.arange(len(controller.convergence_controllers))[
[type(me).__name__ == 'EstimatePolynomialErrorFiredrake' for me in controller.convergence_controllers]
][0]
]

nodes = np.append(np.append(0, L.sweep.coll.nodes), 1.0)
estimate_on_node = cont.params.estimate_on_node
interpolator = LagrangeApproximation(points=[nodes[i] for i in range(num_nodes + 1) if i != estimate_on_node])
cont.interpolation_matrix = np.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))

for i in range(num_nodes + 1):
L.u[i] = L.prob.u_init
L.u[i].functionspace.assign(nodes[i])

u_inter = cont.get_interpolated_solution(L)
error = abs(u_inter - L.u[estimate_on_node])
assert np.isclose(error, 0)


if __name__ == "__main__":
import sys
import numpy as np
Expand All @@ -238,7 +312,8 @@ def test_interpolation_error_MPI(num_nodes, quad_type):
kwargs = {
'num_nodes': int(sys.argv[1]),
'quad_type': sys.argv[2],
'rel_error': False,
}
check_order(steps, useMPI=True, **kwargs)
else:
check_order(steps, useMPI=False, num_nodes=3, quad_type='RADAU-RIGHT')
check_order(steps, useMPI=False, num_nodes=3, quad_type='RADAU-RIGHT', rel_error=False)

0 comments on commit 5b92e94

Please sign in to comment.