Skip to content

Commit 913463d

Browse files
Added coupling between pySDC and Gusto (#516)
* Added coupling between pySDC and Gusto * Removed unused parameters * Update pySDC/implementations/problem_classes/GenericGusto.py Co-authored-by: Dr Jemma Shipton <j.shipton@exeter.ac.uk> * Update pySDC/implementations/problem_classes/GenericGusto.py Co-authored-by: Dr Jemma Shipton <j.shipton@exeter.ac.uk> * Update pySDC/implementations/problem_classes/GenericGusto.py Co-authored-by: Dr Jemma Shipton <j.shipton@exeter.ac.uk> * Update pySDC/tutorial/step_7/F_pySDC_with_Gusto.py Co-authored-by: Dr Jemma Shipton <j.shipton@exeter.ac.uk> * Minor refactor --------- Co-authored-by: Dr Jemma Shipton <j.shipton@exeter.ac.uk>
1 parent 5b92e94 commit 913463d

File tree

10 files changed

+1676
-1
lines changed

10 files changed

+1676
-1
lines changed

.github/workflows/ci_pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ jobs:
173173
user_firedrake_tests:
174174
runs-on: ubuntu-latest
175175
container:
176-
image: firedrakeproject/firedrake-vanilla:latest
176+
image: firedrakeproject/firedrake-vanilla:2025-01
177177
options: --user root
178178
volumes:
179179
- ${{ github.workspace }}:/repositories

docs/source/tutorial/doc_step_7_E.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Full code: `pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py <https://github.com/Parallel-in-Time/pySDC/blob/master/pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py>`_
2+
3+
.. literalinclude:: ../../../pySDC/tutorial/step_7/E_pySDC_with_Firedrake.py

docs/source/tutorial/doc_step_7_F.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Full code: `pySDC/tutorial/step_7/F_pySDC_with_Gusto.py <https://github.com/Parallel-in-Time/pySDC/blob/master/pySDC/tutorial/step_7/F_pySDC_with_Gusto.py>`_
2+
Find a suitable plotting script here: `pySDC/tutorial/step_7/F_2_plot_pySDC_with_Gusto_result.py <https://github.com/Parallel-in-Time/pySDC/blob/master/pySDC/tutorial/step_7/F_2_plot_pySDC_with_Gusto_result.py>`_
3+
4+
.. literalinclude:: ../../../pySDC/tutorial/step_7/F_pySDC_with_Gusto.py
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import firedrake as fd
2+
3+
from gusto.time_discretisation.time_discretisation import TimeDiscretisation, wrapper_apply
4+
from gusto.core.labels import explicit
5+
6+
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
7+
from pySDC.implementations.problem_classes.GenericGusto import GenericGusto, GenericGustoImex
8+
from pySDC.core.hooks import Hooks
9+
from pySDC.helpers.stats_helper import get_sorted
10+
11+
12+
class LogTime(Hooks):
13+
"""
14+
Utility hook for knowing how far we got when using adaptive step size selection.
15+
"""
16+
17+
def post_step(self, step, level_number):
18+
L = step.levels[level_number]
19+
self.add_to_stats(
20+
process=step.status.slot,
21+
process_sweeper=L.sweep.rank,
22+
time=L.time,
23+
level=-1,
24+
iter=-1,
25+
sweep=-1,
26+
type='_time',
27+
value=L.time + L.dt,
28+
)
29+
30+
31+
class pySDC_integrator(TimeDiscretisation):
32+
"""
33+
This class can be entered into Gusto as a time discretization scheme and will solve steps using pySDC.
34+
It will construct a pySDC controller which can be used by itself and will be used within the time step when called
35+
from Gusto. Access the controller via `pySDC_integrator.controller`. This class also has `pySDC_integrator.stats`,
36+
which gathers all of the pySDC stats recorded in the hooks during every time step when used within Gusto.
37+
"""
38+
39+
def __init__(
40+
self,
41+
equation,
42+
description,
43+
controller_params,
44+
domain,
45+
field_name=None,
46+
solver_parameters=None,
47+
options=None,
48+
t0=0,
49+
imex=False,
50+
):
51+
"""
52+
Initialization
53+
54+
Args:
55+
equation (:class:`PrognosticEquation`): the prognostic equation.
56+
description (dict): pySDC description
57+
controller_params (dict): pySDC controller params
58+
domain (:class:`Domain`): the model's domain object, containing the
59+
mesh and the compatible function spaces.
60+
field_name (str, optional): name of the field to be evolved.
61+
Defaults to None.
62+
solver_parameters (dict, optional): dictionary of parameters to
63+
pass to the underlying solver. Defaults to None.
64+
options (:class:`AdvectionOptions`, optional): an object containing
65+
options to either be passed to the spatial discretisation, or
66+
to control the "wrapper" methods, such as Embedded DG or a
67+
recovery method. Defaults to None.
68+
"""
69+
70+
self._residual = None
71+
72+
super().__init__(
73+
domain=domain,
74+
field_name=field_name,
75+
solver_parameters=solver_parameters,
76+
options=options,
77+
)
78+
79+
self.description = description
80+
self.controller_params = controller_params
81+
self.timestepper = None
82+
self.dt_next = None
83+
self.imex = imex
84+
85+
def setup(self, equation, apply_bcs=True, *active_labels):
86+
super().setup(equation, apply_bcs, *active_labels)
87+
88+
# Check if any terms are explicit
89+
imex = any(t.has_label(explicit) for t in equation.residual) or self.imex
90+
if imex:
91+
self.description['problem_class'] = GenericGustoImex
92+
else:
93+
self.description['problem_class'] = GenericGusto
94+
95+
self.description['problem_params'] = {
96+
'equation': equation,
97+
'solver_parameters': self.solver_parameters,
98+
'residual': self._residual,
99+
}
100+
self.description['level_params']['dt'] = float(self.domain.dt)
101+
102+
# add utility hook required for step size adaptivity
103+
hook_class = self.controller_params.get('hook_class', [])
104+
if not type(hook_class) == list:
105+
hook_class = [hook_class]
106+
hook_class.append(LogTime)
107+
self.controller_params['hook_class'] = hook_class
108+
109+
# prepare controller and variables
110+
self.controller = controller_nonMPI(1, description=self.description, controller_params=self.controller_params)
111+
self.prob = self.level.prob
112+
self.sweeper = self.level.sweep
113+
self.x0_pySDC = self.prob.dtype_u(self.prob.init)
114+
self.t = 0
115+
self.stats = {}
116+
117+
@property
118+
def residual(self):
119+
"""Make sure the pySDC problem residual and this residual are the same"""
120+
if hasattr(self, 'prob'):
121+
return self.prob.residual
122+
else:
123+
return self._residual
124+
125+
@residual.setter
126+
def residual(self, value):
127+
"""Make sure the pySDC problem residual and this residual are the same"""
128+
if hasattr(self, 'prob'):
129+
self.prob.residual = value
130+
else:
131+
self._residual = value
132+
133+
@property
134+
def level(self):
135+
"""Get the finest pySDC level"""
136+
return self.controller.MS[0].levels[0]
137+
138+
@wrapper_apply
139+
def apply(self, x_out, x_in):
140+
"""
141+
Apply the time discretization to advance one whole time step.
142+
143+
Args:
144+
x_out (:class:`Function`): the output field to be computed.
145+
x_in (:class:`Function`): the input field.
146+
"""
147+
self.x0_pySDC.functionspace.assign(x_in)
148+
assert self.level.params.dt == float(self.dt), 'Step sizes have diverged between pySDC and Gusto'
149+
150+
if self.dt_next is not None:
151+
assert (
152+
self.timestepper is not None
153+
), 'You need to set self.timestepper to the timestepper in order to facilitate adaptive step size selection here!'
154+
self.timestepper.dt = fd.Constant(self.dt_next)
155+
self.t = self.timestepper.t
156+
157+
uend, _stats = self.controller.run(u0=self.x0_pySDC, t0=float(self.t), Tend=float(self.t + self.dt))
158+
159+
# update time variables
160+
if self.level.params.dt != float(self.dt):
161+
self.dt_next = self.level.params.dt
162+
163+
self.t = get_sorted(_stats, type='_time', recomputed=False)[-1][1]
164+
165+
# update time of the Gusto stepper.
166+
# After this step, the Gusto stepper updates its time again to arrive at the correct time
167+
if self.timestepper is not None:
168+
self.timestepper.t = fd.Constant(self.t - self.dt)
169+
170+
self.dt = self.level.params.dt
171+
172+
# update stats and output
173+
self.stats = {**self.stats, **_stats}
174+
x_out.assign(uend.functionspace)

0 commit comments

Comments
 (0)