diff --git a/doc/arkode/guide/source/Constants.rst b/doc/arkode/guide/source/Constants.rst index dcb293581d..3d7c7494e1 100644 --- a/doc/arkode/guide/source/Constants.rst +++ b/doc/arkode/guide/source/Constants.rst @@ -626,9 +626,10 @@ contains the ARKODE output constants. +-------------------------------------+------+------------------------------------------------------------+ | :index:`ARK_STEP_DIRECTION_ERR` | -52 | An error occurred changing the step direction. | +-------------------------------------+------+------------------------------------------------------------+ - | :index:`ARK_UNRECOGNIZED_ERROR` | -99 | An unknown error was encountered. | + | :index:`ARK_ADJ_RECOMPUTE_FAIL` | -53 | An occurred recomputing steps during the adjoint | + | | | integration. | +-------------------------------------+------+------------------------------------------------------------+ - | | + | :index:`ARK_UNRECOGNIZED_ERROR` | -99 | An unknown error was encountered. | +-------------------------------------+------+------------------------------------------------------------+ | **ARKLS linear solver module output constants** | +-------------------------------------+------+------------------------------------------------------------+ diff --git a/doc/arkode/guide/source/Mathematics.rst b/doc/arkode/guide/source/Mathematics.rst index 2e62c18ba7..09bb4ecb36 100644 --- a/doc/arkode/guide/source/Mathematics.rst +++ b/doc/arkode/guide/source/Mathematics.rst @@ -978,7 +978,7 @@ arise from the **separable** Hamiltonian system where .. math:: - f_1(t, q) \equiv -\frac{\partial V(t, q)}{\partial q}, \qquad + f_1(t, q) \equiv \frac{\partial V(t, q)}{\partial q}, \qquad f_2(t, p) \equiv \frac{\partial T(t, p)}{\partial p}. When *H* is autonomous, then *H* is a conserved quantity. Often this corresponds @@ -2697,3 +2697,111 @@ by :math:`\eta_\text{rf}`. For more information on utilizing relaxation Runge--Kutta methods, see :numref:`ARKODE.Usage.Relaxation`. + + +.. _ARKODE.Mathematics.ASA: + +Adjoint Sensitivity Analysis +============================ + +Consider :eq:`ARKODE_IVP_simple_explicit`, but where the ODE also depends on some parameters +:math:`p` (that is, we have :math:`f(t,y,p)`). Now, suppose we have a functional :math:`g(y(t_f),p)` +for which we would like to compute the gradients :math:`\partial g(y(t_f),p)/\partial y(t_0)` +and/or :math:`\partial g(y(t_f),p)/\partial p`. This most often arises in the form of an +optimization problem such as + +.. math:: + \min_{\xi} \bar{\Psi}(\xi) = g(y(t_f), p) + :label: ARKODE_OPTIMIZATION_PROBLEM + +where :math:`\xi \subset \{y(t_0), p\}`. The adjoint method is one approach to obtaining the +gradients that is particularly efficient when there are relatively few functionals and a large +number of parameters. While :ref:`CVODES ` and +:ref:`IDAS ` *continuous* adjoint methods +(differentiate-then-discretize), ARKODE provides *discrete* adjoint methods +(discretize-then-differentiate). For the continuous approach, we derive and solve the adjoint ODE +backwards in time + +.. math:: + \lambda'(t) &= -f_y^T(t, y, p) \lambda,\quad \lambda(t_F) = g_y^T(y(t_f), p), \\ + \mu'(t) &= -f_p^T(t, y, p) \mu,\quad \mu(t_F) = g_p^T(y(t_f), p), \quad t_f \geq t \geq t_0, \\ + :label: ARKODE_CONTINUOUS_ADJOINT_ODE + +where :math:`\lambda(t) \in \mathbb{R}^N`, :math:`\mu(t) \in \mathbb{R}^{N_s}` +:math:`f_y \equiv \partial f/\partial y \in \mathbb{R}^{N \times N}` is the Jacobian with respect to the dependent variable, +and :math:`f_p \equiv \partial f/\partial p \in \mathbb{R}^{N \times N_s}` is the Jacobian with respect to the parameters +(:math:`N` is the size of the original ODE, :math:`N_s` is the number of parameters). +When solved with a numerical time integration scheme, the solution to the continuous adjoint ODE +are numerical approximations of the continuous adjoint sensitivities + +.. math:: + \lambda(t_0) \approx g_y^T(y(t_0), p),\quad \mu(t_0) \approx g_p^T(y(t_0), p) + :label: ARKODE_CONTINUOUS_ADJOINT_SOLUTION + +For the discrete adjoint approach, we first numerically discretize the original ODE :eq:`ARKODE_IVP_simple_explicit`. +In the context of ARKODE, this is done with a one-step time integration scheme :math:`\varphi` so that + +.. math:: + y_0 = y(t_0),\quad y_n = \varphi(y_{n-1}). + :label: ARKODE_DISCRETE_ODE + +Reformulating the optimization problem for the discrete case, we have + +.. math:: + \min_{\xi} \Psi(\xi) = g(y_n, p) + :label: ARKODE_DISCRETE_OPTIMIZATION_PROBLEM + +The gradients of :eq:`ARKODE_DISCRETE_OPTIMIZATION_PROBLEM` can be computed using the transposed chain +rule backwards in time to obtain the discete adjoint variables :math:`\lambda_n, \lambda_{n-1}, \cdots, \lambda_0` +and :math:`\mu_n, \mu_{n-1}, \cdots, \mu_0`, + +.. math:: + \lambda_n &= g_y^T(y_n, p), \quad \lambda_k = \left(\frac{\partial \varphi}{\partial y_k}(y_k, p)\right)^T \lambda_{k+1} \\ + \mu_n &= g_p^T(y_n, p), \quad \mu_k = \left(\frac{\partial \varphi}{\partial p}(y_k, p)\right)^T \lambda_{k+1}, + \quad k = n - 1, \cdots, 0. + :label: ARKODE_DISCRETE_ADJOINT + +The solution of the discrete adjoint equations :eq:`ARKODE_DISCRETE_ADJOINT` is the sensitivities of the discrete cost function +:eq:`ARKODE_DISCRETE_OPTIMIZATION_PROBLEM` with respect to changes in the discretized ODE :eq:`ARKODE_DISCRETE_ODE`. + +.. math:: + \lambda_0 = g_y^T(y_0, p), \quad \mu_0 = g_p^T(y_0, p). + :label: ARKODE_DISCRETE_ADJOINT_SOLUTION + +Given an s-stage explicit Runge--Kutta method (as in :eq:`ARKODE_ERK`, but without the embedding), the discrete adjoint +to compute :math:`\lambda_n` and :math:`\mu_n` starting from :math:`\lambda_{n+1}` and +:math:`\mu_{n+1}` is given by + +.. math:: + \Lambda_i &= h_n f_y^T(t_{n,i}, z_i) \left(b_i \lambda_{n+1} + \sum_{j=i+1}^s a_{j,i} + \Lambda_j \right), \quad \quad i = s, \dots, 1,\\ + \nu_i &= h_n f_p^T(t_{n,i}, z_i, p) \left(b_i \lambda_{n+1} + \sum_{j=i}^{s} a_{ji} \Lambda_j \right), \\ + \lambda_n &= \lambda_{n+1} + \sum_{j=1}^{s} \Lambda_j, \\ + \mu_n &= \mu_{n+1} + \sum_{j=1}^{s} \nu_j. + :label: ARKODE_ERK_ADJOINT + +For more information on performing discrete adjoint sensitivity analysis using ARKODE see, +:numref:`ARKODE.Usage.ARKStep.ASA`. + +For a detailed derivation of the discrete adjoint methods see :cite:p:`hager2000runge,sanduDiscrete2006`. +For a detailed derivation of the continuous adjoint method see :ref:`CVODES `, +or :cite:p:`CLPS:03`. + + +Discrete vs. Continuous Adjoint Method +-------------------------------------- + +It is understood that the continuous adjoint method can be problematic in the context of +optimization problems because the continuous adjoint method provides an approximation to the +gradient of a continuous cost function while the optimizer is expecting the gradient of the discrete +cost function. The discrepancy means that the optimizer can fail to converge further once it is near +a local minimum :cite:p:`giles2000introduction`. On the other hand, the discrete adjoint method +provides the exact gradient of the discrete cost function allowing the optimizer to fully converge. +Consequently, the discrete adjoint method is often preferable in optimization despite its own +drawbacks -- such as its (relatively) increased memory usage and the possible introduction of +unphysical computational modes :cite:p:`sirkes1997finite`. This is not to say that the discrete +adjoint approach is always the better choice over the continuous adjoint approach in optimization. +Computational efficiency and stability of one approach over the other can be both problem and method +dependent. Section 8 in the paper :cite:p:`rackauckas2020universal` discusses the tradeoffs further +and provides numerous references that may help inform users in choosing between the discrete and +continuous adjoint approaches. diff --git a/doc/arkode/guide/source/Usage/ARKStep/ASA.rst b/doc/arkode/guide/source/Usage/ARKStep/ASA.rst new file mode 100644 index 0000000000..25e2b6454a --- /dev/null +++ b/doc/arkode/guide/source/Usage/ARKStep/ASA.rst @@ -0,0 +1,58 @@ +.. _ARKODE.Usage.ARKStep.ASA: + +Adjoint Sensitivity Analysis +============================ + +The previous sections discuss using ARKStep for the integration of forward ODE models. +This section discusses how to use ARKStep for adjoint sensitivity analysis as introduced +in :numref:`ARKODE.Mathematics.ASA`. To use ARKStep for adjoint sensitivity analysis (ASA), users simply setup the forward +integration as usual (following :numref:`ARKODE.Usage.Skeleton`) with one exception: +a :c:type:`SUNAdjointCheckpointScheme` object must be created and passed to +:c:func:`ARKodeSetAdjointCheckpointScheme` before the call to the :c:func:`ARKodeEvolve` +function. After the forward model integration code, a :c:type:`SUNAdjointStepper` object +can be created for the adjoint model integration by calling :c:func:`ARKStepCreateAdjointStepper`. +The code snippet below demonstrates these steps in brief and the example code +``examples/arkode/C_serial/ark_lotka_volterra_asa.c`` demonstrates these steps in detail. + +.. code-block:: C + + // 1. Create a SUNAdjointCheckpointScheme object + + // 2. Setup ARKStep for forward integration + + // 3. Attach the SUNAdjointCheckpointScheme + + // 4. Evolve the forward model + + // 5. Create the SUNAdjointStepper + + // 6. Setup the adjoint model + + // 7. Evolve the adjoint model + + // 8. Cleanup + + + +User Callable Functions +----------------------- + +This section describes ARKStep-specific user-callable functions for performing +adjoint sensitivity analysis with methods with ARKStep. + +.. c:function:: int ARKStepCreateAdjointStepper(void* arkode_mem, N_Vector sf, SUNAdjointStepper* adj_stepper_ptr) + + Creates a :c:type:`SUNAdjointStepper` object compatible with the provided ARKStep instance for + integrating the adjoint sensitivity system :eq:`ARKODE_DISCRETE_ADJOINT`. + + :param arkode_mem: a pointer to the ARKStep memory block. + :param sf: the sensitivity vector holding the adjoint system terminal condition. + This must be an instance of the ManyVector ``N_Vector`` implementation with at + least one subvector (depending on if sensitivities to parameters should be computed). + The first subvector must be :math:`\partial g_y(y(t_f)) \in \mathbb{R}^N`. If sensitivities to parameters should be computed, then the second subvector must be :math:`g_p(y(t_f), p) \in \mathbb{R}^{N_s}`. + :param adj_stepper_ptr: the newly created :c:type:`SUNAdjointStepper` object. + + :return: + * ``ARK_SUCCESS`` if successful + * ``ARK_MEM_FAIL`` if a memory allocation failed + * ``ARK_ILL_INPUT`` if an argument has an illegal value. diff --git a/doc/arkode/guide/source/Usage/ARKStep/index.rst b/doc/arkode/guide/source/Usage/ARKStep/index.rst index 15f97660d4..5d81c89d61 100644 --- a/doc/arkode/guide/source/Usage/ARKStep/index.rst +++ b/doc/arkode/guide/source/Usage/ARKStep/index.rst @@ -30,3 +30,4 @@ are specific to ARKStep. User_callable Relaxation XBraid + ASA diff --git a/doc/arkode/guide/source/Usage/ERKStep/ASA.rst b/doc/arkode/guide/source/Usage/ERKStep/ASA.rst new file mode 100644 index 0000000000..faacb124a2 --- /dev/null +++ b/doc/arkode/guide/source/Usage/ERKStep/ASA.rst @@ -0,0 +1,58 @@ +.. _ARKODE.Usage.ERKStep.ASA: + +Adjoint Sensitivity Analysis +============================ + +The previous sections discuss using ARKStep for the integration of forward ODE models. +This section discusses how to use ARKStep for adjoint sensitivity analysis as introduced +in :numref:`ARKODE.Mathematics.ASA`. To use ARKStep for adjoint sensitivity analysis (ASA), users simply setup the forward +integration as usual (following :numref:`ARKODE.Usage.Skeleton`) with one exception: +a :c:type:`SUNAdjointCheckpointScheme` object must be created and passed to +:c:func:`ARKodeSetAdjointCheckpointScheme` before the call to the :c:func:`ARKodeEvolve` +function. After the forward model integration code, a :c:type:`SUNAdjointStepper` object +can be created for the adjoint model integration by calling :c:func:`ERKStepCreateAdjointStepper`. +The code snippet below demonstrates these steps in brief and the example code +``examples/arkode/C_serial/ark_lotka_volterra_asa.c`` demonstrates these steps in detail. + +.. code-block:: C + + // 1. Create a SUNAdjointCheckpointScheme object + + // 2. Setup ERKStep for forward integration + + // 3. Attach the SUNAdjointCheckpointScheme + + // 4. Evolve the forward model + + // 5. Create the SUNAdjointStepper + + // 6. Setup the adjoint model + + // 7. Evolve the adjoint model + + // 8. Cleanup + + + +User Callable Functions +----------------------- + +This section describes ERKStep-specific user-callable functions for performing +adjoint sensitivity analysis with methods with ERKStep. + +.. c:function:: int ERKStepCreateAdjointStepper(void* arkode_mem, N_Vector sf, SUNAdjointStepper* adj_stepper_ptr) + + Creates a :c:type:`SUNAdjointStepper` object compatible with the provided ARKStep instance for + integrating the adjoint sensitivity system :eq:`ARKODE_DISCRETE_ADJOINT`. + + :param arkode_mem: a pointer to the ARKStep memory block. + :param sf: the sensitivity vector holding the adjoint system terminal condition. + This must be an instance of the ManyVector ``N_Vector`` implementation with at + least one subvector (depending on if sensitivities to parameters should be computed). + The first subvector must be :math:`\partial g_y(y(t_f)) \in \mathbb{R}^N`. If sensitivities to parameters should be computed, then the second subvector must be :math:`g_p(y(t_f), p) \in \mathbb{R}^{N_s}`. + :param adj_stepper_ptr: the newly created :c:type:`SUNAdjointStepper` object. + + :return: + * ``ARK_SUCCESS`` if successful + * ``ARK_MEM_FAIL`` if a memory allocation failed + * ``ARK_ILL_INPUT`` if an argument has an illegal value. diff --git a/doc/arkode/guide/source/Usage/ERKStep/index.rst b/doc/arkode/guide/source/Usage/ERKStep/index.rst index f50cf4cc96..45a2a11e00 100644 --- a/doc/arkode/guide/source/Usage/ERKStep/index.rst +++ b/doc/arkode/guide/source/Usage/ERKStep/index.rst @@ -29,3 +29,4 @@ are specific to ERKStep. User_callable Relaxation + ASA diff --git a/doc/arkode/guide/source/Usage/User_callable.rst b/doc/arkode/guide/source/Usage/User_callable.rst index a5798450e0..a49553119e 100644 --- a/doc/arkode/guide/source/Usage/User_callable.rst +++ b/doc/arkode/guide/source/Usage/User_callable.rst @@ -879,30 +879,31 @@ Optional inputs for ARKODE .. cssclass:: table-bordered -================================================ ======================================= ======================= -Optional input Function name Default -================================================ ======================================= ======================= -Return ARKODE parameters to their defaults :c:func:`ARKodeSetDefaults` internal -Set integrator method order :c:func:`ARKodeSetOrder` 4 -Set dense output interpolation type (SPRKStep) :c:func:`ARKodeSetInterpolantType` ``ARK_INTERP_LAGRANGE`` -Set dense output interpolation type (others) :c:func:`ARKodeSetInterpolantType` ``ARK_INTERP_HERMITE`` -Set dense output polynomial degree :c:func:`ARKodeSetInterpolantDegree` 5 -Disable time step adaptivity (fixed-step mode) :c:func:`ARKodeSetFixedStep` disabled -Set forward or backward integration direction :c:func:`ARKodeSetStepDirection` 0.0 -Supply an initial step size to attempt :c:func:`ARKodeSetInitStep` estimated -Maximum no. of warnings for :math:`t_n+h = t_n` :c:func:`ARKodeSetMaxHnilWarns` 10 -Maximum no. of internal steps before *tout* :c:func:`ARKodeSetMaxNumSteps` 500 -Maximum absolute step size :c:func:`ARKodeSetMaxStep` :math:`\infty` -Minimum absolute step size :c:func:`ARKodeSetMinStep` 0.0 -Set a value for :math:`t_{stop}` :c:func:`ARKodeSetStopTime` undefined -Interpolate at :math:`t_{stop}` :c:func:`ARKodeSetInterpolateStopTime` ``SUNFALSE`` -Disable the stop time :c:func:`ARKodeClearStopTime` N/A -Supply a pointer for user data :c:func:`ARKodeSetUserData` ``NULL`` -Maximum no. of ARKODE error test failures :c:func:`ARKodeSetMaxErrTestFails` 7 -Set inequality constraints on solution :c:func:`ARKodeSetConstraints` ``NULL`` -Set max number of constraint failures :c:func:`ARKodeSetMaxNumConstrFails` 10 -================================================ ======================================= ======================= - +================================================= ========================================== ======================= +Optional input Function name Default +================================================= ========================================== ======================= +Return ARKODE parameters to their defaults :c:func:`ARKodeSetDefaults` internal +Set integrator method order :c:func:`ARKodeSetOrder` 4 +Set dense output interpolation type (SPRKStep) :c:func:`ARKodeSetInterpolantType` ``ARK_INTERP_LAGRANGE`` +Set dense output interpolation type (others) :c:func:`ARKodeSetInterpolantType` ``ARK_INTERP_HERMITE`` +Set dense output polynomial degree :c:func:`ARKodeSetInterpolantDegree` 5 +Disable time step adaptivity (fixed-step mode) :c:func:`ARKodeSetFixedStep` disabled +Set forward or backward integration direction :c:func:`ARKodeSetStepDirection` 0.0 +Supply an initial step size to attempt :c:func:`ARKodeSetInitStep` estimated +Maximum no. of warnings for :math:`t_n+h = t_n` :c:func:`ARKodeSetMaxHnilWarns` 10 +Maximum no. of internal steps before *tout* :c:func:`ARKodeSetMaxNumSteps` 500 +Maximum absolute step size :c:func:`ARKodeSetMaxStep` :math:`\infty` +Minimum absolute step size :c:func:`ARKodeSetMinStep` 0.0 +Set a value for :math:`t_{stop}` :c:func:`ARKodeSetStopTime` undefined +Interpolate at :math:`t_{stop}` :c:func:`ARKodeSetInterpolateStopTime` ``SUNFALSE`` +Disable the stop time :c:func:`ARKodeClearStopTime` N/A +Supply a pointer for user data :c:func:`ARKodeSetUserData` ``NULL`` +Maximum no. of ARKODE error test failures :c:func:`ARKodeSetMaxErrTestFails` 7 +Set inequality constraints on solution :c:func:`ARKodeSetConstraints` ``NULL`` +Set max number of constraint failures :c:func:`ARKodeSetMaxNumConstrFails` 10 +Set the checkpointing scheme to use (for adjoint) :c:func:`ARKodeSetAdjointCheckpointScheme` ``NULL`` +Set the checkpointing step index (for adjoint) :c:func:`ARKodeSetAdjointCheckpointIndex` 0 +================================================= ========================================== ======================= @@ -1114,7 +1115,7 @@ Set max number of constraint failures :c:func:`ARKodeSetMaxNumConstr selects forward integration, a negative value selects backward integration, and zero leaves the current direction unchanged. - + :retval ARK_SUCCESS: the function exited successfully. :retval ARK_MEM_NULL: ``arkode_mem`` was ``NULL``. @@ -1437,6 +1438,34 @@ Set max number of constraint failures :c:func:`ARKodeSetMaxNumConstr .. versionadded:: 6.1.0 +.. c:function:: int ARKodeSetAdjointCheckpointScheme(void* arkode_mem, SUNAdjointCheckpointScheme checkpoint_scheme) + + Specifies the :c:type:`SUNAdjointCheckpointScheme` to use for saving states + during the forward integration, and loading states during backward integration + of an adjoint system. + + :param arkode_mem: pointer to the ARKODE memory block. + :param checkpoint_scheme: the checkpoint scheme to use. + + :retval ARK_SUCCESS: the function exited successfully. + :retval ARK_MEM_NULL: ``arkode_mem`` was ``NULL``. + + .. versionadded:: x.y.z + +.. c:function:: int ARKodeSetAdjointCheckpointIndex(void* arkode_mem, int64_t step_index) + + Specifies the step index (that is step number) to insert the next checkpoint at. + This is incremented along with the step count, but it is useful to be able to reset + this index during recomputations of missing states during the backward adjoint integration. + + :param arkode_mem: pointer to the ARKODE memory block. + :param step_idx: the step to insert the next checkpoint at. + + :retval ARK_SUCCESS: the function exited successfully. + :retval ARK_MEM_NULL: ``arkode_mem`` was ``NULL``. + + .. versionadded:: x.y.z + .. _ARKODE.Usage.ARKodeAdaptivityInputTable: @@ -4887,6 +4916,8 @@ rescale the upcoming time step by the specified factor. If a value +<<<<<<< HEAD +======= .. _ARKODE.Usage.MRIStepInterface: Using an ARKODE solver as an MRIStep "inner" solver @@ -4938,6 +4969,7 @@ wrap the ARKODE memory block as an :c:type:`MRIStepInnerStepper`. functions and the initial condition */ outer_arkode_mem = MRIStepCreate(fse, fsi, t0, y0, stepper, sunctx) +>>>>>>> origin/develop .. _ARKODE.Usage.SUNStepperInterface: Using an ARKODE solver as a SUNStepper diff --git a/doc/arkode/guide/source/index.rst b/doc/arkode/guide/source/index.rst index 466eba8ae2..2df0cef963 100644 --- a/doc/arkode/guide/source/index.rst +++ b/doc/arkode/guide/source/index.rst @@ -66,6 +66,7 @@ with support by the `US Department of Energy `_, sunnonlinsol/index.rst sunadaptcontroller/index.rst sunstepper/index.rst + sunadjoint/index.rst sunmemory/index.rst sundials/Install_link.rst Constants diff --git a/doc/arkode/guide/source/sunadjoint/SUNAdjoint_links.rst b/doc/arkode/guide/source/sunadjoint/SUNAdjoint_links.rst new file mode 100644 index 0000000000..142352fa02 --- /dev/null +++ b/doc/arkode/guide/source/sunadjoint/SUNAdjoint_links.rst @@ -0,0 +1,14 @@ +.. ---------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2002-2025, Lawrence Livermore National Security + and Southern Methodist University. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ---------------------------------------------------------------- + +.. include:: ../../../../shared/sunadjoint/SUNAdjointCheckpointScheme.rst +.. include:: ../../../../shared/sunadjoint/SUNAdjointStepper.rst diff --git a/doc/arkode/guide/source/sunadjoint/index.rst b/doc/arkode/guide/source/sunadjoint/index.rst new file mode 100644 index 0000000000..d41499c18c --- /dev/null +++ b/doc/arkode/guide/source/sunadjoint/index.rst @@ -0,0 +1,19 @@ +.. + ---------------------------------------------------------------- + SUNDIALS Copyright Start + Copyright (c) 2002-2025, Lawrence Livermore National Security + and Southern Methodist University. + All rights reserved. + + See the top-level LICENSE and NOTICE files for details. + + SPDX-License-Identifier: BSD-3-Clause + SUNDIALS Copyright End + ---------------------------------------------------------------- + +.. include:: ../../../../shared/sunadjoint/SUNAdjoint_Introduction.rst + +.. toctree:: + :maxdepth: 1 + + SUNAdjoint_links.rst diff --git a/examples/arkode/C_serial/CMakeLists.txt b/examples/arkode/C_serial/CMakeLists.txt index 176d7ca272..e1956028cd 100644 --- a/examples/arkode/C_serial/CMakeLists.txt +++ b/examples/arkode/C_serial/CMakeLists.txt @@ -86,6 +86,8 @@ set(ARKODE_examples "ark_KrylovDemo_prec\;\;exclude-single" "ark_KrylovDemo_prec\;1\;exclude-single" "ark_KrylovDemo_prec\;2\;exclude-single" + "ark_lotka_volterra_ASA\;--check-freq 1\;develop" + "ark_lotka_volterra_ASA\;--check-freq 5\;develop" "ark_onewaycouple_mri\;\;develop" "ark_reaction_diffusion_mri\;\;develop" "ark_robertson_constraints\;\;exclude-single" diff --git a/examples/arkode/C_serial/ark_lotka_volterra_ASA.c b/examples/arkode/C_serial/ark_lotka_volterra_ASA.c new file mode 100644 index 0000000000..232157a599 --- /dev/null +++ b/examples/arkode/C_serial/ark_lotka_volterra_ASA.c @@ -0,0 +1,359 @@ +/* ----------------------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2002-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------------------- + * This example solves the Lotka-Volterra ODE with four parameters, + * + * u = [dx/dt] = [ p_0*x - p_1*x*y ] + * [dy/dt] [ -p_2*y + p_3*x*y ]. + * + * The initial condition is u(t_0) = 1.0 and we use the parameters + * p = [1.5, 1.0, 3.0, 1.0]. The integration interval can be controlled via + * the --tf command line argument, but by default it is t \in [0, 10.]. + * An explicit Runge--Kutta method is employed via the ARKStep time stepper + * provided by ARKODE. After solving the forward problem, adjoint sensitivity + * analysis (ASA) is performed using the discrete adjoint method available with + * with ARKStep in order to obtain the gradient of the scalar cost function, + * + * g(u(t_f), p) = || 1 - u(t_f, p) ||^2 / 2 + * + * with respect to the initial condition and the parameters. + * + * ./ark_lotka_volterra_adj options: + * --tf the final simulation time + * --dt the timestep size + * --order the order of the RK method + * --check-freq how often to checkpoint (in steps) + * --no-stages don't checkpoint stages + * --dont-keep don't keep checkpoints around after loading + * --help print these options + * ---------------------------------------------------------------------------*/ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "sundials/sundials_nvector.h" + +typedef struct +{ + sunrealtype tf; + sunrealtype dt; + int order; + int check_freq; + sunbooleantype save_stages; + sunbooleantype keep_checks; +} ProgramArgs; + +static sunrealtype params[4] = {SUN_RCONST(1.5), SUN_RCONST(1.0), + SUN_RCONST(3.0), SUN_RCONST(1.0)}; +static void parse_args(int argc, char* argv[], ProgramArgs* args); +static void print_help(int argc, char* argv[], int exit_code); +static int check_retval(void* retval_ptr, const char* funcname, int opt); +static int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, + void* user_data); +static int vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + N_Vector udotvec, void* user_data, N_Vector tmp); +static int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, + N_Vector uvec, N_Vector udotvec, void* user_data, + N_Vector tmp); +static void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p); +static void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p); + +int main(int argc, char* argv[]) +{ + int retval = 0; + SUNContext sunctx = NULL; + SUNContext_Create(SUN_COMM_NULL, &sunctx); + + ProgramArgs args; + args.tf = SUN_RCONST(10.0); + args.dt = SUN_RCONST(1e-3); + args.order = 4; + args.save_stages = SUNTRUE; + args.keep_checks = SUNTRUE; + args.check_freq = 2; + parse_args(argc, argv, &args); + + // + // Create the initial conditions vector + // + + sunindextype neq = 2; + N_Vector u = N_VNew_Serial(neq, sunctx); + N_Vector u0 = N_VClone(u); + N_VConst(SUN_RCONST(1.0), u0); + N_VConst(SUN_RCONST(1.0), u); + + // + // Create the ARKODE stepper that will be used for the forward evolution. + // + + const sunrealtype dt = args.dt; + sunrealtype t0 = SUN_RCONST(0.0); + sunrealtype tf = args.tf; + const int nsteps = (int)ceil(((tf - t0) / dt + 1)); + const int order = args.order; + void* arkode_mem = ARKStepCreate(lotka_volterra, NULL, t0, u, sunctx); + + retval = ARKodeSetOrder(arkode_mem, order); + if (check_retval(&retval, "ARKodeSetOrder", 1)) { return 1; } + + retval = ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2); + if (check_retval(&retval, "ARKodeSetMaxNumSteps", 1)) { return 1; } + + // Enable checkpointing during the forward solution. + const int check_interval = args.check_freq; + const int ncheck = nsteps * order; + const sunbooleantype save_stages = args.save_stages; + const sunbooleantype keep_check = args.keep_checks; + SUNAdjointCheckpointScheme checkpoint_scheme = NULL; + SUNMemoryHelper mem_helper = SUNMemoryHelper_Sys(sunctx); + + retval = SUNAdjointCheckpointScheme_Create_Fixed(SUNDATAIOMODE_INMEM, + mem_helper, check_interval, + ncheck, save_stages, + keep_check, sunctx, + &checkpoint_scheme); + if (check_retval(&retval, "SUNAdjointCheckpointScheme_Create_Fixed", 1)) + { + return 1; + } + + retval = ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme); + if (check_retval(&retval, "ARKodeSetAdjointCheckpointScheme", 1)) + { + return 1; + } + + // + // Compute the forward solution + // + + printf("Initial condition:\n"); + N_VPrint(u); + + retval = ARKodeSetUserData(arkode_mem, (void*)params); + if (check_retval(&retval, "ARKodeSetUserData", 1)) { return 1; } + + retval = ARKodeSetFixedStep(arkode_mem, dt); + if (check_retval(&retval, "ARKodeSetFixedStep", 1)) { return 1; } + + sunrealtype tret = t0; + while (tret < tf) + { + retval = ARKodeEvolve(arkode_mem, tf, u, &tret, ARK_NORMAL); + if (retval < 0) + { + fprintf(stderr, ">>> ERROR: ARKodeEvolve returned %d\n", retval); + return -1; + } + } + + printf("Forward Solution:\n"); + N_VPrint(u); + + printf("ARKODE Stats for Forward Solution:\n"); + retval = ARKodePrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE); + if (check_retval(&retval, "ARKodePrintAllStats", 1)) { return 1; } + printf("\n"); + + // + // Create the adjoint stepper + // + + sunindextype num_params = 4; + N_Vector sensu0 = N_VClone(u); + N_Vector sensp = N_VNew_Serial(num_params, sunctx); + N_Vector sens[2] = {sensu0, sensp}; + N_Vector sf = N_VNew_ManyVector(2, sens, sunctx); + + // Set the terminal condition for the adjoint system, which + // should be the the gradient of our cost function at tf. + dgdu(u, sensu0, params); + dgdp(u, sensp, params); + + printf("Adjoint terminal condition:\n"); + N_VPrint(sf); + + SUNAdjointStepper adj_stepper; + retval = ARKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper); + if (check_retval(&retval, "ARKStepCreateAdjointStepper", 1)) { return 1; } + + retval = SUNAdjointStepper_SetVecTimesJacFn(adj_stepper, vjp, parameter_vjp); + if (check_retval(&retval, "SUNAdjointStepper_SetVecTimesJacFn", 1)) + { + return 1; + } + + // + // Now compute the adjoint solution + // + + retval = SUNAdjointStepper_Evolve(adj_stepper, t0, sf, &tret); + if (check_retval(&retval, "SUNAdjointStepper_Evolve", 1)) { return 1; } + + printf("Adjoint Solution:\n"); + N_VPrint(sf); + + printf("\nSUNAdjointStepper Stats:\n"); + retval = SUNAdjointStepper_PrintAllStats(adj_stepper, stdout, + SUN_OUTPUTFORMAT_TABLE); + if (check_retval(&retval, "SUNAdjointStepper_PrintAllStats", 1)) { return 1; } + printf("\n"); + + // + // Cleanup + // + + N_VDestroy(u); + N_VDestroy(sf); + SUNAdjointCheckpointScheme_Destroy(&checkpoint_scheme); + SUNAdjointStepper_Destroy(&adj_stepper); + ARKodeFree(&arkode_mem); + SUNContext_Free(&sunctx); + + return 0; +} + +int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, void* user_data) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* udot = N_VGetArrayPointer(udotvec); + + udot[0] = p[0] * u[0] - p[1] * u[0] * u[1]; + udot[1] = -p[2] * u[1] + p[3] * u[0] * u[1]; + + return 0; +} + +int vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + N_Vector udotvec, void* user_data, N_Vector tmp) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = (p[0] - p[1] * u[1]) * v[0] + p[3] * u[1] * v[1]; + Jv[1] = -p[1] * u[0] * v[0] + (-p[2] + p[3] * u[0]) * v[1]; + + return 0; +} + +int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + N_Vector udotvec, void* user_data, N_Vector tmp) +{ + if (user_data != params) { return -1; } + + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = u[0] * v[0]; + Jv[1] = -u[0] * u[1] * v[0]; + Jv[2] = -u[1] * v[1]; + Jv[3] = u[0] * u[1] * v[1]; + + return 0; +} + +void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = -SUN_RCONST(1.0) + u[0]; + dg[1] = -SUN_RCONST(1.0) + u[1]; +} + +void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p) +{ + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = SUN_RCONST(0.0); + dg[1] = SUN_RCONST(0.0); + dg[2] = SUN_RCONST(0.0); + dg[3] = SUN_RCONST(0.0); +} + +void print_help(int argc, char* argv[], int exit_code) +{ + if (exit_code) { fprintf(stderr, "%s: option not recognized\n", argv[0]); } + else { fprintf(stderr, "%s ", argv[0]); } + fprintf(stderr, "options:\n"); + fprintf(stderr, "--tf the final simulation time\n"); + fprintf(stderr, "--dt the timestep size\n"); + fprintf(stderr, "--order the order of the RK method\n"); + fprintf(stderr, "--check-freq how often to checkpoint (in steps)\n"); + fprintf(stderr, "--no-stages don't checkpoint stages\n"); + fprintf(stderr, + "--dont-keep don't keep checkpoints around after loading\n"); + fprintf(stderr, "--help print these options\n"); + exit(exit_code); +} + +void parse_args(int argc, char* argv[], ProgramArgs* args) +{ + for (int argi = 1; argi < argc; ++argi) + { + const char* arg = argv[argi]; + if (!strcmp(arg, "--tf")) { args->tf = atof(argv[++argi]); } + else if (!strcmp(arg, "--dt")) { args->dt = atof(argv[++argi]); } + else if (!strcmp(arg, "--order")) { args->order = atoi(argv[++argi]); } + else if (!strcmp(arg, "--check-freq")) + { + args->check_freq = atoi(argv[++argi]); + } + else if (!strcmp(arg, "--no-stages")) { args->save_stages = SUNFALSE; } + else if (!strcmp(arg, "--dont-keep")) { args->keep_checks = SUNFALSE; } + else if (!strcmp(arg, "--help")) { print_help(argc, argv, 0); } + else { print_help(argc, argv, 1); } + } +} + +int check_retval(void* retval_ptr, const char* funcname, int opt) +{ + int* retval; + + /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ + if (opt == 0 && retval_ptr == NULL) + { + fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", + funcname); + return 1; + } + + /* Check if retval < 0 */ + else if (opt == 1) + { + retval = (int*)retval_ptr; + if (*retval < 0) + { + fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n", + funcname, *retval); + return 1; + } + } + + return (0); +} diff --git a/examples/arkode/C_serial/ark_lotka_volterra_ASA_--check-freq_1.out b/examples/arkode/C_serial/ark_lotka_volterra_ASA_--check-freq_1.out new file mode 100644 index 0000000000..512207bd66 --- /dev/null +++ b/examples/arkode/C_serial/ark_lotka_volterra_ASA_--check-freq_1.out @@ -0,0 +1,52 @@ +Initial condition: +1.0000000000000000e+00 +1.0000000000000000e+00 + +Forward Solution: +1.0263447675712871e+00 +9.0969107814004269e-01 + +ARKODE Stats for Forward Solution: +Current time = 10.0009999999999 +Steps = 10001 +Step attempts = 10001 +Stability limited steps = 0 +Accuracy limited steps = 0 +Error test fails = 0 +NLS step fails = 0 +Inequality constraint fails = 0 +Initial step size = 0.001 +Last step size = 0.001 +Current step size = 0.001 +Explicit RHS fn evals = 50006 +Implicit RHS fn evals = 0 +NLS iters = 0 +NLS fails = 0 +NLS iters per step = 0 +LS setups = 0 + +Adjoint terminal condition: +1.9360358457113298e+00 +1.9360358457113298e+00 + +0.0000000000000000e+00 +0.0000000000000000e+00 +0.0000000000000000e+00 +0.0000000000000000e+00 + +Adjoint Solution: +-1.4926513048964769e+00 +7.5866467767938539e-01 + +-7.7848751398398628e+00 +-1.0490727707562078e+00 +-2.1661090848372893e+00 +-3.5256325524648946e+00 + + +SUNAdjointStepper Stats: +Num backwards steps = 10001 +Num recompute passes = 0 +v-times-Jac evals = 50005 +v-times-Jacp evals = 50005 + diff --git a/examples/arkode/C_serial/ark_lotka_volterra_ASA_--check-freq_5.out b/examples/arkode/C_serial/ark_lotka_volterra_ASA_--check-freq_5.out new file mode 100644 index 0000000000..80031ffbf5 --- /dev/null +++ b/examples/arkode/C_serial/ark_lotka_volterra_ASA_--check-freq_5.out @@ -0,0 +1,52 @@ +Initial condition: +1.0000000000000000e+00 +1.0000000000000000e+00 + +Forward Solution: +1.0263447675712871e+00 +9.0969107814004269e-01 + +ARKODE Stats for Forward Solution: +Current time = 10.0009999999999 +Steps = 10001 +Step attempts = 10001 +Stability limited steps = 0 +Accuracy limited steps = 0 +Error test fails = 0 +NLS step fails = 0 +Inequality constraint fails = 0 +Initial step size = 0.001 +Last step size = 0.001 +Current step size = 0.001 +Explicit RHS fn evals = 50006 +Implicit RHS fn evals = 0 +NLS iters = 0 +NLS fails = 0 +NLS iters per step = 0 +LS setups = 0 + +Adjoint terminal condition: +1.9360358457113298e+00 +1.9360358457113298e+00 + +0.0000000000000000e+00 +0.0000000000000000e+00 +0.0000000000000000e+00 +0.0000000000000000e+00 + +Adjoint Solution: +-1.4947639071731134e+00 +7.5873448439158087e-01 + +-7.7894652372843973e+00 +-1.0491640634028794e+00 +-2.1671821513999858e+00 +-3.5275142622843187e+00 + + +SUNAdjointStepper Stats: +Num backwards steps = 10001 +Num recompute passes = 3080 +v-times-Jac evals = 50005 +v-times-Jacp evals = 50005 + diff --git a/examples/cvodes/serial/CMakeLists.txt b/examples/cvodes/serial/CMakeLists.txt index 41f480d65c..ae7921ea1b 100644 --- a/examples/cvodes/serial/CMakeLists.txt +++ b/examples/cvodes/serial/CMakeLists.txt @@ -37,6 +37,7 @@ set(CVODES_examples "cvsKrylovDemo_ls\;1\;develop" "cvsKrylovDemo_ls\;2\;develop" "cvsKrylovDemo_prec\;\;develop" + "cvsLotkaVolterra_ASA\;\;develop" "cvsParticle_dns\;\;develop" "cvsPendulum_dns\;\;exclude-single" "cvsRoberts_ASAi_dns\;\;exclude-single" diff --git a/examples/cvodes/serial/cvsLotkaVolterra_ASA.c b/examples/cvodes/serial/cvsLotkaVolterra_ASA.c new file mode 100644 index 0000000000..259d6370df --- /dev/null +++ b/examples/cvodes/serial/cvsLotkaVolterra_ASA.c @@ -0,0 +1,326 @@ +/* ----------------------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2002-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------------------- + * This example solves the Lotka-Volterra ODE with four parameters, + * + * u' = [dx/dt] = [ p_0*x - p_1*x*y ] + * [dy/dt] [ -p_2*y + p_3*x*y ]. + * + * The initial condition is u(t_0) = 1.0 and we use the parameters + * p = [1.5, 1.0, 3.0, 1.0]. The integration interval is t \in [0, 10.]. + * The implicit BDF method from CVODES is used to solve the forward problem. + * Afterwards, the continuous adjoint sensitivity analysis capabilities of CVODES + * are used to obtain the gradient of the cost function, + * + * g(u(t_f), p) = || 1 - u(t_f, p) ||^2 / 2 + * + * with respect to the initial condition and the parameters. + * ----------------------------------------------------------------------------- + */ + +#include +#include +#include +#include +#include +#include "cvodes/cvodes_ls.h" +#include "sundials/sundials_context.h" +#include "sundials/sundials_iterative.h" +#include "sundials/sundials_nvector.h" +#include "sundials/sundials_types.h" +#include "sunlinsol/sunlinsol_dense.h" +#include "sunlinsol/sunlinsol_spgmr.h" + +/* Problem Constants */ +#define NEQ 2 /* number of equations */ +#define NP 4 /* number of params */ +#define T0 SUN_RCONST(0.0) /* initial time */ +#define TF SUN_RCONST(1.0) /* final time */ +#if defined(SUNDIALS_SINGLE_PRECISION) +#define RTOL SUN_RCONST(1.0e-5) /* relative tolerance */ +#define ATOL SUN_RCONST(1.0e-8) /* absolute tolerance */ +#else +#define RTOL SUN_RCONST(1.0e-10) /* relative tolerance */ +#define ATOL SUN_RCONST(1.0e-14) /* absolute tolerance */ +#endif +#define STEPS 5 /* checkpoint interval */ + +static int check_retval(void* retval_ptr, const char* funcname, int opt); + +static sunrealtype params[4] = {1.5, 1.0, 3.0, 1.0}; + +static int vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + void* user_data); + +static int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, + void* user_data); + +static int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, + N_Vector uvec, void* user_data); + +static void dgdu(N_Vector uvec, N_Vector dgvec); + +static int adjoint_rhs(sunrealtype t, N_Vector uvec, N_Vector lvec, + N_Vector ldotvec, void* user_data); + +static int quad_rhs(sunrealtype t, N_Vector uvec, N_Vector muvec, + N_Vector qBdotvec, void* user_dataB); + +int main(int argc, char* argv[]) +{ + SUNContext sunctx; + sunrealtype reltol, abstol, t, tout; + N_Vector u, uB, qB; + void* cvode_mem; + int which, retval; + + SUNContext_Create(SUN_COMM_NULL, &sunctx); + + /* Allocate memory for the solution vector */ + u = N_VNew_Serial(NEQ, sunctx); + if (check_retval((void*)u, "N_VNew_Serial", 0)) { return 1; } + + /* Initialize the solution vector */ + N_VConst(1.0, u); + + /* Set the tolerances */ + reltol = RTOL; + abstol = ATOL; + + /* Create the CVODES object */ + cvode_mem = CVodeCreate(CV_BDF, sunctx); + if (check_retval((void*)cvode_mem, "CVodeCreate", 0)) { return 1; } + + /* Initialize the CVODES solver */ + retval = CVodeInit(cvode_mem, lotka_volterra, T0, u); + if (check_retval(&retval, "CVodeInit", 1)) { return 1; } + + /* Set the user data */ + retval = CVodeSetUserData(cvode_mem, (void*)params); + if (check_retval(&retval, "CVodeSetUserData", 1)) { return 1; } + + /* Set the tolerances */ + retval = CVodeSStolerances(cvode_mem, reltol, abstol); + if (check_retval(&retval, "CVodeSStolerances", 1)) { return 1; } + + // SUNLinearSolver LS = SUNLinSol_Dense(y, NULL, sunctx); + SUNLinearSolver LS = SUNLinSol_SPGMR(u, SUN_PREC_NONE, 3, sunctx); + + retval = CVodeSetLinearSolver(cvode_mem, LS, NULL); + if (check_retval(&retval, "CVodeSetLinearSolver", 1)) { return 1; } + + retval = CVodeSetMaxNumSteps(cvode_mem, 100000); + if (check_retval(&retval, "CVodeSetMaxNumSteps", 1)) { return 1; } + + /* Initialize ASA */ + retval = CVodeAdjInit(cvode_mem, STEPS, CV_HERMITE); + if (check_retval(&retval, "CVodeAdjInit", 1)) { return 1; } + + /* Integrate the ODE */ + tout = TF; + int ncheck; + retval = CVodeF(cvode_mem, tout, u, &t, CV_NORMAL, &ncheck); + if (check_retval(&retval, "CVode", 1)) { return 1; } + + /* Print the final solution */ + printf("Forward Solution at t = %g:\n", t); + N_VPrint(u); + + /* Allocate memory for the adjoint solution vector */ + uB = N_VNew_Serial(NEQ, sunctx); + if (check_retval((void*)uB, "N_VNew_Serial", 0)) { return 1; } + + /* Allocate memory for the quadrature equations and initialize it to zero */ + qB = N_VNew_Serial(NP, sunctx); + N_VConst(SUN_RCONST(0.0), qB); + + /* Initialize the adjoint solution vector */ + dgdu(u, uB); + + printf("Adjoint terminal condition:\n"); + N_VPrint(uB); + N_VPrint(qB); + + /* Create the CVODES object for the backward problem */ + retval = CVodeCreateB(cvode_mem, CV_BDF, &which); + + /* Initialize the CVODES solver for the backward problem */ + retval = CVodeInitB(cvode_mem, which, adjoint_rhs, TF, uB); + if (check_retval(&retval, "CVodeInitB", 1)) { return 1; } + + /* Set the user data for the backward problem */ + retval = CVodeSetUserDataB(cvode_mem, which, (void*)params); + if (check_retval(&retval, "CVodeSetUserDataB", 1)) { return 1; } + + /* Set the tolerances for the backward problem */ + retval = CVodeSStolerancesB(cvode_mem, which, reltol, abstol); + if (check_retval(&retval, "CVodeSStolerancesB", 1)) { return 1; } + + /* Create the linear solver for the backward problem */ + SUNLinearSolver LSB = SUNLinSol_SPGMR(uB, SUN_PREC_NONE, 3, sunctx); + + retval = CVodeSetLinearSolverB(cvode_mem, which, LSB, NULL); + if (check_retval(&retval, "CVodeSetLinearSolver", 1)) { return 1; } + + /* Call CVodeQuadInitB to allocate internal memory and initialize backward + quadrature integration. This gives the sensitivities w.r.t. the parameters. */ + retval = CVodeQuadInitB(cvode_mem, which, quad_rhs, qB); + if (check_retval(&retval, "CVodeQuadInitB", 1)) { return (1); } + + /* Call CVodeSetQuadErrCon to specify whether or not the quadrature variables + are to be used in the step size control mechanism within CVODES. Call + CVodeQuadSStolerances or CVodeQuadSVtolerances to specify the integration + tolerances for the quadrature variables. */ + retval = CVodeSetQuadErrConB(cvode_mem, which, SUNTRUE); + if (check_retval(&retval, "CVodeSetQuadErrConB", 1)) { return (1); } + + /* Call CVodeQuadSStolerancesB to specify the scalar relative and absolute tolerances + for the backward problem. */ + retval = CVodeQuadSStolerancesB(cvode_mem, which, reltol, abstol); + if (check_retval(&retval, "CVodeQuadSStolerancesB", 1)) { return (1); } + + /* Integrate the adjoint ODE */ + retval = CVodeB(cvode_mem, T0, CV_NORMAL); + if (check_retval(&retval, "CVodeB", 1)) { return 1; } + + /* Get the final adjoint solution */ + retval = CVodeGetB(cvode_mem, which, &t, uB); + if (check_retval(&retval, "CVodeGetB", 1)) { return 1; } + + /* Call CVodeGetQuadB to get the quadrature solution vector after a + successful return from CVodeB. */ + retval = CVodeGetQuadB(cvode_mem, which, &t, qB); + if (check_retval(&retval, "CVodeGetQuadB", 1)) { return (1); } + + /* Print the final adjoint solution */ + printf("Adjoint Solution at t = %g:\n", t); + N_VPrint(uB); + N_VPrint(qB); + + /* Free memory */ + N_VDestroy(u); + N_VDestroy(uB); + N_VDestroy(qB); + SUNLinSolFree(LS); + SUNLinSolFree(LSB); + CVodeFree(&cvode_mem); + SUNContext_Free(&sunctx); + + return 0; +} + +/* Function to compute the ODE right-hand side */ +int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, void* user_data) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* udot = N_VGetArrayPointer(udotvec); + + udot[0] = p[0] * u[0] - p[1] * u[0] * u[1]; + udot[1] = -p[2] * u[1] + p[3] * u[0] * u[1]; + + return 0; +} + +/* Function to compute v^T (df/du) */ +int vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + void* user_data) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = (p[0] - p[1] * u[1]) * v[0] + p[3] * u[1] * v[1]; + Jv[1] = -p[1] * u[0] * v[0] + (-p[2] + p[3] * u[0]) * v[1]; + + return 0; +} + +/* Function to compute v^T (df/dp) */ +int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + void* user_data) +{ + if (user_data != params) { return -1; } + + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = u[0] * v[0]; + Jv[1] = -u[0] * u[1] * v[0]; + Jv[2] = -u[1] * v[1]; + Jv[3] = u[0] * u[1] * v[1]; + + return 0; +} + +/* Gradient of the cost function w.r.t to u. + The gradient w.r.t to p is zero since the cost function + does not depend on the parameters. */ +void dgdu(N_Vector uvec, N_Vector dgvec) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = -SUN_RCONST(1.0) + u[0]; + dg[1] = -SUN_RCONST(1.0) + u[1]; +} + +/* Function to compute the adjoint ODE right-hand side: + -mu^T (df/du) + */ +int adjoint_rhs(sunrealtype t, N_Vector uvec, N_Vector lvec, N_Vector ldotvec, + void* user_data) +{ + vjp(lvec, ldotvec, t, uvec, user_data); + N_VScale(-1.0, ldotvec, ldotvec); + + return 0; +} + +/* Function to compute the quadrature right-hand side: + mu^T (df/dp) + */ +int quad_rhs(sunrealtype t, N_Vector uvec, N_Vector muvec, N_Vector qBdotvec, + void* user_dataB) +{ + parameter_vjp(muvec, qBdotvec, t, uvec, user_dataB); + return 0; +} + +/* Check function return value */ +int check_retval(void* retval_ptr, const char* funcname, int opt) +{ + int* retval; + + /* Check if SUNDIALS function returned NULL pointer - no memory allocated */ + if (opt == 0 && retval_ptr == NULL) + { + fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed - returned NULL pointer\n\n", + funcname); + return 1; + } + + /* Check if retval < 0 */ + else if (opt == 1) + { + retval = (int*)retval_ptr; + if (*retval < 0) + { + fprintf(stderr, "\nSUNDIALS_ERROR: %s() failed with retval = %d\n\n", + funcname, *retval); + return 1; + } + } + + return (0); +} diff --git a/examples/cvodes/serial/cvsLotkaVolterra_ASA.out b/examples/cvodes/serial/cvsLotkaVolterra_ASA.out new file mode 100644 index 0000000000..57495742f8 --- /dev/null +++ b/examples/cvodes/serial/cvsLotkaVolterra_ASA.out @@ -0,0 +1,22 @@ +Forward Solution at t = 10 +1.0263448015893779e+00 +9.0969106392430898e-01 + +Adjoint terminal condition: +1.9360358655136869e+00 +1.9360358655136869e+00 + +0.0000000000000000e+00 +0.0000000000000000e+00 +0.0000000000000000e+00 +0.0000000000000000e+00 + +Adjoint Solution at t = 0: +-1.5062028945464851e+00 +7.7496083813398497e-01 + +7.9304147529003224e+00 +9.8623365223362636e-01 +2.2183773121596766e+00 +3.4932432378604874e+00 + diff --git a/include/arkode/arkode.h b/include/arkode/arkode.h index eec76a520a..eb68184425 100644 --- a/include/arkode/arkode.h +++ b/include/arkode/arkode.h @@ -30,6 +30,7 @@ #include #include +#include #include #include @@ -144,10 +145,12 @@ extern "C" { #define ARK_DOMEIG_FAIL -49 #define ARK_MAX_STAGE_LIMIT_FAIL -50 -#define ARK_SUNSTEPPER_ERR -51 - +#define ARK_SUNSTEPPER_ERR -51 #define ARK_STEP_DIRECTION_ERR -52 +#define ARK_ADJ_CHECKPOINT_FAIL -53 +#define ARK_ADJ_RECOMPUTE_FAIL -54 + #define ARK_UNRECOGNIZED_ERROR -99 /* ------------------------------ @@ -305,6 +308,11 @@ SUNDIALS_EXPORT int ARKodeSetInitStep(void* arkode_mem, sunrealtype hin); SUNDIALS_EXPORT int ARKodeSetMinStep(void* arkode_mem, sunrealtype hmin); SUNDIALS_EXPORT int ARKodeSetMaxStep(void* arkode_mem, sunrealtype hmax); SUNDIALS_EXPORT int ARKodeSetMaxNumConstrFails(void* arkode_mem, int maxfails); +SUNDIALS_EXPORT +int ARKodeSetAdjointCheckpointScheme(void* arkode_mem, + SUNAdjointCheckpointScheme checkpoint_scheme); +SUNDIALS_EXPORT +int ARKodeSetAdjointCheckpointIndex(void* arkode_mem, int64_t step_index); SUNDIALS_EXPORT int ARKodeSetAccumulatedErrorType(void* arkode_mem, ARKAccumError accum_type); SUNDIALS_EXPORT int ARKodeResetAccumulatedError(void* arkode_mem); diff --git a/include/arkode/arkode_arkstep.h b/include/arkode/arkode_arkstep.h index 2be3d59332..362fe7704b 100644 --- a/include/arkode/arkode_arkstep.h +++ b/include/arkode/arkode_arkstep.h @@ -23,6 +23,8 @@ #include #include #include +#include +#include #ifdef __cplusplus /* wrapper to enable C++ usage */ extern "C" { @@ -396,6 +398,13 @@ SUNDIALS_DEPRECATED_EXPORT_MSG("use ARKodeFree instead") void ARKStepFree(void** arkode_mem); SUNDIALS_DEPRECATED_EXPORT_MSG("use ARKodePrintMem instead") void ARKStepPrintMem(void* arkode_mem, FILE* outfile); + +/* Adjoint solver functions */ +SUNDIALS_EXPORT +int ARKStepCreateAdjointStepper(void* arkode_mem, N_Vector sf, + SUNAdjointStepper* adj_stepper_ptr); + +/* Relaxation functions */ SUNDIALS_DEPRECATED_EXPORT_MSG("use ARKodeSetRelaxFn instead") int ARKStepSetRelaxFn(void* arkode_mem, ARKRelaxFn rfn, ARKRelaxJacFn rjac); SUNDIALS_DEPRECATED_EXPORT_MSG("use ARKodeSetRelaxEtaFail instead") diff --git a/include/arkode/arkode_erkstep.h b/include/arkode/arkode_erkstep.h index 6c045753b5..70481af451 100644 --- a/include/arkode/arkode_erkstep.h +++ b/include/arkode/arkode_erkstep.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #ifdef __cplusplus /* wrapper to enable C++ usage */ extern "C" { @@ -67,6 +69,11 @@ SUNDIALS_EXPORT int ERKStepGetTimestepperStats( void* arkode_mem, long int* expsteps, long int* accsteps, long int* step_attempts, long int* nfevals, long int* netfails); +/* Adjoint solver functions */ +SUNDIALS_EXPORT +int ERKStepCreateAdjointStepper(void* arkode_mem, N_Vector sf, + SUNAdjointStepper* adj_stepper_ptr); + /* -------------------------------------------------------------------------- * Deprecated Functions -- all are superseded by shared ARKODE-level routines * -------------------------------------------------------------------------- */ diff --git a/src/arkode/CMakeLists.txt b/src/arkode/CMakeLists.txt index 3e3634f7b1..ff57c62a3c 100644 --- a/src/arkode/CMakeLists.txt +++ b/src/arkode/CMakeLists.txt @@ -82,6 +82,7 @@ sundials_add_library( OBJECT_LIBRARIES sundials_sunmemsys_obj sundials_nvecserial_obj + sundials_nvecmanyvector_obj sundials_sunadaptcontrollersoderlind_obj sundials_sunadaptcontrollerimexgus_obj sundials_sunadaptcontrollermrihtol_obj @@ -97,6 +98,7 @@ sundials_add_library( sundials_sunlinsolpcg_obj sundials_sunnonlinsolnewton_obj sundials_sunnonlinsolfixedpoint_obj + sundials_adjointcheckpointscheme_fixed_obj OUTPUT_NAME sundials_arkode VERSION ${arkodelib_VERSION} SOVERSION ${arkodelib_SOVERSION}) diff --git a/src/arkode/arkode.c b/src/arkode/arkode.c index eefba1aee8..64cdf041bd 100644 --- a/src/arkode/arkode.c +++ b/src/arkode/arkode.c @@ -35,6 +35,8 @@ #include "sundials/sundials_logger.h" #include "sundials_utils.h" +#include "sundials_macros.h" + /*=============================================================== Exported functions ===============================================================*/ @@ -1325,6 +1327,7 @@ void ARKodePrintMem(void* arkode_mem, FILE* outfile) fprintf(outfile, "fixedstep = %i\n", ark_mem->fixedstep); fprintf(outfile, "tolsf = " SUN_FORMAT_G "\n", ark_mem->tolsf); fprintf(outfile, "call_fullrhs = %i\n", ark_mem->call_fullrhs); + fprintf(outfile, "do_adjoint = %i\n", ark_mem->do_adjoint); /* output counters */ fprintf(outfile, "nhnil = %i\n", ark_mem->nhnil); @@ -1639,6 +1642,8 @@ ARKodeMem arkCreate(SUNContext sunctx) return (NULL); } + ark_mem->do_adjoint = SUNFALSE; + /* Return pointer to ARKODE memory block */ return (ark_mem); } @@ -1835,6 +1840,9 @@ int arkInit(ARKodeMem ark_mem, sunrealtype t0, N_Vector y0, int init_type) and/or the stepper initialization function in arkInitialSetup */ ark_mem->call_fullrhs = SUNFALSE; + /* Adjoint related */ + ark_mem->checkpoint_step_idx = 0; + /* Indicate that initialization has not been done before */ ark_mem->initialized = SUNFALSE; } @@ -1988,6 +1996,14 @@ int arkInitialSetup(ARKodeMem ark_mem, sunrealtype tout) /* Test input tstop for legality (correct direction of integration) */ if (ark_mem->tstopset) { +#if SUNDIALS_LOGGING_LEVEL >= SUNDIALS_LOGGING_DEBUG + SUNLogger_QueueMsg(ARK_LOGGER, SUN_LOGLEVEL_DEBUG, + "ARKODE::arkInitialSetup", "test-tstop", + "h = %" RSYM ", tcur = %" RSYM ", tout = %" RSYM + ", tstop = %" RSYM, + ark_mem->h, ark_mem->tcur, tout, ark_mem->tstop); +#endif + htmp = (ark_mem->h == ZERO) ? tout - ark_mem->tcur : ark_mem->h; if ((ark_mem->tstop - ark_mem->tcur) * htmp <= ZERO) { @@ -2716,8 +2732,10 @@ int arkCompleteStep(ARKodeMem ark_mem, sunrealtype dsm) /* update interpolation structure NOTE: This must be called before updating yn with ycur as the interpolation - module may need to save tn, yn from the start of this step */ - if (ark_mem->interp != NULL) + module may need to save tn, yn from the start of this step + + NOTE: When doing adjoint integration interpolation is disabled, so we skip this */ + if (ark_mem->interp != NULL && !ark_mem->do_adjoint) { retval = arkInterpUpdate(ark_mem, ark_mem->interp, ark_mem->tcur); if (retval != ARK_SUCCESS) { return (retval); } @@ -2742,6 +2760,7 @@ int arkCompleteStep(ARKodeMem ark_mem, sunrealtype dsm) /* update scalar quantities */ ark_mem->nst++; + ark_mem->checkpoint_step_idx++; ark_mem->hold = ark_mem->h; ark_mem->tn = ark_mem->tcur; ark_mem->hprime = ark_mem->h * ark_mem->eta; @@ -2869,6 +2888,10 @@ int arkHandleFailure(ARKodeMem ark_mem, int flag) arkProcessError(ark_mem, ARK_RELAX_JAC_FAIL, __LINE__, __func__, __FILE__, "The relaxation Jacobian failed unrecoverably"); break; + case ARK_ADJ_RECOMPUTE_FAIL: + arkProcessError(ark_mem, ARK_ADJ_RECOMPUTE_FAIL, __LINE__, __func__, __FILE__, + "The forward recomputation of step failed unrecoverably"); + break; case ARK_DOMEIG_FAIL: arkProcessError(ark_mem, ARK_DOMEIG_FAIL, __LINE__, __func__, __FILE__, "The dominant eigenvalue function failed unrecoverably"); diff --git a/src/arkode/arkode_arkstep.c b/src/arkode/arkode_arkstep.c index 0093728575..a2f4b3bf8e 100644 --- a/src/arkode/arkode_arkstep.c +++ b/src/arkode/arkode_arkstep.c @@ -18,15 +18,23 @@ #include #include #include + #include +#include + +#include +#include + #include +#include + #include "arkode/arkode.h" +#include "arkode/arkode_arkstep.h" #include "arkode/arkode_butcher.h" #include "arkode_arkstep_impl.h" #include "arkode_impl.h" #include "arkode_interp_impl.h" -#include "sundials/sundials_types.h" /*=============================================================== Exported functions @@ -1138,8 +1146,8 @@ int arkStep_Init(ARKodeMem ark_mem, SUNDIALS_MAYBE_UNUSED sunrealtype tout, } /* set appropriate TakeStep routine based on problem configuration */ - /* (only one choice for now) */ - ark_mem->step = arkStep_TakeStep_Z; + if (ark_mem->do_adjoint) { ark_mem->step = arkStep_TakeStep_ERK_Adjoint; } + else { ark_mem->step = arkStep_TakeStep_Z; } /* Check for consistency between mass system and system linear system modules (e.g., if lsolve is direct, msolve needs to match) */ @@ -1715,10 +1723,42 @@ int arkStep_TakeStep_Z(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) } } - /* explicit first stage -- store stage if necessary for relaxation */ - if (is_start == 1 && save_stages) + /* explicit first stage -- store stage if necessary for relaxation or checkpointing */ + if (is_start == 1) { - N_VScale(ONE, ark_mem->yn, step_mem->z[0]); + if (save_stages) { N_VScale(ONE, ark_mem->yn, step_mem->z[0]); } + + if (ark_mem->checkpoint_scheme) + { + sunbooleantype do_save; + SUNErrCode errcode = + SUNAdjointCheckpointScheme_ShouldWeSave(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, 0, + ark_mem->tcur, &do_save); + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_ShouldWeSave returned %d", + errcode); + } + + if (do_save) + { + errcode = + SUNAdjointCheckpointScheme_InsertVector(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, 0, + ark_mem->tcur, ark_mem->ycur); + + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_InsertVector returned %d", + errcode); + } + } + } } /* check if the method is Stiffly Accurate (SA) */ @@ -1983,6 +2023,38 @@ int arkStep_TakeStep_Z(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) /* store stage (if necessary for relaxation) */ if (save_stages) { N_VScale(ONE, ark_mem->ycur, step_mem->z[is]); } + /* checkpoint stage for adjoint (if necessary) */ + if (ark_mem->checkpoint_scheme) + { + sunbooleantype do_save; + SUNErrCode errcode = + SUNAdjointCheckpointScheme_ShouldWeSave(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, + is, ark_mem->tcur, &do_save); + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_ShouldWeSave returned %d", + errcode); + } + + if (do_save) + { + SUNAdjointCheckpointScheme_InsertVector(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, is, + ark_mem->tcur, ark_mem->ycur); + + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_InsertVector returned %d", + errcode); + } + } + } + /* store implicit RHS (value in Fi[is] is from preceding nonlinear iteration) */ if (step_mem->implicit) { @@ -2099,12 +2171,212 @@ int arkStep_TakeStep_Z(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) if (*nflagPtr < 0) { return (*nflagPtr); } if (*nflagPtr > 0) { return (TRY_AGAIN); } + if (ark_mem->checkpoint_scheme) + { + sunbooleantype do_save; + SUNAdjointCheckpointScheme_ShouldWeSave(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, + step_mem->Be->stages, + ark_mem->tn + ark_mem->h, &do_save); + if (do_save) + { + SUNAdjointCheckpointScheme_InsertVector(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, + step_mem->Be->stages, + ark_mem->tn + ark_mem->h, + ark_mem->ycur); + } + } + SUNLogExtraDebugVec(ARK_LOGGER, "updated solution", ark_mem->ycur, "ycur(:) ="); SUNLogInfo(ARK_LOGGER, "end-compute-solution", "status = success"); return (ARK_SUCCESS); } +/*--------------------------------------------------------------- + arkStep_TakeStep_ERK_Adjoint: + + This routine performs a single backwards step of the discrete + adjoint of the ERK method. + + Since we are not doing error control during the adjoint integration, + the output variable dsmPtr should should be 0. + + The input/output variable nflagPtr is used to gauge convergence + of any algebraic solvers within the step. In this case, it should + always be 0 since we do not do any algebraic solves. + + The return value from this routine is: + 0 => step completed successfully + >0 => step encountered recoverable failure; + reduce step and retry (if possible) + <0 => step encountered unrecoverable failure + ---------------------------------------------------------------*/ +int arkStep_TakeStep_ERK_Adjoint(ARKodeMem ark_mem, sunrealtype* dsmPtr, + int* nflagPtr) +{ + int retval = ARK_SUCCESS; + + ARKodeARKStepMem step_mem; + + /* access ARKodeARKStepMem structure */ + retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem); + if (retval != ARK_SUCCESS) { return (retval); } + + SUNLogDebug(ARK_LOGGER, "ARKODE::arkStep_TakeStep_ERK_Adjoint", "start-step", + "step = %li, h = %" RSYM ", dsm = %" RSYM ", nflag = %d", + ark_mem->nst, ark_mem->h, *dsmPtr, *nflagPtr); + + /* local shortcuts for readability */ + SUNAdjointStepper adj_stepper = (SUNAdjointStepper)ark_mem->user_data; + sunrealtype* cvals = step_mem->cvals; + N_Vector* Xvecs = step_mem->Xvecs; + N_Vector sens_np1 = ark_mem->yn; + N_Vector sens_n = ark_mem->ycur; + N_Vector sens_tmp = step_mem->sdata; + N_Vector Lambda_tmp = N_VGetSubvector_ManyVector(sens_tmp, 0); + N_Vector lambda_np1 = N_VGetSubvector_ManyVector(sens_np1, 0); + N_Vector* stage_values = step_mem->Fe; + + /* determine if method has fsal property */ + sunbooleantype fsal = (SUNRabs(step_mem->Be->A[0][0]) <= TINY) && + ARKodeButcherTable_IsStifflyAccurate(step_mem->Be); + + /* Loop over stages */ + if (fsal) { N_VConst(SUN_RCONST(0.0), stage_values[step_mem->stages - 1]); } + for (int is = step_mem->stages - (fsal ? 2 : 1); is >= 0; --is) + { + /* which stage is being processed -- needed for loading checkpoints */ + ark_mem->adj_stage_idx = is; + + /* Set current stage time(s) and index */ + ark_mem->tcur = ark_mem->tn + + ark_mem->h * (SUN_RCONST(1.0) - step_mem->Be->c[is]); + + /* + * Compute partial current stage value \Lambda + */ + int nvec = 0; + for (int js = is + 1; js < step_mem->stages; ++js) + { + /* h sum_{j=i}^{s} A_{ji}/b_i \Lambda_{j} */ + if (step_mem->Be->b[is] > SUN_UNIT_ROUNDOFF) + { + cvals[nvec] = -ark_mem->h * step_mem->Be->A[js][is] / step_mem->Be->b[is]; + } + else { cvals[nvec] = -ark_mem->h * step_mem->Be->A[js][is]; } + Xvecs[nvec] = N_VGetSubvector_ManyVector(stage_values[js], 0); + nvec++; + } + cvals[nvec] = -ark_mem->h * step_mem->Be->b[is]; + Xvecs[nvec] = lambda_np1; + nvec++; + + /* h b_i \lambda_{n+1} + h sum_{j=i}^{s} A_{ji} \Lambda_{j} */ + retval = N_VLinearCombination(nvec, cvals, Xvecs, Lambda_tmp); + if (retval != 0) { return (ARK_VECTOROP_ERR); } + + /* Compute stage values \Lambda_i, \nu_i by applying f_{y,p}^T (which is what fe does in this case) */ + retval = step_mem->fe(ark_mem->tcur, sens_tmp, stage_values[is], + ark_mem->user_data); + step_mem->nfe++; + + /* The checkpoint was not found, so we need to recompute at least + this step forward in time. We first seek the last checkpointed step + solution, then recompute from there. */ + if (retval > 0) + { + N_Vector checkpoint = N_VGetSubvector_ManyVector(ark_mem->tempv2, 0); + int64_t start_step = adj_stepper->step_idx; + + SUNErrCode errcode = SUN_ERR_CHECKPOINT_NOT_FOUND; + for (int64_t i = 0; i <= adj_stepper->step_idx; ++i, --start_step) + { + SUNDIALS_MAYBE_UNUSED int64_t stop_step = adj_stepper->step_idx + 1; + SUNLogDebug(ARK_LOGGER, "ARKODE::arkStep_TakeStep_ERK_Adjoint", + "searching-for-checkpoint", + "start_step = %li, stop_step = %li", start_step, stop_step); + sunrealtype checkpoint_t; + errcode = + SUNAdjointCheckpointScheme_LoadVector(ark_mem->checkpoint_scheme, + start_step, step_mem->stages, 1, + &checkpoint, &checkpoint_t); + if (errcode == SUN_SUCCESS) + { + /* OK, now we have the last checkpoint that stored as (start_step, stages). + This represents the last step solution that was checkpointed. As such, we + want to recompute from start_step+1 to stop_step. */ + start_step++; + sunrealtype t0 = checkpoint_t; + sunrealtype tf = ark_mem->tn; + SUNLogDebug(ARK_LOGGER, "ARKODE::arkStep_TakeStep_ERK_Adjoint", + "start-recompute", + "start_step = %li, stop_step = %li, t0 = %" RSYM + ", tf = %" RSYM "", + start_step, stop_step, t0, tf); + if (SUNAdjointStepper_RecomputeFwd(adj_stepper, start_step, t0, tf, + checkpoint)) + { + return (ARK_ADJ_RECOMPUTE_FAIL); + } + SUNLogDebug(ARK_LOGGER, "ARKODE::arkStep_TakeStep_ERK_Adjoint", + "end-recompute", + "start_step = %li, stop_step = %li, t0 = %" RSYM + ", tf = %" RSYM "", + start_step, stop_step, t0, tf); + return arkStep_TakeStep_ERK_Adjoint(ark_mem, dsmPtr, nflagPtr); + } + } + if (errcode != SUN_SUCCESS) { return (ARK_RHSFUNC_FAIL); } + } + else if (retval < 0) { return (ARK_RHSFUNC_FAIL); } + } + + /* Throw away the step solution */ + sunrealtype checkpoint_t = 0.0; + N_Vector checkpoint = N_VGetSubvector_ManyVector(ark_mem->tempv2, 0); + SUNErrCode errcode = + SUNAdjointCheckpointScheme_LoadVector(ark_mem->checkpoint_scheme, + adj_stepper->step_idx, 0, 0, + &checkpoint, &checkpoint_t); + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_LoadVector returned %d", errcode); + } + + /* Now compute the time step solution. We cannot use arkStep_ComputeSolutions because the + adjoint calculation for the time step solution is different than the forward case. */ + + int nvec = 0; + for (int j = 0; j < step_mem->stages; j++) + { + cvals[nvec] = ONE; + Xvecs[nvec] = + stage_values[j]; // this needs to be the stage values [Lambda_i, nu_i] + nvec++; + } + cvals[nvec] = ONE; + Xvecs[nvec] = sens_np1; + nvec++; + + /* \lambda_n = \lambda_{n+1} + \sum_{j=1}^{s} \Lambda_j + \mu_n = \mu_{n+1} + \sum_{j=1}^{s} \nu_j */ + retval = N_VLinearCombination(nvec, cvals, Xvecs, sens_n); + if (retval != 0) { return (ARK_VECTOROP_ERR); } + + *dsmPtr = ZERO; + *nflagPtr = 0; + + SUNLogDebug(ARK_LOGGER, "ARKODE::arkStep_TakeStep_ERK_Adjoint", "end-step", + "step = %li, h = %" RSYM ", dsm = %" RSYM ", nflag = %d", + ark_mem->nst, ark_mem->h, *dsmPtr, *nflagPtr); + + return (ARK_SUCCESS); +} + /*=============================================================== Internal utility routines ===============================================================*/ @@ -3146,6 +3418,266 @@ int arkStep_ComputeSolutions_MassFixed(ARKodeMem ark_mem, sunrealtype* dsmPtr) return (ARK_SUCCESS); } +/*--------------------------------------------------------------- + Utility routines for interfacing with SUNAdjointStepper + ---------------------------------------------------------------*/ + +int arkStep_fe_Adj(sunrealtype t, N_Vector sens_partial_stage, + N_Vector sens_complete_stage, void* content) +{ + SUNErrCode errcode = SUN_SUCCESS; + + SUNAdjointStepper adj_stepper = (SUNAdjointStepper)content; + SUNAdjointCheckpointScheme check_scheme = adj_stepper->checkpoint_scheme; + ARKodeMem ark_mem = (ARKodeMem)adj_stepper->adj_sunstepper->content; + void* user_data = adj_stepper->user_data; + + N_Vector Lambda_part = N_VGetSubvector_ManyVector(sens_partial_stage, 0); + N_Vector Lambda = N_VGetSubvector_ManyVector(sens_complete_stage, 0); + N_Vector checkpoint = N_VGetSubvector_ManyVector(ark_mem->tempv2, 0); + sunrealtype checkpoint_t = SUN_RCONST(0.0); + + errcode = SUNAdjointCheckpointScheme_LoadVector(check_scheme, + adj_stepper->step_idx, + ark_mem->adj_stage_idx, 0, + &checkpoint, &checkpoint_t); + + // Checkpoint was not found, recompute the missing step + if (errcode == SUN_ERR_CHECKPOINT_NOT_FOUND) { return +1; } + + if (adj_stepper->JacFn) + { + adj_stepper->JacFn(t, checkpoint, NULL, adj_stepper->Jac, user_data, NULL, + NULL, NULL); + adj_stepper->njeval++; + if (SUNMatMatTransposeVec(adj_stepper->Jac, Lambda_part, Lambda)) + { + return -1; + }; + } + else if (adj_stepper->JvpFn) + { + adj_stepper->JvpFn(Lambda_part, Lambda, t, checkpoint, NULL, user_data, NULL); + + adj_stepper->njtimesv++; + } + else if (adj_stepper->vJpFn) + { + adj_stepper->vJpFn(Lambda_part, Lambda, t, checkpoint, NULL, user_data, NULL); + adj_stepper->nvtimesj++; + } + + if (adj_stepper->JacPFn) + { + if (N_VGetNumSubvectors_ManyVector(sens_complete_stage) < 2) { return -1; } + N_Vector nu = N_VGetSubvector_ManyVector(sens_complete_stage, 1); + adj_stepper->JacPFn(t, checkpoint, NULL, adj_stepper->JacP, user_data, NULL, + NULL, NULL); + adj_stepper->njpeval++; + if (SUNMatMatTransposeVec(adj_stepper->JacP, Lambda_part, nu)) + { + return -1; + } + } + else if (adj_stepper->JPvpFn) + { + if (N_VGetNumSubvectors_ManyVector(sens_complete_stage) < 2) { return -1; } + N_Vector nu = N_VGetSubvector_ManyVector(sens_complete_stage, 1); + adj_stepper->JPvpFn(Lambda_part, nu, t, checkpoint, NULL, user_data, NULL); + adj_stepper->njptimesv++; + } + else if (adj_stepper->vJPpFn) + { + if (N_VGetNumSubvectors_ManyVector(sens_complete_stage) < 2) { return -1; } + N_Vector nu = N_VGetSubvector_ManyVector(sens_complete_stage, 1); + adj_stepper->vJPpFn(Lambda_part, nu, t, checkpoint, NULL, user_data, NULL); + adj_stepper->nvtimesjp++; + } + + return 0; +} + +int arkStepCompatibleWithAdjointSolver(ARKodeMem ark_mem, + ARKodeARKStepMem step_mem, int lineno, + const char* fname, const char* filename) +{ + if (!ark_mem->fixedstep) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, + filename, "ARKStep must be using a fixed step to work with SUNAdjointStepper"); + return ARK_ILL_INPUT; + } + + if (step_mem->fi) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, + filename, "SUNAdjointStepper requires fi = NULL (it only supports explicit RK methods)"); + return ARK_ILL_INPUT; + } + + if (!step_mem->fe) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, + filename, "fe must have been provided to ARKStepCreate to create a SUNAdjointStepper"); + return ARK_ILL_INPUT; + } + + if (ark_mem->relax_enabled) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, filename, + "SUNAdjointStepper is not compatible with relaxation"); + return ARK_ILL_INPUT; + } + + if (step_mem->mass_type != MASS_IDENTITY) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, + filename, "SUNAdjointStepper is not compatible with non-identity mass matrices"); + return ARK_ILL_INPUT; + } + + return ARK_SUCCESS; +} + +int ARKStepCreateAdjointStepper(void* arkode_mem, N_Vector sf, + SUNAdjointStepper* adj_stepper_ptr) +{ + ARKodeMem ark_mem; + ARKodeARKStepMem step_mem; + int retval = arkStep_AccessARKODEStepMem(arkode_mem, + "ARKStepCreateAdjointStepper", + &ark_mem, &step_mem); + if (retval) + { + arkProcessError(NULL, ARK_ILL_INPUT, __LINE__, __func__, __FILE__, + "The ARKStep memory pointer is NULL"); + return ARK_ILL_INPUT; + } + + if (arkStepCompatibleWithAdjointSolver(ark_mem, step_mem, __LINE__, __func__, + __FILE__)) + { + return ARK_ILL_INPUT; + } + + /** + Create and configure the ARKStep stepper for the adjoint system + */ + long nst = 0; + retval = ARKodeGetNumSteps(arkode_mem, &nst); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeGetNumSteps failed"); + return retval; + } + + void* arkode_mem_adj = ARKStepCreate(arkStep_fe_Adj, NULL, ark_mem->tretlast, + sf, ark_mem->sunctx); + ARKodeMem ark_mem_adj = (ARKodeMem)arkode_mem_adj; + + ark_mem_adj->do_adjoint = SUNTRUE; + + retval = ARKodeSetFixedStep(arkode_mem_adj, -ark_mem->h); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetFixedStep failed"); + return retval; + } + + retval = ARKStepSetTables(arkode_mem_adj, step_mem->Be->q, step_mem->Be->p, + step_mem->Bi, step_mem->Be); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKStepSetTables failed"); + return retval; + } + + retval = ARKodeSetMaxNumSteps(arkode_mem_adj, nst); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetMaxNumSteps failed"); + return retval; + } + + retval = ARKodeSetAdjointCheckpointScheme(arkode_mem_adj, + ark_mem->checkpoint_scheme); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetAdjointCheckpointScheme failed"); + return retval; + } + + /* SUNAdjointStepper will own the SUNSteppers and destroy them */ + SUNStepper fwd_stepper; + retval = ARKodeCreateSUNStepper(arkode_mem, &fwd_stepper); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeCreateSUNStepper failed"); + return retval; + } + + SUNStepper adj_stepper; + retval = ARKodeCreateSUNStepper(arkode_mem_adj, &adj_stepper); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeCreateSUNStepper failed"); + return retval; + } + + SUNErrCode errcode = SUN_SUCCESS; + + /* Setting this ensures that the ARKodeMem underneath the adj_stepper + is destroyed with the SUNStepper_Destroy call. */ + errcode = SUNStepper_SetDestroyFn(adj_stepper, arkSUNStepperSelfDestruct); + if (errcode) + { + retval = ARK_UNRECOGNIZED_ERROR; + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "SUNStepper_SetDestroyFn failed"); + return retval; + } + + errcode = SUNAdjointStepper_Create(fwd_stepper, adj_stepper, nst - 1, sf, + ark_mem->tretlast, + ark_mem->checkpoint_scheme, + ark_mem->sunctx, adj_stepper_ptr); + if (errcode) + { + retval = ARK_UNRECOGNIZED_ERROR; + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "SUNAdjointStepper_Create failed"); + return retval; + } + + errcode = SUNAdjointStepper_SetUserData(*adj_stepper_ptr, ark_mem->user_data); + if (errcode) + { + retval = ARK_UNRECOGNIZED_ERROR; + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "SUNAdjointStepper_SetUserData failed"); + return retval; + } + + /* We need access to the adjoint solver to access the parameter Jacobian inside of ARKStep's + backwards integration of the the adjoint problem. */ + retval = ARKodeSetUserData(arkode_mem_adj, *adj_stepper_ptr); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetUserData failed"); + return retval; + } + + return ARK_SUCCESS; +} + /*=============================================================== Internal utility routines for interacting with MRIStep ===============================================================*/ @@ -3230,10 +3762,9 @@ int arkStep_SetInnerForcing(ARKodeMem ark_mem, sunrealtype tshift, sunrealtype tscale, N_Vector* forcing, int nvecs) { ARKodeARKStepMem step_mem; - int retval; /* access ARKodeARKStepMem structure */ - retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem); + int retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem); if (retval != ARK_SUCCESS) { return (retval); } if (nvecs > 0) diff --git a/src/arkode/arkode_arkstep_impl.h b/src/arkode/arkode_arkstep_impl.h index 9dcc375bbf..95baa06a0a 100644 --- a/src/arkode/arkode_arkstep_impl.h +++ b/src/arkode/arkode_arkstep_impl.h @@ -189,6 +189,8 @@ int arkStep_GetGammas(ARKodeMem ark_mem, sunrealtype* gamma, sunrealtype* gamrat sunbooleantype** jcur, sunbooleantype* dgamma_fail); int arkStep_FullRHS(ARKodeMem ark_mem, sunrealtype t, N_Vector y, N_Vector f, int mode); +int arkStep_TakeStep_ERK_Adjoint(ARKodeMem ark_mem, sunrealtype* dsmPtr, + int* nflagPtr); int arkStep_TakeStep_Z(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr); int arkStep_SetUserData(ARKodeMem ark_mem, void* user_data); int arkStep_SetDefaults(ARKodeMem ark_mem); @@ -277,6 +279,14 @@ int arkStep_RelaxDeltaE(ARKodeMem ark_mem, ARKRelaxJacFn relax_jac_fn, long int* relax_jac_fn_evals, sunrealtype* delta_e_out); int arkStep_GetOrder(ARKodeMem ark_mem); +/* private functions for adjoints */ +int arkStep_fe_Adj(sunrealtype t, N_Vector sens_partial_stage, + N_Vector sens_complete_stage, void* content); + +int arkStepCompatibleWithAdjointSolver(ARKodeMem ark_mem, + ARKodeARKStepMem step_mem, int lineno, + const char* fname, const char* filename); + /*=============================================================== Reusable ARKStep Error Messages ===============================================================*/ diff --git a/src/arkode/arkode_erkstep.c b/src/arkode/arkode_erkstep.c index ab2f150b63..a575b526d5 100644 --- a/src/arkode/arkode_erkstep.c +++ b/src/arkode/arkode_erkstep.c @@ -18,9 +18,15 @@ #include #include #include + #include #include +#include +#include + +#include + #include "arkode/arkode_butcher.h" #include "arkode_erkstep_impl.h" #include "arkode_impl.h" @@ -512,6 +518,10 @@ int erkStep_Init(ARKodeMem ark_mem, SUNDIALS_MAYBE_UNUSED sunrealtype tout, ark_mem->interp_degree = 1; } + /* set appropriate TakeStep routine based on problem configuration */ + if (ark_mem->do_adjoint) { ark_mem->step = erkStep_TakeStep_Adjoint; } + else { ark_mem->step = erkStep_TakeStep; } + /* Signal to shared arkode module that full RHS evaluations are required */ ark_mem->call_fullrhs = SUNTRUE; @@ -758,6 +768,39 @@ int erkStep_TakeStep(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) } SUNLogExtraDebugVec(ARK_LOGGER, "stage RHS", step_mem->F[0], "F_0(:) ="); + + if (ark_mem->checkpoint_scheme) + { + sunbooleantype do_save; + SUNErrCode errcode = + SUNAdjointCheckpointScheme_ShouldWeSave(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, 0, + ark_mem->tcur, &do_save); + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_ShouldWeSave returned %d", + errcode); + } + + if (do_save) + { + errcode = + SUNAdjointCheckpointScheme_InsertVector(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, 0, + ark_mem->tcur, ark_mem->ycur); + + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_InsertVector returned %d", + errcode); + } + } + } + SUNLogInfo(ARK_LOGGER, "end-stage", "status = success"); /* Loop over internal stages to the step; since the method is explicit @@ -829,6 +872,38 @@ int erkStep_TakeStep(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) if (retval < 0) { return (ARK_RHSFUNC_FAIL); } if (retval > 0) { return (ARK_UNREC_RHSFUNC_ERR); } + /* checkpoint stage for adjoint (if necessary) */ + if (ark_mem->checkpoint_scheme) + { + sunbooleantype do_save; + SUNErrCode errcode = + SUNAdjointCheckpointScheme_ShouldWeSave(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, + is, ark_mem->tcur, &do_save); + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_ShouldWeSave returned %d", + errcode); + } + + if (do_save) + { + SUNAdjointCheckpointScheme_InsertVector(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, is, + ark_mem->tcur, ark_mem->ycur); + + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_InsertVector returned %d", + errcode); + } + } + } + SUNLogInfo(ARK_LOGGER, "end-stage", "status = success"); } /* loop over stages */ @@ -847,6 +922,205 @@ int erkStep_TakeStep(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) SUNLogExtraDebugVec(ARK_LOGGER, "updated solution", ark_mem->ycur, "ycur(:) ="); SUNLogInfo(ARK_LOGGER, "end-compute-solution", "status = success"); + if (ark_mem->checkpoint_scheme) + { + sunbooleantype do_save; + SUNAdjointCheckpointScheme_ShouldWeSave(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, + step_mem->B->stages, + ark_mem->tn + ark_mem->h, &do_save); + if (do_save) + { + SUNAdjointCheckpointScheme_InsertVector(ark_mem->checkpoint_scheme, + ark_mem->checkpoint_step_idx, + step_mem->B->stages, + ark_mem->tn + ark_mem->h, + ark_mem->ycur); + } + } + + return (ARK_SUCCESS); +} + +/*--------------------------------------------------------------- + erkStep_TakeStep_Adjoint: + + This routine performs a single backwards step of the discrete + adjoint of the ERK method. + + Since we are not doing error control during the adjoint integration, + the output variable dsmPtr should should be 0. + + The input/output variable nflagPtr is used to gauge convergence + of any algebraic solvers within the step. In this case, it should + always be 0 since we do not do any algebraic solves. + + The return value from this routine is: + 0 => step completed successfully + >0 => step encountered recoverable failure; + reduce step and retry (if possible) + <0 => step encountered unrecoverable failure + ---------------------------------------------------------------*/ +int erkStep_TakeStep_Adjoint(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr) +{ + int retval = ARK_SUCCESS; + + ARKodeERKStepMem step_mem; + + /* access ARKodeERKStepMem structure */ + retval = erkStep_AccessStepMem(ark_mem, __func__, &step_mem); + if (retval != ARK_SUCCESS) { return (retval); } + + SUNLogDebug(ARK_LOGGER, "ARKODE::erkStep_TakeStep_ERK_Adjoint", "start-step", + "step = %li, h = %" RSYM ", dsm = %" RSYM ", nflag = %d", + ark_mem->nst, ark_mem->h, *dsmPtr, *nflagPtr); + + /* local shortcuts for readability */ + SUNAdjointStepper adj_stepper = (SUNAdjointStepper)ark_mem->user_data; + sunrealtype* cvals = step_mem->cvals; + N_Vector* Xvecs = step_mem->Xvecs; + N_Vector sens_np1 = ark_mem->yn; + N_Vector sens_n = ark_mem->ycur; + N_Vector sens_tmp = ark_mem->tempv2; + N_Vector Lambda_tmp = N_VGetSubvector_ManyVector(sens_tmp, 0); + N_Vector lambda_np1 = N_VGetSubvector_ManyVector(sens_np1, 0); + N_Vector* stage_values = step_mem->F; + + /* determine if method has fsal property */ + sunbooleantype fsal = (SUNRabs(step_mem->B->A[0][0]) <= TINY) && + ARKodeButcherTable_IsStifflyAccurate(step_mem->B); + + /* Loop over stages */ + if (fsal) { N_VConst(SUN_RCONST(0.0), stage_values[step_mem->stages - 1]); } + for (int is = step_mem->stages - (fsal ? 2 : 1); is >= 0; --is) + { + /* which stage is being processed -- needed for loading checkpoints */ + ark_mem->adj_stage_idx = is; + + /* Set current stage time(s) and index */ + ark_mem->tcur = ark_mem->tn + + ark_mem->h * (SUN_RCONST(1.0) - step_mem->B->c[is]); + + /* + * Compute partial current stage value \Lambda + */ + int nvec = 0; + for (int js = is + 1; js < step_mem->stages; ++js) + { + /* h sum_{j=i}^{s} A_{ji}/b_i \Lambda_{j} */ + if (step_mem->B->b[is] > SUN_UNIT_ROUNDOFF) + { + cvals[nvec] = -ark_mem->h * step_mem->B->A[js][is] / step_mem->B->b[is]; + } + else { cvals[nvec] = -ark_mem->h * step_mem->B->A[js][is]; } + Xvecs[nvec] = N_VGetSubvector_ManyVector(stage_values[js], 0); + nvec++; + } + cvals[nvec] = -ark_mem->h * step_mem->B->b[is]; + Xvecs[nvec] = lambda_np1; + nvec++; + + /* h b_i \lambda_{n+1} + h sum_{j=i}^{s} A_{ji} \Lambda_{j} */ + retval = N_VLinearCombination(nvec, cvals, Xvecs, Lambda_tmp); + if (retval != 0) { return (ARK_VECTOROP_ERR); } + + /* Compute stage values \Lambda_i, \nu_i by applying f_{y,p}^T (which is what fe does in this case) */ + retval = step_mem->f(ark_mem->tcur, sens_tmp, stage_values[is], + ark_mem->user_data); + step_mem->nfe++; + + /* The checkpoint was not found, so we need to recompute at least + this step forward in time. We first seek the last checkpointed step + solution, then recompute from there. */ + if (retval > 0) + { + N_Vector checkpoint = N_VGetSubvector_ManyVector(ark_mem->tempv3, 0); + int64_t start_step = adj_stepper->step_idx; + + SUNErrCode errcode = SUN_ERR_CHECKPOINT_NOT_FOUND; + for (int64_t i = 0; i <= adj_stepper->step_idx; ++i, --start_step) + { + SUNDIALS_MAYBE_UNUSED int64_t stop_step = adj_stepper->step_idx + 1; + SUNLogDebug(ARK_LOGGER, "ARKODE::erkStep_TakeStep_Adjoint", + "searching-for-checkpoint", + "start_step = %li, stop_step = %li", start_step, stop_step); + sunrealtype checkpoint_t; + errcode = + SUNAdjointCheckpointScheme_LoadVector(ark_mem->checkpoint_scheme, + start_step, step_mem->stages, 1, + &checkpoint, &checkpoint_t); + if (errcode == SUN_SUCCESS) + { + /* OK, now we have the last checkpoint that stored as (start_step, stages). + This represents the last step solution that was checkpointed. As such, we + want to recompute from start_step+1 to stop_step. */ + start_step++; + sunrealtype t0 = checkpoint_t; + sunrealtype tf = ark_mem->tn; + SUNLogDebug(ARK_LOGGER, "ARKODE::erkStep_TakeStep_ERK_Adjoint", + "start-recompute", + "start_step = %li, stop_step = %li, t0 = %" RSYM + ", tf = %" RSYM "", + start_step, stop_step, t0, tf); + if (SUNAdjointStepper_RecomputeFwd(adj_stepper, start_step, t0, tf, + checkpoint)) + { + return (ARK_ADJ_RECOMPUTE_FAIL); + } + SUNLogDebug(ARK_LOGGER, "ARKODE::erkStep_TakeStep_ERK_Adjoint", + "end-recompute", + "start_step = %li, stop_step = %li, t0 = %" RSYM + ", tf = %" RSYM "", + start_step, stop_step, t0, tf); + return erkStep_TakeStep_Adjoint(ark_mem, dsmPtr, nflagPtr); + } + } + if (errcode != SUN_SUCCESS) { return (ARK_RHSFUNC_FAIL); } + } + else if (retval < 0) { return (ARK_RHSFUNC_FAIL); } + } + + /* Throw away the step solution */ + sunrealtype checkpoint_t = 0.0; + N_Vector checkpoint = N_VGetSubvector_ManyVector(ark_mem->tempv2, 0); + SUNErrCode errcode = + SUNAdjointCheckpointScheme_LoadVector(ark_mem->checkpoint_scheme, + adj_stepper->step_idx, 0, 0, + &checkpoint, &checkpoint_t); + if (errcode) + { + arkProcessError(ark_mem, ARK_ADJ_CHECKPOINT_FAIL, __LINE__, __func__, + __FILE__, + "SUNAdjointCheckpointScheme_LoadVector returned %d", errcode); + } + + /* Now compute the time step solution. We cannot use erkStep_ComputeSolutions because the + adjoint calculation for the time step solution is different than the forward case. */ + + int nvec = 0; + for (int j = 0; j < step_mem->stages; j++) + { + cvals[nvec] = ONE; + Xvecs[nvec] = + stage_values[j]; // this needs to be the stage values [Lambda_i, nu_i] + nvec++; + } + cvals[nvec] = ONE; + Xvecs[nvec] = sens_np1; + nvec++; + + /* \lambda_n = \lambda_{n+1} + \sum_{j=1}^{s} \Lambda_j + \mu_n = \mu_{n+1} + \sum_{j=1}^{s} \nu_j */ + retval = N_VLinearCombination(nvec, cvals, Xvecs, sens_n); + if (retval != 0) { return (ARK_VECTOROP_ERR); } + + *dsmPtr = ZERO; + *nflagPtr = 0; + + SUNLogDebug(ARK_LOGGER, "ARKODE::erkStep_TakeStep_ERK_Adjoint", "end-step", + "step = %li, h = %" RSYM ", dsm = %" RSYM ", nflag = %d", + ark_mem->nst, ark_mem->h, *dsmPtr, *nflagPtr); + return (ARK_SUCCESS); } @@ -1273,6 +1547,244 @@ int erkStep_GetOrder(ARKodeMem ark_mem) return step_mem->q; } +/*--------------------------------------------------------------- + Utility routines for interfacing with SUNAdjointStepper + ---------------------------------------------------------------*/ + +int erkStep_fe_Adj(sunrealtype t, N_Vector sens_partial_stage, + N_Vector sens_complete_stage, void* content) +{ + SUNErrCode errcode = SUN_SUCCESS; + + SUNAdjointStepper adj_stepper = (SUNAdjointStepper)content; + SUNAdjointCheckpointScheme check_scheme = adj_stepper->checkpoint_scheme; + ARKodeMem ark_mem = (ARKodeMem)adj_stepper->adj_sunstepper->content; + void* user_data = adj_stepper->user_data; + + N_Vector Lambda_part = N_VGetSubvector_ManyVector(sens_partial_stage, 0); + N_Vector Lambda = N_VGetSubvector_ManyVector(sens_complete_stage, 0); + N_Vector checkpoint = N_VGetSubvector_ManyVector(ark_mem->tempv3, 0); + sunrealtype checkpoint_t = SUN_RCONST(0.0); + + errcode = SUNAdjointCheckpointScheme_LoadVector(check_scheme, + adj_stepper->step_idx, + ark_mem->adj_stage_idx, 0, + &checkpoint, &checkpoint_t); + + // Checkpoint was not found, recompute the missing step + if (errcode == SUN_ERR_CHECKPOINT_NOT_FOUND) { return +1; } + + if (adj_stepper->JacFn) + { + adj_stepper->JacFn(t, checkpoint, NULL, adj_stepper->Jac, user_data, NULL, + NULL, NULL); + adj_stepper->njeval++; + if (SUNMatMatTransposeVec(adj_stepper->Jac, Lambda_part, Lambda)) + { + return -1; + }; + } + else if (adj_stepper->JvpFn) + { + adj_stepper->JvpFn(Lambda_part, Lambda, t, checkpoint, NULL, user_data, NULL); + + adj_stepper->njtimesv++; + } + else if (adj_stepper->vJpFn) + { + adj_stepper->vJpFn(Lambda_part, Lambda, t, checkpoint, NULL, user_data, NULL); + adj_stepper->nvtimesj++; + } + + if (adj_stepper->JacPFn) + { + if (N_VGetNumSubvectors_ManyVector(sens_complete_stage) < 2) { return -1; } + N_Vector nu = N_VGetSubvector_ManyVector(sens_complete_stage, 1); + adj_stepper->JacPFn(t, checkpoint, NULL, adj_stepper->JacP, user_data, NULL, + NULL, NULL); + adj_stepper->njpeval++; + if (SUNMatMatTransposeVec(adj_stepper->JacP, Lambda_part, nu)) + { + return -1; + } + } + else if (adj_stepper->JPvpFn) + { + if (N_VGetNumSubvectors_ManyVector(sens_complete_stage) < 2) { return -1; } + N_Vector nu = N_VGetSubvector_ManyVector(sens_complete_stage, 1); + adj_stepper->JPvpFn(Lambda_part, nu, t, checkpoint, NULL, user_data, NULL); + adj_stepper->njptimesv++; + } + else if (adj_stepper->vJPpFn) + { + if (N_VGetNumSubvectors_ManyVector(sens_complete_stage) < 2) { return -1; } + N_Vector nu = N_VGetSubvector_ManyVector(sens_complete_stage, 1); + adj_stepper->vJPpFn(Lambda_part, nu, t, checkpoint, NULL, user_data, NULL); + adj_stepper->nvtimesjp++; + } + + return 0; +} + +int erkStepCompatibleWithAdjointSolver( + ARKodeMem ark_mem, SUNDIALS_MAYBE_UNUSED ARKodeERKStepMem step_mem, + int lineno, const char* fname, const char* filename) +{ + if (!ark_mem->fixedstep) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, + filename, "ERKStep must be using a fixed step to work with SUNAdjointStepper"); + return ARK_ILL_INPUT; + } + + if (ark_mem->relax_enabled) + { + arkProcessError(ark_mem, ARK_ILL_INPUT, lineno, fname, filename, + "SUNAdjointStepper is not compatible with relaxation"); + return ARK_ILL_INPUT; + } + + return ARK_SUCCESS; +} + +int ERKStepCreateAdjointStepper(void* arkode_mem, N_Vector sf, + SUNAdjointStepper* adj_stepper_ptr) +{ + ARKodeMem ark_mem; + ARKodeERKStepMem step_mem; + int retval = erkStep_AccessARKODEStepMem(arkode_mem, + "ERKStepCreateAdjointStepper", + &ark_mem, &step_mem); + if (retval) + { + arkProcessError(NULL, ARK_ILL_INPUT, __LINE__, __func__, __FILE__, + "The ERKStep memory pointer is NULL"); + return ARK_ILL_INPUT; + } + + if (erkStepCompatibleWithAdjointSolver(ark_mem, step_mem, __LINE__, __func__, + __FILE__)) + { + return ARK_ILL_INPUT; + } + + /** + Create and configure the ERKStep stepper for the adjoint system + */ + long nst = 0; + retval = ARKodeGetNumSteps(arkode_mem, &nst); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeGetNumSteps failed"); + return retval; + } + + void* arkode_mem_adj = ERKStepCreate(erkStep_fe_Adj, ark_mem->tretlast, sf, + ark_mem->sunctx); + ARKodeMem ark_mem_adj = (ARKodeMem)arkode_mem_adj; + + ark_mem_adj->do_adjoint = SUNTRUE; + + retval = ARKodeSetFixedStep(arkode_mem_adj, -ark_mem->h); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetFixedStep failed"); + return retval; + } + + retval = ERKStepSetTable(arkode_mem_adj, step_mem->B); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ERKStepSetTables failed"); + return retval; + } + + retval = ARKodeSetMaxNumSteps(arkode_mem_adj, nst); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetMaxNumSteps failed"); + return retval; + } + + retval = ARKodeSetAdjointCheckpointScheme(arkode_mem_adj, + ark_mem->checkpoint_scheme); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetAdjointCheckpointScheme failed"); + return retval; + } + + /* SUNAdjointStepper will own the SUNSteppers and destroy them */ + SUNStepper fwd_stepper; + retval = ARKodeCreateSUNStepper(arkode_mem, &fwd_stepper); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeCreateSUNStepper failed"); + return retval; + } + + SUNStepper adj_stepper; + retval = ARKodeCreateSUNStepper(arkode_mem_adj, &adj_stepper); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeCreateSUNStepper failed"); + return retval; + } + + SUNErrCode errcode = SUN_SUCCESS; + + /* Setting this ensures that the ARKodeMem underneath the adj_stepper + is destroyed with the SUNStepper_Destroy call. */ + errcode = SUNStepper_SetDestroyFn(adj_stepper, arkSUNStepperSelfDestruct); + if (errcode) + { + retval = ARK_UNRECOGNIZED_ERROR; + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "SUNStepper_SetDestroyFn failed"); + return retval; + } + + errcode = SUNAdjointStepper_Create(fwd_stepper, adj_stepper, nst - 1, sf, + ark_mem->tretlast, + ark_mem->checkpoint_scheme, + ark_mem->sunctx, adj_stepper_ptr); + if (errcode) + { + retval = ARK_UNRECOGNIZED_ERROR; + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "SUNAdjointStepper_Create failed"); + return retval; + } + + errcode = SUNAdjointStepper_SetUserData(*adj_stepper_ptr, ark_mem->user_data); + if (errcode) + { + retval = ARK_UNRECOGNIZED_ERROR; + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "SUNAdjointStepper_SetUserData failed"); + return retval; + } + + /* We need access to the adjoint solver to access the parameter Jacobian inside of ERKStep's + backwards integration of the the adjoint problem. */ + retval = ARKodeSetUserData(arkode_mem_adj, *adj_stepper_ptr); + if (retval) + { + arkProcessError(ark_mem, retval, __LINE__, __func__, __FILE__, + "ARKodeSetUserData failed"); + return retval; + } + + return ARK_SUCCESS; +} + /*--------------------------------------------------------------- Utility routines for ERKStep to serve as an MRIStepInnerStepper ---------------------------------------------------------------*/ diff --git a/src/arkode/arkode_erkstep_impl.h b/src/arkode/arkode_erkstep_impl.h index 8c6de3b7d4..519f64a9b5 100644 --- a/src/arkode/arkode_erkstep_impl.h +++ b/src/arkode/arkode_erkstep_impl.h @@ -81,6 +81,8 @@ int erkStep_Init(ARKodeMem ark_mem, sunrealtype tout, int init_type); int erkStep_FullRHS(ARKodeMem ark_mem, sunrealtype t, N_Vector y, N_Vector f, int mode); int erkStep_TakeStep(ARKodeMem ark_mem, sunrealtype* dsmPtr, int* nflagPtr); +int erkStep_TakeStep_Adjoint(ARKodeMem ark_mem, sunrealtype* dsmPtr, + int* nflagPtr); int erkStep_SetDefaults(ARKodeMem ark_mem); int erkStep_SetOrder(ARKodeMem ark_mem, int ord); int erkStep_PrintAllStats(ARKodeMem ark_mem, FILE* outfile, SUNOutputFormat fmt); @@ -113,6 +115,14 @@ int erkStep_RelaxDeltaE(ARKodeMem ark_mem, ARKRelaxJacFn relax_jac_fn, long int* relax_jac_fn_evals, sunrealtype* delta_e_out); int erkStep_GetOrder(ARKodeMem ark_mem); +/* private functions for adjoints */ +int erkStep_fe_Adj(sunrealtype t, N_Vector sens_partial_stage, + N_Vector sens_complete_stage, void* content); + +int erkStepCompatibleWithAdjointSolver(ARKodeMem ark_mem, + ARKodeERKStepMem step_mem, int lineno, + const char* fname, const char* filename); + /*=============================================================== Reusable ERKStep Error Messages ===============================================================*/ diff --git a/src/arkode/arkode_impl.h b/src/arkode/arkode_impl.h index 4ef09d44c4..cbc55dfd0e 100644 --- a/src/arkode/arkode_impl.h +++ b/src/arkode/arkode_impl.h @@ -24,11 +24,15 @@ #include #include #include + #include #include #include +#include +#include #include #include +#include #include "arkode_adapt_impl.h" #include "arkode_relaxation_impl.h" @@ -564,6 +568,14 @@ struct ARKodeMemRec sunbooleantype use_compensated_sums; + /* Adjoint solver data */ + sunbooleantype do_adjoint; + long int adj_stage_idx; /* current stage index (only valid in adjoint context)*/ + + /* Checkpointing data */ + SUNAdjointCheckpointScheme checkpoint_scheme; + int64_t checkpoint_step_idx; /* the step number for checkpointing */ + /* XBraid interface variables */ sunbooleantype force_pass; /* when true the step attempt loop will ignore the return value (kflag) from arkCheckTemporalError @@ -673,6 +685,9 @@ int ark_MRIStepInnerGetAccumulatedError(MRIStepInnerStepper stepper, int ark_MRIStepInnerResetAccumulatedError(MRIStepInnerStepper stepper); int ark_MRIStepInnerSetRTol(MRIStepInnerStepper stepper, sunrealtype rtol); +/* utility functions for wrapping ARKODE as a SUNStepper */ +SUNErrCode arkSUNStepperSelfDestruct(SUNStepper stepper); + /* XBraid interface functions */ int arkSetForcePass(void* arkode_mem, sunbooleantype force_pass); int arkGetLastKFlag(void* arkode_mem, int* last_kflag); diff --git a/src/arkode/arkode_io.c b/src/arkode/arkode_io.c index a06feb1d13..909565e23d 100644 --- a/src/arkode/arkode_io.c +++ b/src/arkode/arkode_io.c @@ -2082,6 +2082,41 @@ int ARKodeResetAccumulatedError(void* arkode_mem) /* Reset value and counter, and return */ ark_mem->AccumErrorStart = ark_mem->tn; ark_mem->AccumError = ZERO; + + return (ARK_SUCCESS); +} + +int ARKodeSetAdjointCheckpointScheme(void* arkode_mem, + SUNAdjointCheckpointScheme checkpoint_scheme) + +{ + ARKodeMem ark_mem; + if (arkode_mem == NULL) + { + arkProcessError(NULL, ARK_MEM_NULL, __LINE__, __func__, __FILE__, + MSG_ARK_NO_MEM); + return (ARK_MEM_NULL); + } + ark_mem = (ARKodeMem)arkode_mem; + + ark_mem->checkpoint_scheme = checkpoint_scheme; + + return (ARK_SUCCESS); +} + +int ARKodeSetAdjointCheckpointIndex(void* arkode_mem, int64_t step_index) +{ + ARKodeMem ark_mem; + if (arkode_mem == NULL) + { + arkProcessError(NULL, ARK_MEM_NULL, __LINE__, __func__, __FILE__, + MSG_ARK_NO_MEM); + return (ARK_MEM_NULL); + } + ark_mem = (ARKodeMem)arkode_mem; + + ark_mem->checkpoint_step_idx = step_index; + return (ARK_SUCCESS); } @@ -3045,6 +3080,7 @@ char* ARKodeGetReturnFlagName(long int flag) case ARK_RELAX_JAC_FAIL: sprintf(name, "ARK_RELAX_JAC_FAIL"); break; case ARK_CONTROLLER_ERR: sprintf(name, "ARK_CONTROLLER_ERR"); break; case ARK_STEPPER_UNSUPPORTED: sprintf(name, "ARK_STEPPER_UNSUPPORTED"); break; + case ARK_ADJ_RECOMPUTE_FAIL: sprintf(name, "ARK_ADJ_RECOMPUTE_FAIL"); break; case ARK_DOMEIG_FAIL: sprintf(name, "ARK_DOMEIG_FAIL"); break; case ARK_MAX_STAGE_LIMIT_FAIL: sprintf(name, "ARK_MAX_STAGE_LIMIT_FAIL"); diff --git a/src/arkode/arkode_sunstepper.c b/src/arkode/arkode_sunstepper.c index b5c4829e4a..2aed1504da 100644 --- a/src/arkode/arkode_sunstepper.c +++ b/src/arkode/arkode_sunstepper.c @@ -88,6 +88,7 @@ static SUNErrCode arkSUNStepperReset(SUNStepper stepper, sunrealtype tR, N_Vector yR) { SUNFunctionBegin(stepper->sunctx); + /* extract the ARKODE memory struct */ void* arkode_mem; SUNCheckCall(SUNStepper_GetContent(stepper, &arkode_mem)); @@ -146,6 +147,20 @@ static SUNErrCode arkSUNStepperSetForcing(SUNStepper stepper, sunrealtype tshift return SUN_SUCCESS; } +SUNErrCode arkSUNStepperSelfDestruct(SUNStepper stepper) +{ + /* This function is useful when we create a ARKodeMem/SUNStepper internally, + and want it to be destroyed with the SUNStepper. */ + ARKodeMem ark_mem; + + SUNErrCode errcode = SUNStepper_GetContent(stepper, (void**)&ark_mem); + if (errcode) { return errcode; } + + ARKodeFree((void**)&ark_mem); + + return SUN_SUCCESS; +} + int ARKodeCreateSUNStepper(void* arkode_mem, SUNStepper* stepper) { /* unpack ark_mem */ diff --git a/src/arkode/fmod_int32/farkode_arkstep_mod.c b/src/arkode/fmod_int32/farkode_arkstep_mod.c index d41ed70c4d..3122f2caec 100644 --- a/src/arkode/fmod_int32/farkode_arkstep_mod.c +++ b/src/arkode/fmod_int32/farkode_arkstep_mod.c @@ -2427,6 +2427,22 @@ SWIGEXPORT void _wrap_FARKStepPrintMem(void *farg1, void *farg2) { } +SWIGEXPORT int _wrap_FARKStepCreateAdjointStepper(void *farg1, N_Vector farg2, void *farg3) { + int fresult ; + void *arg1 = (void *) 0 ; + N_Vector arg2 = (N_Vector) 0 ; + SUNAdjointStepper *arg3 = (SUNAdjointStepper *) 0 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (N_Vector)(farg2); + arg3 = (SUNAdjointStepper *)(farg3); + result = (int)ARKStepCreateAdjointStepper(arg1,arg2,arg3); + fresult = (int)(result); + return fresult; +} + + SWIGEXPORT int _wrap_FARKStepSetRelaxFn(void *farg1, ARKRelaxFn farg2, ARKRelaxJacFn farg3) { int fresult ; void *arg1 = (void *) 0 ; diff --git a/src/arkode/fmod_int32/farkode_arkstep_mod.f90 b/src/arkode/fmod_int32/farkode_arkstep_mod.f90 index b06c17d1ed..9f74d5a201 100644 --- a/src/arkode/fmod_int32/farkode_arkstep_mod.f90 +++ b/src/arkode/fmod_int32/farkode_arkstep_mod.f90 @@ -202,6 +202,7 @@ module farkode_arkstep_mod public :: FARKStepGetLinReturnFlagName public :: FARKStepFree public :: FARKStepPrintMem + public :: FARKStepCreateAdjointStepper public :: FARKStepSetRelaxFn public :: FARKStepSetRelaxEtaFail public :: FARKStepSetRelaxLowerBound @@ -1625,6 +1626,16 @@ subroutine swigc_FARKStepPrintMem(farg1, farg2) & type(C_PTR), value :: farg2 end subroutine +function swigc_FARKStepCreateAdjointStepper(farg1, farg2, farg3) & +bind(C, name="_wrap_FARKStepCreateAdjointStepper") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +type(C_PTR), value :: farg2 +type(C_PTR), value :: farg3 +integer(C_INT) :: fresult +end function + function swigc_FARKStepSetRelaxFn(farg1, farg2, farg3) & bind(C, name="_wrap_FARKStepSetRelaxFn") & result(fresult) @@ -4356,6 +4367,25 @@ subroutine FARKStepPrintMem(arkode_mem, outfile) call swigc_FARKStepPrintMem(farg1, farg2) end subroutine +function FARKStepCreateAdjointStepper(arkode_mem, sf, adj_stepper_ptr) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +type(N_Vector), target, intent(inout) :: sf +type(C_PTR), target, intent(inout) :: adj_stepper_ptr +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +type(C_PTR) :: farg2 +type(C_PTR) :: farg3 + +farg1 = arkode_mem +farg2 = c_loc(sf) +farg3 = c_loc(adj_stepper_ptr) +fresult = swigc_FARKStepCreateAdjointStepper(farg1, farg2, farg3) +swig_result = fresult +end function + function FARKStepSetRelaxFn(arkode_mem, rfn, rjac) & result(swig_result) use, intrinsic :: ISO_C_BINDING diff --git a/src/arkode/fmod_int32/farkode_erkstep_mod.c b/src/arkode/fmod_int32/farkode_erkstep_mod.c index b3e74677f2..0c9c4ead32 100644 --- a/src/arkode/fmod_int32/farkode_erkstep_mod.c +++ b/src/arkode/fmod_int32/farkode_erkstep_mod.c @@ -347,6 +347,22 @@ SWIGEXPORT int _wrap_FERKStepGetTimestepperStats(void *farg1, long *farg2, long } +SWIGEXPORT int _wrap_FERKStepCreateAdjointStepper(void *farg1, N_Vector farg2, void *farg3) { + int fresult ; + void *arg1 = (void *) 0 ; + N_Vector arg2 = (N_Vector) 0 ; + SUNAdjointStepper *arg3 = (SUNAdjointStepper *) 0 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (N_Vector)(farg2); + arg3 = (SUNAdjointStepper *)(farg3); + result = (int)ERKStepCreateAdjointStepper(arg1,arg2,arg3); + fresult = (int)(result); + return fresult; +} + + SWIGEXPORT int _wrap_FERKStepResize(void *farg1, N_Vector farg2, double const *farg3, double const *farg4, ARKVecResizeFn farg5, void *farg6) { int fresult ; void *arg1 = (void *) 0 ; diff --git a/src/arkode/fmod_int32/farkode_erkstep_mod.f90 b/src/arkode/fmod_int32/farkode_erkstep_mod.f90 index 7d719bee04..1e28565da5 100644 --- a/src/arkode/fmod_int32/farkode_erkstep_mod.f90 +++ b/src/arkode/fmod_int32/farkode_erkstep_mod.f90 @@ -46,6 +46,7 @@ module farkode_erkstep_mod public :: FERKStepSetTableName public :: FERKStepGetCurrentButcherTable public :: FERKStepGetTimestepperStats + public :: FERKStepCreateAdjointStepper public :: FERKStepResize public :: FERKStepReset public :: FERKStepSStolerances @@ -205,6 +206,16 @@ function swigc_FERKStepGetTimestepperStats(farg1, farg2, farg3, farg4, farg5, fa integer(C_INT) :: fresult end function +function swigc_FERKStepCreateAdjointStepper(farg1, farg2, farg3) & +bind(C, name="_wrap_FERKStepCreateAdjointStepper") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +type(C_PTR), value :: farg2 +type(C_PTR), value :: farg3 +integer(C_INT) :: fresult +end function + function swigc_FERKStepResize(farg1, farg2, farg3, farg4, farg5, farg6) & bind(C, name="_wrap_FERKStepResize") & result(fresult) @@ -1145,6 +1156,25 @@ function FERKStepGetTimestepperStats(arkode_mem, expsteps, accsteps, step_attemp swig_result = fresult end function +function FERKStepCreateAdjointStepper(arkode_mem, sf, adj_stepper_ptr) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +type(N_Vector), target, intent(inout) :: sf +type(C_PTR), target, intent(inout) :: adj_stepper_ptr +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +type(C_PTR) :: farg2 +type(C_PTR) :: farg3 + +farg1 = arkode_mem +farg2 = c_loc(sf) +farg3 = c_loc(adj_stepper_ptr) +fresult = swigc_FERKStepCreateAdjointStepper(farg1, farg2, farg3) +swig_result = fresult +end function + function FERKStepResize(arkode_mem, ynew, hscale, t0, resize, resize_data) & result(swig_result) use, intrinsic :: ISO_C_BINDING diff --git a/src/arkode/fmod_int32/farkode_mod.c b/src/arkode/fmod_int32/farkode_mod.c index cf23905ff6..9c4481a93c 100644 --- a/src/arkode/fmod_int32/farkode_mod.c +++ b/src/arkode/fmod_int32/farkode_mod.c @@ -1167,6 +1167,34 @@ SWIGEXPORT int _wrap_FARKodeSetMaxNumConstrFails(void *farg1, int const *farg2) } +SWIGEXPORT int _wrap_FARKodeSetAdjointCheckpointScheme(void *farg1, SUNAdjointCheckpointScheme farg2) { + int fresult ; + void *arg1 = (void *) 0 ; + SUNAdjointCheckpointScheme arg2 = (SUNAdjointCheckpointScheme) 0 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (SUNAdjointCheckpointScheme)(farg2); + result = (int)ARKodeSetAdjointCheckpointScheme(arg1,arg2); + fresult = (int)(result); + return fresult; +} + + +SWIGEXPORT int _wrap_FARKodeSetAdjointCheckpointIndex(void *farg1, int64_t const *farg2) { + int fresult ; + void *arg1 = (void *) 0 ; + int64_t arg2 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (int64_t)(*farg2); + result = (int)ARKodeSetAdjointCheckpointIndex(arg1,arg2); + fresult = (int)(result); + return fresult; +} + + SWIGEXPORT int _wrap_FARKodeSetAccumulatedErrorType(void *farg1, int const *farg2) { int fresult ; void *arg1 = (void *) 0 ; diff --git a/src/arkode/fmod_int32/farkode_mod.f90 b/src/arkode/fmod_int32/farkode_mod.f90 index 6aa2552cd8..29d6de6482 100644 --- a/src/arkode/fmod_int32/farkode_mod.f90 +++ b/src/arkode/fmod_int32/farkode_mod.f90 @@ -98,6 +98,8 @@ module farkode_mod integer(C_INT), parameter, public :: ARK_MAX_STAGE_LIMIT_FAIL = -50_C_INT integer(C_INT), parameter, public :: ARK_SUNSTEPPER_ERR = -51_C_INT integer(C_INT), parameter, public :: ARK_STEP_DIRECTION_ERR = -52_C_INT + integer(C_INT), parameter, public :: ARK_ADJ_CHECKPOINT_FAIL = -53_C_INT + integer(C_INT), parameter, public :: ARK_ADJ_RECOMPUTE_FAIL = -54_C_INT integer(C_INT), parameter, public :: ARK_UNRECOGNIZED_ERROR = -99_C_INT ! typedef enum ARKRelaxSolver enum, bind(c) @@ -175,6 +177,8 @@ module farkode_mod public :: FARKodeSetMinStep public :: FARKodeSetMaxStep public :: FARKodeSetMaxNumConstrFails + public :: FARKodeSetAdjointCheckpointScheme + public :: FARKodeSetAdjointCheckpointIndex public :: FARKodeSetAccumulatedErrorType public :: FARKodeResetAccumulatedError public :: FARKodeEvolve @@ -1025,6 +1029,24 @@ function swigc_FARKodeSetMaxNumConstrFails(farg1, farg2) & integer(C_INT) :: fresult end function +function swigc_FARKodeSetAdjointCheckpointScheme(farg1, farg2) & +bind(C, name="_wrap_FARKodeSetAdjointCheckpointScheme") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +type(C_PTR), value :: farg2 +integer(C_INT) :: fresult +end function + +function swigc_FARKodeSetAdjointCheckpointIndex(farg1, farg2) & +bind(C, name="_wrap_FARKodeSetAdjointCheckpointIndex") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +integer(C_INT64_T), intent(in) :: farg2 +integer(C_INT) :: fresult +end function + function swigc_FARKodeSetAccumulatedErrorType(farg1, farg2) & bind(C, name="_wrap_FARKodeSetAccumulatedErrorType") & result(fresult) @@ -3440,6 +3462,38 @@ function FARKodeSetMaxNumConstrFails(arkode_mem, maxfails) & swig_result = fresult end function +function FARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +type(SUNAdjointCheckpointScheme), target, intent(inout) :: checkpoint_scheme +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +type(C_PTR) :: farg2 + +farg1 = arkode_mem +farg2 = c_loc(checkpoint_scheme) +fresult = swigc_FARKodeSetAdjointCheckpointScheme(farg1, farg2) +swig_result = fresult +end function + +function FARKodeSetAdjointCheckpointIndex(arkode_mem, step_index) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +integer(C_INT64_T), intent(in) :: step_index +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +integer(C_INT64_T) :: farg2 + +farg1 = arkode_mem +farg2 = step_index +fresult = swigc_FARKodeSetAdjointCheckpointIndex(farg1, farg2) +swig_result = fresult +end function + function FARKodeSetAccumulatedErrorType(arkode_mem, accum_type) & result(swig_result) use, intrinsic :: ISO_C_BINDING diff --git a/src/arkode/fmod_int64/farkode_arkstep_mod.c b/src/arkode/fmod_int64/farkode_arkstep_mod.c index d41ed70c4d..3122f2caec 100644 --- a/src/arkode/fmod_int64/farkode_arkstep_mod.c +++ b/src/arkode/fmod_int64/farkode_arkstep_mod.c @@ -2427,6 +2427,22 @@ SWIGEXPORT void _wrap_FARKStepPrintMem(void *farg1, void *farg2) { } +SWIGEXPORT int _wrap_FARKStepCreateAdjointStepper(void *farg1, N_Vector farg2, void *farg3) { + int fresult ; + void *arg1 = (void *) 0 ; + N_Vector arg2 = (N_Vector) 0 ; + SUNAdjointStepper *arg3 = (SUNAdjointStepper *) 0 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (N_Vector)(farg2); + arg3 = (SUNAdjointStepper *)(farg3); + result = (int)ARKStepCreateAdjointStepper(arg1,arg2,arg3); + fresult = (int)(result); + return fresult; +} + + SWIGEXPORT int _wrap_FARKStepSetRelaxFn(void *farg1, ARKRelaxFn farg2, ARKRelaxJacFn farg3) { int fresult ; void *arg1 = (void *) 0 ; diff --git a/src/arkode/fmod_int64/farkode_arkstep_mod.f90 b/src/arkode/fmod_int64/farkode_arkstep_mod.f90 index b06c17d1ed..9f74d5a201 100644 --- a/src/arkode/fmod_int64/farkode_arkstep_mod.f90 +++ b/src/arkode/fmod_int64/farkode_arkstep_mod.f90 @@ -202,6 +202,7 @@ module farkode_arkstep_mod public :: FARKStepGetLinReturnFlagName public :: FARKStepFree public :: FARKStepPrintMem + public :: FARKStepCreateAdjointStepper public :: FARKStepSetRelaxFn public :: FARKStepSetRelaxEtaFail public :: FARKStepSetRelaxLowerBound @@ -1625,6 +1626,16 @@ subroutine swigc_FARKStepPrintMem(farg1, farg2) & type(C_PTR), value :: farg2 end subroutine +function swigc_FARKStepCreateAdjointStepper(farg1, farg2, farg3) & +bind(C, name="_wrap_FARKStepCreateAdjointStepper") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +type(C_PTR), value :: farg2 +type(C_PTR), value :: farg3 +integer(C_INT) :: fresult +end function + function swigc_FARKStepSetRelaxFn(farg1, farg2, farg3) & bind(C, name="_wrap_FARKStepSetRelaxFn") & result(fresult) @@ -4356,6 +4367,25 @@ subroutine FARKStepPrintMem(arkode_mem, outfile) call swigc_FARKStepPrintMem(farg1, farg2) end subroutine +function FARKStepCreateAdjointStepper(arkode_mem, sf, adj_stepper_ptr) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +type(N_Vector), target, intent(inout) :: sf +type(C_PTR), target, intent(inout) :: adj_stepper_ptr +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +type(C_PTR) :: farg2 +type(C_PTR) :: farg3 + +farg1 = arkode_mem +farg2 = c_loc(sf) +farg3 = c_loc(adj_stepper_ptr) +fresult = swigc_FARKStepCreateAdjointStepper(farg1, farg2, farg3) +swig_result = fresult +end function + function FARKStepSetRelaxFn(arkode_mem, rfn, rjac) & result(swig_result) use, intrinsic :: ISO_C_BINDING diff --git a/src/arkode/fmod_int64/farkode_erkstep_mod.c b/src/arkode/fmod_int64/farkode_erkstep_mod.c index b3e74677f2..0c9c4ead32 100644 --- a/src/arkode/fmod_int64/farkode_erkstep_mod.c +++ b/src/arkode/fmod_int64/farkode_erkstep_mod.c @@ -347,6 +347,22 @@ SWIGEXPORT int _wrap_FERKStepGetTimestepperStats(void *farg1, long *farg2, long } +SWIGEXPORT int _wrap_FERKStepCreateAdjointStepper(void *farg1, N_Vector farg2, void *farg3) { + int fresult ; + void *arg1 = (void *) 0 ; + N_Vector arg2 = (N_Vector) 0 ; + SUNAdjointStepper *arg3 = (SUNAdjointStepper *) 0 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (N_Vector)(farg2); + arg3 = (SUNAdjointStepper *)(farg3); + result = (int)ERKStepCreateAdjointStepper(arg1,arg2,arg3); + fresult = (int)(result); + return fresult; +} + + SWIGEXPORT int _wrap_FERKStepResize(void *farg1, N_Vector farg2, double const *farg3, double const *farg4, ARKVecResizeFn farg5, void *farg6) { int fresult ; void *arg1 = (void *) 0 ; diff --git a/src/arkode/fmod_int64/farkode_erkstep_mod.f90 b/src/arkode/fmod_int64/farkode_erkstep_mod.f90 index 7d719bee04..1e28565da5 100644 --- a/src/arkode/fmod_int64/farkode_erkstep_mod.f90 +++ b/src/arkode/fmod_int64/farkode_erkstep_mod.f90 @@ -46,6 +46,7 @@ module farkode_erkstep_mod public :: FERKStepSetTableName public :: FERKStepGetCurrentButcherTable public :: FERKStepGetTimestepperStats + public :: FERKStepCreateAdjointStepper public :: FERKStepResize public :: FERKStepReset public :: FERKStepSStolerances @@ -205,6 +206,16 @@ function swigc_FERKStepGetTimestepperStats(farg1, farg2, farg3, farg4, farg5, fa integer(C_INT) :: fresult end function +function swigc_FERKStepCreateAdjointStepper(farg1, farg2, farg3) & +bind(C, name="_wrap_FERKStepCreateAdjointStepper") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +type(C_PTR), value :: farg2 +type(C_PTR), value :: farg3 +integer(C_INT) :: fresult +end function + function swigc_FERKStepResize(farg1, farg2, farg3, farg4, farg5, farg6) & bind(C, name="_wrap_FERKStepResize") & result(fresult) @@ -1145,6 +1156,25 @@ function FERKStepGetTimestepperStats(arkode_mem, expsteps, accsteps, step_attemp swig_result = fresult end function +function FERKStepCreateAdjointStepper(arkode_mem, sf, adj_stepper_ptr) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +type(N_Vector), target, intent(inout) :: sf +type(C_PTR), target, intent(inout) :: adj_stepper_ptr +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +type(C_PTR) :: farg2 +type(C_PTR) :: farg3 + +farg1 = arkode_mem +farg2 = c_loc(sf) +farg3 = c_loc(adj_stepper_ptr) +fresult = swigc_FERKStepCreateAdjointStepper(farg1, farg2, farg3) +swig_result = fresult +end function + function FERKStepResize(arkode_mem, ynew, hscale, t0, resize, resize_data) & result(swig_result) use, intrinsic :: ISO_C_BINDING diff --git a/src/arkode/fmod_int64/farkode_mod.c b/src/arkode/fmod_int64/farkode_mod.c index 0c4024af1f..11f19fabf1 100644 --- a/src/arkode/fmod_int64/farkode_mod.c +++ b/src/arkode/fmod_int64/farkode_mod.c @@ -1167,6 +1167,34 @@ SWIGEXPORT int _wrap_FARKodeSetMaxNumConstrFails(void *farg1, int const *farg2) } +SWIGEXPORT int _wrap_FARKodeSetAdjointCheckpointScheme(void *farg1, SUNAdjointCheckpointScheme farg2) { + int fresult ; + void *arg1 = (void *) 0 ; + SUNAdjointCheckpointScheme arg2 = (SUNAdjointCheckpointScheme) 0 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (SUNAdjointCheckpointScheme)(farg2); + result = (int)ARKodeSetAdjointCheckpointScheme(arg1,arg2); + fresult = (int)(result); + return fresult; +} + + +SWIGEXPORT int _wrap_FARKodeSetAdjointCheckpointIndex(void *farg1, int64_t const *farg2) { + int fresult ; + void *arg1 = (void *) 0 ; + int64_t arg2 ; + int result; + + arg1 = (void *)(farg1); + arg2 = (int64_t)(*farg2); + result = (int)ARKodeSetAdjointCheckpointIndex(arg1,arg2); + fresult = (int)(result); + return fresult; +} + + SWIGEXPORT int _wrap_FARKodeSetAccumulatedErrorType(void *farg1, int const *farg2) { int fresult ; void *arg1 = (void *) 0 ; diff --git a/src/arkode/fmod_int64/farkode_mod.f90 b/src/arkode/fmod_int64/farkode_mod.f90 index f29b57f5c3..383074a657 100644 --- a/src/arkode/fmod_int64/farkode_mod.f90 +++ b/src/arkode/fmod_int64/farkode_mod.f90 @@ -98,6 +98,8 @@ module farkode_mod integer(C_INT), parameter, public :: ARK_MAX_STAGE_LIMIT_FAIL = -50_C_INT integer(C_INT), parameter, public :: ARK_SUNSTEPPER_ERR = -51_C_INT integer(C_INT), parameter, public :: ARK_STEP_DIRECTION_ERR = -52_C_INT + integer(C_INT), parameter, public :: ARK_ADJ_CHECKPOINT_FAIL = -53_C_INT + integer(C_INT), parameter, public :: ARK_ADJ_RECOMPUTE_FAIL = -54_C_INT integer(C_INT), parameter, public :: ARK_UNRECOGNIZED_ERROR = -99_C_INT ! typedef enum ARKRelaxSolver enum, bind(c) @@ -175,6 +177,8 @@ module farkode_mod public :: FARKodeSetMinStep public :: FARKodeSetMaxStep public :: FARKodeSetMaxNumConstrFails + public :: FARKodeSetAdjointCheckpointScheme + public :: FARKodeSetAdjointCheckpointIndex public :: FARKodeSetAccumulatedErrorType public :: FARKodeResetAccumulatedError public :: FARKodeEvolve @@ -1025,6 +1029,24 @@ function swigc_FARKodeSetMaxNumConstrFails(farg1, farg2) & integer(C_INT) :: fresult end function +function swigc_FARKodeSetAdjointCheckpointScheme(farg1, farg2) & +bind(C, name="_wrap_FARKodeSetAdjointCheckpointScheme") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +type(C_PTR), value :: farg2 +integer(C_INT) :: fresult +end function + +function swigc_FARKodeSetAdjointCheckpointIndex(farg1, farg2) & +bind(C, name="_wrap_FARKodeSetAdjointCheckpointIndex") & +result(fresult) +use, intrinsic :: ISO_C_BINDING +type(C_PTR), value :: farg1 +integer(C_INT64_T), intent(in) :: farg2 +integer(C_INT) :: fresult +end function + function swigc_FARKodeSetAccumulatedErrorType(farg1, farg2) & bind(C, name="_wrap_FARKodeSetAccumulatedErrorType") & result(fresult) @@ -3440,6 +3462,38 @@ function FARKodeSetMaxNumConstrFails(arkode_mem, maxfails) & swig_result = fresult end function +function FARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +type(SUNAdjointCheckpointScheme), target, intent(inout) :: checkpoint_scheme +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +type(C_PTR) :: farg2 + +farg1 = arkode_mem +farg2 = c_loc(checkpoint_scheme) +fresult = swigc_FARKodeSetAdjointCheckpointScheme(farg1, farg2) +swig_result = fresult +end function + +function FARKodeSetAdjointCheckpointIndex(arkode_mem, step_index) & +result(swig_result) +use, intrinsic :: ISO_C_BINDING +integer(C_INT) :: swig_result +type(C_PTR) :: arkode_mem +integer(C_INT64_T), intent(in) :: step_index +integer(C_INT) :: fresult +type(C_PTR) :: farg1 +integer(C_INT64_T) :: farg2 + +farg1 = arkode_mem +farg2 = step_index +fresult = swigc_FARKodeSetAdjointCheckpointIndex(farg1, farg2) +swig_result = fresult +end function + function FARKodeSetAccumulatedErrorType(arkode_mem, accum_type) & result(swig_result) use, intrinsic :: ISO_C_BINDING diff --git a/test/unit_tests/arkode/CXX_serial/CMakeLists.txt b/test/unit_tests/arkode/CXX_serial/CMakeLists.txt index 1fd8a4e07e..e5dd8a1bdf 100644 --- a/test/unit_tests/arkode/CXX_serial/CMakeLists.txt +++ b/test/unit_tests/arkode/CXX_serial/CMakeLists.txt @@ -54,7 +54,26 @@ set(unit_tests "ark_test_slowerror_brusselator.cpp\;\;exclude-single" "ark_test_slowerror_kpr.cpp\;\;exclude-single" "ark_test_slowerror_polynomial.cpp\;\;exclude-single" - "ark_test_splittingstep.cpp\;\;") + "ark_test_splittingstep.cpp\;\;" + "ark_test_adjoint_erk.cpp\;--check-freq 1\;" + "ark_test_adjoint_erk.cpp\;--check-freq 2\;" + "ark_test_adjoint_erk.cpp\;--check-freq 5\;" + "ark_test_adjoint_erk.cpp\;--check-freq 1 --dont-keep\;" + "ark_test_adjoint_erk.cpp\;--check-freq 2 --dont-keep\;" + "ark_test_adjoint_erk.cpp\;--check-freq 5 --dont-keep\;" + "ark_test_adjoint_ark.cpp\;--check-freq 1\;" + "ark_test_adjoint_ark.cpp\;--check-freq 2\;" + "ark_test_adjoint_ark.cpp\;--check-freq 5\;" + "ark_test_adjoint_ark.cpp\;--check-freq 1 --dont-keep\;" + "ark_test_adjoint_ark.cpp\;--check-freq 2 --dont-keep\;" + "ark_test_adjoint_ark.cpp\;--check-freq 5 --dont-keep\;" + # "ark_test_adjoint_erk.cpp\;--check-freq 1 --no-stages\;" + # "ark_test_adjoint_erk.cpp\;--check-freq 2 --no-stages\;" + # "ark_test_adjoint_erk.cpp\;--check-freq 5 --no-stages\;" + # "ark_test_adjoint_erk.cpp\;--check-freq 1 --dont-keep --no-stages\;" + # "ark_test_adjoint_erk.cpp\;--check-freq 2 --dont-keep --no-stages\;" + # "ark_test_adjoint_erk.cpp\;--check-freq 5 --dont-keep --no-stages\;" +) # Add the build and install targets for each test foreach(test_tuple ${unit_tests}) @@ -81,7 +100,8 @@ foreach(test_tuple ${unit_tests}) target_include_directories( ${test_target} PRIVATE $ - ${CMAKE_SOURCE_DIR}/include ${CMAKE_SOURCE_DIR}/src) + ${CMAKE_SOURCE_DIR}/include ${CMAKE_SOURCE_DIR}/src + ${CMAKE_SOURCE_DIR}/test/unit_tests) # We explicitly choose which object libraries to link to and link in the # arkode objects so that we have access to private functions w/o changing @@ -91,6 +111,7 @@ foreach(test_tuple ${unit_tests}) $ sundials_sunmemsys_obj sundials_nvecserial_obj + sundials_nvecmanyvector_obj sundials_sunlinsolband_obj sundials_sunlinsoldense_obj sundials_sunnonlinsolnewton_obj @@ -98,6 +119,7 @@ foreach(test_tuple ${unit_tests}) sundials_sunadaptcontrollerimexgus_obj sundials_sunadaptcontrollersoderlind_obj sundials_sunadaptcontrollermrihtol_obj + sundials_adjointcheckpointscheme_fixed_obj ${EXE_EXTRA_LINK_LIBS}) # Tell CMake that we depend on the ARKODE library since it does not pick diff --git a/test/unit_tests/arkode/CXX_serial/ark_test_adjoint_ark.cpp b/test/unit_tests/arkode/CXX_serial/ark_test_adjoint_ark.cpp new file mode 100644 index 0000000000..b834940786 --- /dev/null +++ b/test/unit_tests/arkode/CXX_serial/ark_test_adjoint_ark.cpp @@ -0,0 +1,542 @@ +/* ----------------------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2002-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------------------- + * Program to test the SUNAdjoint capability with ARKODE. The test uses the + * implements the four parameter Lotka-Volterra problem + * + * u' = [dx/dt] = [ p_0*x - p_1*x*y ] + * [dy/dt] [ -p_2*y + p_3*x*y ]. + * + * The initial condition is u(t_0) = 1.0 and we use the parameters + * p = [1.5, 1.0, 3.0, 1.0]. We compute the sensitivities for the scalar cost + * function, + * + * g(u(t_f), p) = || 1 - u(t_f, p) ||^2 / 2 + * + * with respect to the initial condition and the parameters. + * ---------------------------------------------------------------------------*/ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "problems/lotka_volterra.hpp" + +#if defined(SUNDIALS_SINGLE_PRECISION) +#define FWD_TOL SUN_RCONST(1e-2) +#elif defined(SUNDIALS_DOUBLE_PRECISION) +#define FWD_TOL SUN_RCONST(1e-4) +#elif defined(SUNDIALS_EXTENDED_PRECISION) +#define FWD_TOL SUN_RCONST(1e-6) +#endif + +#define ADJ_TOL SUN_RCONST(1e-2) + +using namespace problems::lotka_volterra; + +typedef struct +{ + sunrealtype tf; + sunrealtype dt; + int order; + int check_freq; + sunbooleantype save_stages; + sunbooleantype keep_checks; +} ProgramArgs; + +static sunrealtype params[4] = {SUN_RCONST(1.5), SUN_RCONST(1.0), + SUN_RCONST(3.0), SUN_RCONST(1.0)}; + +static int check_forward_answer(N_Vector answer) +{ + const sunrealtype u1 = SUN_RCONST(2.77266836); + const sunrealtype u2 = SUN_RCONST(0.258714765); + sunrealtype* ans = N_VGetArrayPointer(answer); + + if (SUNRCompareTol(ans[0], u1, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[0] = %g, should be %g\n", ans[0], u1); + return -1; + }; + if (SUNRCompareTol(ans[1], u2, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[1] = %g, should be %g\n", ans[1], u2); + return -1; + }; + + return 0; +} + +static int check_forward_backward_answer(N_Vector answer) +{ + const sunrealtype u1 = SUN_RCONST(1.0); + const sunrealtype u2 = SUN_RCONST(1.0); + sunrealtype* ans = N_VGetArrayPointer(answer); + + if (SUNRCompareTol(ans[0], u1, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[0] = %g, should be %g\n", ans[0], u1); + return -1; + }; + if (SUNRCompareTol(ans[1], u2, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[1] = %g, should be %g\n", ans[1], u2); + return -1; + }; + + return 0; +} + +static int check_sensitivities(N_Vector answer) +{ + // The correct answer was generated with the Julia ForwardDiff.jl + // automatic differentiation package. + + const sunrealtype lambda[2] = { + SUN_RCONST(3.5202568952661544), + -SUN_RCONST(2.19271337646507), + }; + + const sunrealtype mu[4] = {SUN_RCONST(4.341147542533404), + -SUN_RCONST(2.000933816791803), + SUN_RCONST(1.010120676762905), + -SUN_RCONST(1.3955943267337996)}; + + sunrealtype* ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 0); + + for (sunindextype i = 0; i < 2; ++i) + { + if (SUNRCompareTol(ans[i], lambda[i], ADJ_TOL)) + { + fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, + ans[i], lambda[i]); + return -1; + }; + } + + ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 1); + + for (sunindextype i = 0; i < 4; ++i) + { + if (SUNRCompareTol(ans[i], mu[i], ADJ_TOL)) + { + fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, + ans[i], mu[i]); + return -1; + }; + } + + return 0; +} + +// static inline int check_sensitivities_backward(N_Vector answer) +// { +// // The correct answer was generated with the Julia ForwardDiff.jl +// // automatic differentiation package. + +// const sunrealtype lambda[2] = { +// SUN_RCONST(1.772850901841113), +// -SUN_RCONST(0.7412891218574361), +// }; + +// const sunrealtype mu[4] = {SUN_RCONST(0.0), SUN_RCONST(0.0), SUN_RCONST(0.0), +// SUN_RCONST(0.0)}; + +// sunrealtype* ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 0); + +// for (sunindextype i = 0; i < 2; ++i) +// { +// if (SUNRCompareTol(ans[i], lambda[i], ADJ_TOL)) +// { +// fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, +// ans[i], lambda[i]); +// return -1; +// }; +// } + +// ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 1); + +// for (sunindextype i = 0; i < 4; ++i) +// { +// if (SUNRCompareTol(ans[i], mu[i], ADJ_TOL)) +// { +// fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, +// ans[i], mu[i]); +// return -1; +// }; +// } + +// return 0; +// } + +static void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, + sunrealtype t) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = -SUN_RCONST(1.0) + u[0]; + dg[1] = -SUN_RCONST(1.0) + u[1]; +} + +static void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, + sunrealtype t) +{ + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = SUN_RCONST(0.0); + dg[1] = SUN_RCONST(0.0); + dg[2] = SUN_RCONST(0.0); + dg[3] = SUN_RCONST(0.0); +} + +static int forward_solution(SUNContext sunctx, void* arkode_mem, + SUNAdjointCheckpointScheme checkpoint_scheme, + const sunrealtype t0, const sunrealtype tf, + const sunrealtype dt, N_Vector u) +{ + int retval = 0; + + retval = ARKodeSetUserData(arkode_mem, (void*)params); + retval = ARKodeSetFixedStep(arkode_mem, dt); + + sunrealtype t = t0; + retval = ARKodeEvolve(arkode_mem, tf, u, &t, ARK_NORMAL); + if (retval < 0) + { + fprintf(stderr, ">>> ERROR: ARKodeEvolve returned %d\n", retval); + return -1; + } + + printf("Forward Solution:\n"); + N_VPrint(u); + + printf("ARKODE Stats for Forward Solution:\n"); + ARKodePrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE); + printf("\n"); + + return 0; +} + +static int adjoint_solution(SUNContext sunctx, SUNAdjointStepper adj_stepper, + SUNAdjointCheckpointScheme checkpoint_scheme, + const sunrealtype tf, const sunrealtype tout, + N_Vector sf) +{ + sunrealtype t = tf; + SUNAdjointStepper_Evolve(adj_stepper, tout, sf, &t); + + printf("Adjoint Solution:\n"); + N_VPrint(sf); + + printf("\nSUNAdjointStepper Stats:\n"); + SUNAdjointStepper_PrintAllStats(adj_stepper, stdout, SUN_OUTPUTFORMAT_TABLE); + printf("\n"); + + return 0; +} + +static void print_help(int argc, char* argv[], int exit_code) +{ + if (exit_code) { fprintf(stderr, "%s: option not recognized\n", argv[0]); } + else { fprintf(stderr, "%s ", argv[0]); } + fprintf(stderr, "options:\n"); + fprintf(stderr, "--tf the final simulation time\n"); + fprintf(stderr, "--dt the timestep size\n"); + fprintf(stderr, "--order the order of the RK method\n"); + fprintf(stderr, "--check-freq how often to checkpoint (in steps)\n"); + fprintf(stderr, "--no-stages don't checkpoint stages\n"); + fprintf(stderr, + "--dont-keep don't keep checkpoints around after loading\n"); + fprintf(stderr, "--help print these options\n"); + exit(exit_code); +} + +static void parse_args(int argc, char* argv[], ProgramArgs* args) +{ + for (int argi = 1; argi < argc; ++argi) + { + const char* arg = argv[argi]; + if (!strcmp(arg, "--tf")) { args->tf = atof(argv[++argi]); } + else if (!strcmp(arg, "--dt")) { args->dt = atof(argv[++argi]); } + else if (!strcmp(arg, "--order")) { args->order = atoi(argv[++argi]); } + else if (!strcmp(arg, "--check-freq")) + { + args->check_freq = atoi(argv[++argi]); + } + else if (!strcmp(arg, "--no-stages")) { args->save_stages = SUNFALSE; } + else if (!strcmp(arg, "--dont-keep")) { args->keep_checks = SUNFALSE; } + else if (!strcmp(arg, "--help")) { print_help(argc, argv, 0); } + else { print_help(argc, argv, 1); } + } +} + +int main(int argc, char* argv[]) +{ + SUNContext sunctx = NULL; + SUNContext_Create(SUN_COMM_NULL, &sunctx); + + // Since this a unit test, we want to abort immediately on any internal error + SUNContext_PushErrHandler(sunctx, SUNAbortErrHandlerFn, NULL); + + ProgramArgs args; + args.tf = SUN_RCONST(1.0); + args.dt = SUN_RCONST(1e-4); + args.order = 4; + args.save_stages = SUNTRUE; + args.keep_checks = SUNTRUE; + args.check_freq = 2; + parse_args(argc, argv, &args); + + // + // Create the initial conditions vector + // + + sunindextype neq = 2; + N_Vector u = N_VNew_Serial(neq, sunctx); + N_VConst(SUN_RCONST(1.0), u); + + // + // Create the ARKODE stepper that will be used for the forward evolution. + // + + const sunrealtype dt = args.dt; + sunrealtype t0 = SUN_RCONST(0.0); + sunrealtype tf = args.tf; + const int nsteps = (int)ceil(((tf - t0) / dt + 1)); + const int order = args.order; + + void* arkode_mem = ARKStepCreate(ode_rhs, NULL, t0, u, sunctx); + ARKodeSetOrder(arkode_mem, order); + ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2); + + // Enable checkpointing during the forward solution. + // ncheck will be more than nsteps, but for testing purposes we try setting it + // to nsteps and allow things to be resized automatically. + const int check_interval = args.check_freq; + const int ncheck = nsteps; + const sunbooleantype save_stages = args.save_stages; + const sunbooleantype keep_check = args.keep_checks; + SUNAdjointCheckpointScheme checkpoint_scheme = NULL; + SUNMemoryHelper mem_helper = SUNMemoryHelper_Sys(sunctx); + SUNAdjointCheckpointScheme_Create_Fixed(SUNDATAIOMODE_INMEM, mem_helper, + check_interval, ncheck, save_stages, + keep_check, sunctx, &checkpoint_scheme); + ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme); + + // + // Compute the forward solution + // + + printf("\n-- Do forward problem --\n\n"); + + printf("Initial condition:\n"); + N_VPrint(u); + + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u); + if (check_forward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + }; + printf(">>> PASS\n"); + + // + // Create the adjoint stepper + // + + printf("\n-- Do adjoint problem using Jacobian matrix --\n\n"); + + sunindextype num_params = 4; + N_Vector sensu0 = N_VClone(u); + N_Vector sensp = N_VNew_Serial(num_params, sunctx); + N_Vector sens[2] = {sensu0, sensp}; + N_Vector sf = N_VNew_ManyVector(2, sens, sunctx); + + // Set the terminal condition for the adjoint system, which + // should be the the gradient of our cost function at tf. + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + + printf("Adjoint terminal condition:\n"); + N_VPrint(sf); + + SUNAdjointStepper adj_stepper; + ARKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper); + + // + // Now compute the adjoint solution + // + + SUNMatrix jac = SUNDenseMatrix(neq, neq, sunctx); + SUNMatrix jacp = SUNDenseMatrix(neq, num_params, sunctx); + + SUNAdjointStepper_SetJacFn(adj_stepper, ode_jac, jac, parameter_jacobian, jacp); + + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + if (check_sensitivities(sf)) + { + fprintf(stderr, + ">>> FAILURE: adjoint solution does not match correct answer\n"); + return -1; + } + printf("\n>>> PASS\n"); + + // + // Now compute the adjoint solution using Jvp + // + + printf("\n-- Redo adjoint problem using JVP --\n\n"); + if (!keep_check) + { + N_VConst(SUN_RCONST(1.0), u); + printf("Initial condition:\n"); + N_VPrint(u); + ARKStepReInit(arkode_mem, ode_rhs, NULL, t0, u); + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u); + if (check_forward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + } + } + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + SUNAdjointStepper_ReInit(adj_stepper, u, t0, sf, tf); + SUNAdjointStepper_SetJacFn(adj_stepper, NULL, NULL, NULL, NULL); + SUNAdjointStepper_SetJacTimesVecFn(adj_stepper, ode_jvp, parameter_jvp); + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + if (check_sensitivities(sf)) + { + fprintf(stderr, + ">>> FAILURE: adjoint solution does not match correct answer\n"); + return -1; + }; + printf("\n>>> PASS\n"); + + // + // Now compute the adjoint solution using vJp + // + + printf("\n-- Redo adjoint problem using VJP --\n\n"); + if (!keep_check) + { + N_VConst(SUN_RCONST(1.0), u); + printf("Initial condition:\n"); + N_VPrint(u); + ARKStepReInit(arkode_mem, ode_rhs, NULL, t0, u); + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u); + if (check_forward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + }; + } + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + SUNAdjointStepper_ReInit(adj_stepper, u, t0, sf, tf); + SUNAdjointStepper_SetJacTimesVecFn(adj_stepper, NULL, NULL); + SUNAdjointStepper_SetVecTimesJacFn(adj_stepper, ode_vjp, parameter_vjp); + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + if (check_sensitivities(sf)) + { + fprintf(stderr, + ">>> FAILURE: adjoint solution does not match correct answer\n"); + return -1; + }; + printf(">>> PASS\n"); + + // + // Now compute the adjoint solution but for when forward problem done backwards + // starting with the forward solution. + // + + printf("\n-- Redo adjoint problem of forward problem done backwards --\n\n"); + + // Swap the start and end times + sunrealtype tmp = t0; + t0 = tf; + tf = tmp; + + // Cleanup from the original forward problem and then recreate the integrator + // for the forward problem done backwards. + SUNAdjointCheckpointScheme_Destroy(&checkpoint_scheme); + SUNAdjointStepper_Destroy(&adj_stepper); + ARKodeFree(&arkode_mem); + arkode_mem = ARKStepCreate(ode_rhs, NULL, t0, u, sunctx); + ARKodeSetOrder(arkode_mem, order); + ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2); + SUNAdjointCheckpointScheme_Create_Fixed(SUNDATAIOMODE_INMEM, mem_helper, + check_interval, ncheck, save_stages, + keep_check, sunctx, &checkpoint_scheme); + ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme); + + printf("Initial condition:\n"); + N_VPrint(u); + + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, -dt, u); + if (check_forward_backward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + }; + + ARKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper); + SUNAdjointStepper_SetJacFn(adj_stepper, ode_jac, jac, parameter_jacobian, jacp); + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + // TODO(CJB): figure out why ForwardDiff, CVODES, and ERK adjoint all differ + // if (check_sensitivities_backward(sf)) + // { + // fprintf(stderr, + // ">>> FAILURE: adjoint solution does not match correct answer\n"); + // return -1; + // }; + // printf(">>> PASS\n"); + + // + // Cleanup + // + + // adjoint related + SUNMatDestroy(jac); + SUNMatDestroy(jacp); + N_VDestroy(sensu0); + N_VDestroy(sensp); + N_VDestroy(sf); + SUNAdjointCheckpointScheme_Destroy(&checkpoint_scheme); + SUNAdjointStepper_Destroy(&adj_stepper); + SUNMemoryHelper_Destroy(mem_helper); + // forward and adjoint related + N_VDestroy(u); + ARKodeFree(&arkode_mem); + SUNContext_Free(&sunctx); + + return 0; +} diff --git a/test/unit_tests/arkode/CXX_serial/ark_test_adjoint_erk.cpp b/test/unit_tests/arkode/CXX_serial/ark_test_adjoint_erk.cpp new file mode 100644 index 0000000000..bdf359aa98 --- /dev/null +++ b/test/unit_tests/arkode/CXX_serial/ark_test_adjoint_erk.cpp @@ -0,0 +1,542 @@ +/* ----------------------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2002-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------------------- + * Program to test the SUNAdjoint capability with ARKODE. The test uses the + * implements the four parameter Lotka-Volterra problem + * + * u' = [dx/dt] = [ p_0*x - p_1*x*y ] + * [dy/dt] [ -p_2*y + p_3*x*y ]. + * + * The initial condition is u(t_0) = 1.0 and we use the parameters + * p = [1.5, 1.0, 3.0, 1.0]. We compute the sensitivities for the scalar cost + * function, + * + * g(u(t_f), p) = || 1 - u(t_f, p) ||^2 / 2 + * + * with respect to the initial condition and the parameters. + * ---------------------------------------------------------------------------*/ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "problems/lotka_volterra.hpp" + +#if defined(SUNDIALS_SINGLE_PRECISION) +#define FWD_TOL SUN_RCONST(1e-2) +#elif defined(SUNDIALS_DOUBLE_PRECISION) +#define FWD_TOL SUN_RCONST(1e-4) +#elif defined(SUNDIALS_EXTENDED_PRECISION) +#define FWD_TOL SUN_RCONST(1e-6) +#endif + +#define ADJ_TOL SUN_RCONST(1e-2) + +using namespace problems::lotka_volterra; + +typedef struct +{ + sunrealtype tf; + sunrealtype dt; + int order; + int check_freq; + sunbooleantype save_stages; + sunbooleantype keep_checks; +} ProgramArgs; + +static sunrealtype params[4] = {SUN_RCONST(1.5), SUN_RCONST(1.0), + SUN_RCONST(3.0), SUN_RCONST(1.0)}; + +static int check_forward_answer(N_Vector answer) +{ + const sunrealtype u1 = SUN_RCONST(2.77266836); + const sunrealtype u2 = SUN_RCONST(0.258714765); + sunrealtype* ans = N_VGetArrayPointer(answer); + + if (SUNRCompareTol(ans[0], u1, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[0] = %g, should be %g\n", ans[0], u1); + return -1; + }; + if (SUNRCompareTol(ans[1], u2, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[1] = %g, should be %g\n", ans[1], u2); + return -1; + }; + + return 0; +} + +static int check_forward_backward_answer(N_Vector answer) +{ + const sunrealtype u1 = SUN_RCONST(1.0); + const sunrealtype u2 = SUN_RCONST(1.0); + sunrealtype* ans = N_VGetArrayPointer(answer); + + if (SUNRCompareTol(ans[0], u1, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[0] = %g, should be %g\n", ans[0], u1); + return -1; + }; + if (SUNRCompareTol(ans[1], u2, FWD_TOL)) + { + fprintf(stdout, "\n>>> ans[1] = %g, should be %g\n", ans[1], u2); + return -1; + }; + + return 0; +} + +static int check_sensitivities(N_Vector answer) +{ + // The correct answer was generated with the Julia ForwardDiff.jl + // automatic differentiation package. + + const sunrealtype lambda[2] = { + SUN_RCONST(3.5202568952661544), + -SUN_RCONST(2.19271337646507), + }; + + const sunrealtype mu[4] = {SUN_RCONST(4.341147542533404), + -SUN_RCONST(2.000933816791803), + SUN_RCONST(1.010120676762905), + -SUN_RCONST(1.3955943267337996)}; + + sunrealtype* ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 0); + + for (sunindextype i = 0; i < 2; ++i) + { + if (SUNRCompareTol(ans[i], lambda[i], ADJ_TOL)) + { + fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, + ans[i], lambda[i]); + return -1; + }; + } + + ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 1); + + for (sunindextype i = 0; i < 4; ++i) + { + if (SUNRCompareTol(ans[i], mu[i], ADJ_TOL)) + { + fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, + ans[i], mu[i]); + return -1; + }; + } + + return 0; +} + +// static inline int check_sensitivities_backward(N_Vector answer) +// { +// // The correct answer was generated with the Julia ForwardDiff.jl +// // automatic differentiation package. + +// const sunrealtype lambda[2] = { +// SUN_RCONST(1.772850901841113), +// -SUN_RCONST(0.7412891218574361), +// }; + +// const sunrealtype mu[4] = {SUN_RCONST(0.0), SUN_RCONST(0.0), SUN_RCONST(0.0), +// SUN_RCONST(0.0)}; + +// sunrealtype* ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 0); + +// for (sunindextype i = 0; i < 2; ++i) +// { +// if (SUNRCompareTol(ans[i], lambda[i], ADJ_TOL)) +// { +// fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, +// ans[i], lambda[i]); +// return -1; +// }; +// } + +// ans = N_VGetSubvectorArrayPointer_ManyVector(answer, 1); + +// for (sunindextype i = 0; i < 4; ++i) +// { +// if (SUNRCompareTol(ans[i], mu[i], ADJ_TOL)) +// { +// fprintf(stdout, "\n>>> ans[%lld] = %g, should be %g\n", (long long)i, +// ans[i], mu[i]); +// return -1; +// }; +// } + +// return 0; +// } + +static void dgdu(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, + sunrealtype t) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = -SUN_RCONST(1.0) + u[0]; + dg[1] = -SUN_RCONST(1.0) + u[1]; +} + +static void dgdp(N_Vector uvec, N_Vector dgvec, const sunrealtype* p, + sunrealtype t) +{ + sunrealtype* dg = N_VGetArrayPointer(dgvec); + + dg[0] = SUN_RCONST(0.0); + dg[1] = SUN_RCONST(0.0); + dg[2] = SUN_RCONST(0.0); + dg[3] = SUN_RCONST(0.0); +} + +static int forward_solution(SUNContext sunctx, void* arkode_mem, + SUNAdjointCheckpointScheme checkpoint_scheme, + const sunrealtype t0, const sunrealtype tf, + const sunrealtype dt, N_Vector u) +{ + int retval = 0; + + retval = ARKodeSetUserData(arkode_mem, (void*)params); + retval = ARKodeSetFixedStep(arkode_mem, dt); + + sunrealtype t = t0; + retval = ARKodeEvolve(arkode_mem, tf, u, &t, ARK_NORMAL); + if (retval < 0) + { + fprintf(stderr, ">>> ERROR: ARKodeEvolve returned %d\n", retval); + return -1; + } + + printf("Forward Solution:\n"); + N_VPrint(u); + + printf("ARKODE Stats for Forward Solution:\n"); + ARKodePrintAllStats(arkode_mem, stdout, SUN_OUTPUTFORMAT_TABLE); + printf("\n"); + + return 0; +} + +static int adjoint_solution(SUNContext sunctx, SUNAdjointStepper adj_stepper, + SUNAdjointCheckpointScheme checkpoint_scheme, + const sunrealtype tf, const sunrealtype tout, + N_Vector sf) +{ + sunrealtype t = tf; + SUNAdjointStepper_Evolve(adj_stepper, tout, sf, &t); + + printf("Adjoint Solution:\n"); + N_VPrint(sf); + + printf("\nSUNAdjointStepper Stats:\n"); + SUNAdjointStepper_PrintAllStats(adj_stepper, stdout, SUN_OUTPUTFORMAT_TABLE); + printf("\n"); + + return 0; +} + +static void print_help(int argc, char* argv[], int exit_code) +{ + if (exit_code) { fprintf(stderr, "%s: option not recognized\n", argv[0]); } + else { fprintf(stderr, "%s ", argv[0]); } + fprintf(stderr, "options:\n"); + fprintf(stderr, "--tf the final simulation time\n"); + fprintf(stderr, "--dt the timestep size\n"); + fprintf(stderr, "--order the order of the RK method\n"); + fprintf(stderr, "--check-freq how often to checkpoint (in steps)\n"); + fprintf(stderr, "--no-stages don't checkpoint stages\n"); + fprintf(stderr, + "--dont-keep don't keep checkpoints around after loading\n"); + fprintf(stderr, "--help print these options\n"); + exit(exit_code); +} + +static void parse_args(int argc, char* argv[], ProgramArgs* args) +{ + for (int argi = 1; argi < argc; ++argi) + { + const char* arg = argv[argi]; + if (!strcmp(arg, "--tf")) { args->tf = atof(argv[++argi]); } + else if (!strcmp(arg, "--dt")) { args->dt = atof(argv[++argi]); } + else if (!strcmp(arg, "--order")) { args->order = atoi(argv[++argi]); } + else if (!strcmp(arg, "--check-freq")) + { + args->check_freq = atoi(argv[++argi]); + } + else if (!strcmp(arg, "--no-stages")) { args->save_stages = SUNFALSE; } + else if (!strcmp(arg, "--dont-keep")) { args->keep_checks = SUNFALSE; } + else if (!strcmp(arg, "--help")) { print_help(argc, argv, 0); } + else { print_help(argc, argv, 1); } + } +} + +int main(int argc, char* argv[]) +{ + SUNContext sunctx = NULL; + SUNContext_Create(SUN_COMM_NULL, &sunctx); + + // Since this a unit test, we want to abort immediately on any internal error + SUNContext_PushErrHandler(sunctx, SUNAbortErrHandlerFn, NULL); + + ProgramArgs args; + args.tf = SUN_RCONST(1.0); + args.dt = SUN_RCONST(1e-4); + args.order = 4; + args.save_stages = SUNTRUE; + args.keep_checks = SUNTRUE; + args.check_freq = 2; + parse_args(argc, argv, &args); + + // + // Create the initial conditions vector + // + + sunindextype neq = 2; + N_Vector u = N_VNew_Serial(neq, sunctx); + N_VConst(SUN_RCONST(1.0), u); + + // + // Create the ARKODE stepper that will be used for the forward evolution. + // + + const sunrealtype dt = args.dt; + sunrealtype t0 = SUN_RCONST(0.0); + sunrealtype tf = args.tf; + const int nsteps = (int)ceil(((tf - t0) / dt + 1)); + const int order = args.order; + + void* arkode_mem = ERKStepCreate(ode_rhs, t0, u, sunctx); + ARKodeSetOrder(arkode_mem, order); + ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2); + + // Enable checkpointing during the forward solution. + // ncheck will be more than nsteps, but for testing purposes we try setting it + // to nsteps and allow things to be resized automatically. + const int check_interval = args.check_freq; + const int ncheck = nsteps; + const sunbooleantype save_stages = args.save_stages; + const sunbooleantype keep_check = args.keep_checks; + SUNAdjointCheckpointScheme checkpoint_scheme = NULL; + SUNMemoryHelper mem_helper = SUNMemoryHelper_Sys(sunctx); + SUNAdjointCheckpointScheme_Create_Fixed(SUNDATAIOMODE_INMEM, mem_helper, + check_interval, ncheck, save_stages, + keep_check, sunctx, &checkpoint_scheme); + ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme); + + // + // Compute the forward solution + // + + printf("\n-- Do forward problem --\n\n"); + + printf("Initial condition:\n"); + N_VPrint(u); + + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u); + if (check_forward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + }; + printf(">>> PASS\n"); + + // + // Create the adjoint stepper + // + + printf("\n-- Do adjoint problem using Jacobian matrix --\n\n"); + + sunindextype num_params = 4; + N_Vector sensu0 = N_VClone(u); + N_Vector sensp = N_VNew_Serial(num_params, sunctx); + N_Vector sens[2] = {sensu0, sensp}; + N_Vector sf = N_VNew_ManyVector(2, sens, sunctx); + + // Set the terminal condition for the adjoint system, which + // should be the the gradient of our cost function at tf. + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + + printf("Adjoint terminal condition:\n"); + N_VPrint(sf); + + SUNAdjointStepper adj_stepper; + ERKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper); + + // + // Now compute the adjoint solution + // + + SUNMatrix jac = SUNDenseMatrix(neq, neq, sunctx); + SUNMatrix jacp = SUNDenseMatrix(neq, num_params, sunctx); + + SUNAdjointStepper_SetJacFn(adj_stepper, ode_jac, jac, parameter_jacobian, jacp); + + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + if (check_sensitivities(sf)) + { + fprintf(stderr, + ">>> FAILURE: adjoint solution does not match correct answer\n"); + return -1; + } + printf("\n>>> PASS\n"); + + // + // Now compute the adjoint solution using Jvp + // + + printf("\n-- Redo adjoint problem using JVP --\n\n"); + if (!keep_check) + { + N_VConst(SUN_RCONST(1.0), u); + printf("Initial condition:\n"); + N_VPrint(u); + ERKStepReInit(arkode_mem, ode_rhs, t0, u); + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u); + if (check_forward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + } + } + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + SUNAdjointStepper_ReInit(adj_stepper, u, t0, sf, tf); + SUNAdjointStepper_SetJacFn(adj_stepper, NULL, NULL, NULL, NULL); + SUNAdjointStepper_SetJacTimesVecFn(adj_stepper, ode_jvp, parameter_jvp); + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + if (check_sensitivities(sf)) + { + fprintf(stderr, + ">>> FAILURE: adjoint solution does not match correct answer\n"); + return -1; + }; + printf("\n>>> PASS\n"); + + // + // Now compute the adjoint solution using vJp + // + + printf("\n-- Redo adjoint problem using VJP --\n\n"); + if (!keep_check) + { + N_VConst(SUN_RCONST(1.0), u); + printf("Initial condition:\n"); + N_VPrint(u); + ERKStepReInit(arkode_mem, ode_rhs, t0, u); + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, dt, u); + if (check_forward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + }; + } + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + SUNAdjointStepper_ReInit(adj_stepper, u, t0, sf, tf); + SUNAdjointStepper_SetJacTimesVecFn(adj_stepper, NULL, NULL); + SUNAdjointStepper_SetVecTimesJacFn(adj_stepper, ode_vjp, parameter_vjp); + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + if (check_sensitivities(sf)) + { + fprintf(stderr, + ">>> FAILURE: adjoint solution does not match correct answer\n"); + return -1; + }; + printf(">>> PASS\n"); + + // + // Now compute the adjoint solution but for when forward problem done backwards + // starting with the forward solution. + // + + printf("\n-- Redo adjoint problem of forward problem done backwards --\n\n"); + + // Swap the start and end times + sunrealtype tmp = t0; + t0 = tf; + tf = tmp; + + // Cleanup from the original forward problem and then recreate the integrator + // for the forward problem done backwards. + SUNAdjointCheckpointScheme_Destroy(&checkpoint_scheme); + SUNAdjointStepper_Destroy(&adj_stepper); + ARKodeFree(&arkode_mem); + arkode_mem = ERKStepCreate(ode_rhs, t0, u, sunctx); + ARKodeSetOrder(arkode_mem, order); + ARKodeSetMaxNumSteps(arkode_mem, nsteps * 2); + SUNAdjointCheckpointScheme_Create_Fixed(SUNDATAIOMODE_INMEM, mem_helper, + check_interval, ncheck, save_stages, + keep_check, sunctx, &checkpoint_scheme); + ARKodeSetAdjointCheckpointScheme(arkode_mem, checkpoint_scheme); + + printf("Initial condition:\n"); + N_VPrint(u); + + forward_solution(sunctx, arkode_mem, checkpoint_scheme, t0, tf, -dt, u); + if (check_forward_backward_answer(u)) + { + fprintf(stderr, + ">>> FAILURE: forward solution does not match correct answer\n"); + return -1; + }; + + ERKStepCreateAdjointStepper(arkode_mem, sf, &adj_stepper); + SUNAdjointStepper_SetJacFn(adj_stepper, ode_jac, jac, parameter_jacobian, jacp); + dgdu(u, sensu0, params, tf); + dgdp(u, sensp, params, tf); + + adjoint_solution(sunctx, adj_stepper, checkpoint_scheme, tf, t0, sf); + // TODO(CJB): figure out why ForwardDiff, CVODES, and ERK adjoint all differ + // if (check_sensitivities_backward(sf)) + // { + // fprintf(stderr, + // ">>> FAILURE: adjoint solution does not match correct answer\n"); + // return -1; + // }; + // printf(">>> PASS\n"); + + // + // Cleanup + // + + // adjoint related + SUNMatDestroy(jac); + SUNMatDestroy(jacp); + N_VDestroy(sensu0); + N_VDestroy(sensp); + N_VDestroy(sf); + SUNAdjointCheckpointScheme_Destroy(&checkpoint_scheme); + SUNAdjointStepper_Destroy(&adj_stepper); + SUNMemoryHelper_Destroy(mem_helper); + // forward and adjoint related + N_VDestroy(u); + ARKodeFree(&arkode_mem); + SUNContext_Free(&sunctx); + + return 0; +} diff --git a/test/unit_tests/arkode/C_serial/CMakeLists.txt b/test/unit_tests/arkode/C_serial/CMakeLists.txt index e9515a0604..a8cfd9e4f5 100644 --- a/test/unit_tests/arkode/C_serial/CMakeLists.txt +++ b/test/unit_tests/arkode/C_serial/CMakeLists.txt @@ -65,11 +65,13 @@ foreach(test_tuple ${ARKODE_unit_tests}) $ sundials_sunmemsys_obj sundials_nvecserial_obj + sundials_nvecmanyvector_obj sundials_sunlinsolband_obj sundials_sunlinsoldense_obj sundials_sunnonlinsolnewton_obj sundials_sunadaptcontrollerimexgus_obj sundials_sunadaptcontrollersoderlind_obj + sundials_adjointcheckpointscheme_fixed_obj ${EXE_EXTRA_LINK_LIBS}) # Tell CMake that we depend on the ARKODE library since it does not pick diff --git a/test/unit_tests/problems/lotka_volterra.hpp b/test/unit_tests/problems/lotka_volterra.hpp new file mode 100644 index 0000000000..3079f4618d --- /dev/null +++ b/test/unit_tests/problems/lotka_volterra.hpp @@ -0,0 +1,139 @@ +/* ----------------------------------------------------------------------------- + * SUNDIALS Copyright Start + * Copyright (c) 2002-2025, Lawrence Livermore National Security + * and Southern Methodist University. + * All rights reserved. + * + * See the top-level LICENSE and NOTICE files for details. + * + * SPDX-License-Identifier: BSD-3-Clause + * SUNDIALS Copyright End + * ----------------------------------------------------------------------------- + * This header provides right-hand-side and related functions (e.g., Jacobian) + * for the four parameter Lotka-Volterra problem, + * + * u = [dx/dt] = [ p_0*x - p_1*x*y ] + * [dy/dt] [ -p_2*y + p_3*x*y ]. + * + * with parameters p. + * ---------------------------------------------------------------------------*/ + +#ifndef _LOTKA_VOLTERRA_HPP +#define _LOTKA_VOLTERRA_HPP + +#include +#include + +namespace problems { +namespace lotka_volterra { + +inline int ode_rhs(sunrealtype t, N_Vector uvec, N_Vector udotvec, void* user_data) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* udot = N_VGetArrayPointer(udotvec); + + udot[0] = p[0] * u[0] - p[1] * u[0] * u[1]; + udot[1] = -p[2] * u[1] + p[3] * u[0] * u[1]; + + return 0; +} + +inline int ode_jac(sunrealtype t, N_Vector uvec, N_Vector udotvec, SUNMatrix Jac, + void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* J = SUNDenseMatrix_Data(Jac); + + J[0] = p[0] - p[1] * u[1]; + J[2] = -p[1] * u[0]; + J[1] = p[3] * u[1]; + J[3] = p[3] * u[0] - p[2]; + + return 0; +} + +inline int ode_jvp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + N_Vector udotvec, void* user_data, N_Vector tmp) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = (p[0] - p[1] * u[1]) * v[0] + p[3] * u[1] * v[1]; + Jv[1] = -p[1] * u[0] * v[0] + (-p[2] + p[3] * u[0]) * v[1]; + + return 0; +} + +inline int ode_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, N_Vector uvec, + N_Vector udotvec, void* user_data, N_Vector tmp) +{ + sunrealtype* p = (sunrealtype*)user_data; + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = (p[0] - p[1] * u[1]) * v[0] + p[3] * u[1] * v[1]; + Jv[1] = -p[1] * u[0] * v[0] + (-p[2] + p[3] * u[0]) * v[1]; + + return 0; +} + +inline int parameter_jacobian(sunrealtype t, N_Vector uvec, N_Vector udotvec, + SUNMatrix Jac, void* user_data, N_Vector tmp1, + N_Vector tmp2, N_Vector tmp3) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* J = SUNDenseMatrix_Data(Jac); + + J[0] = u[0]; + J[1] = SUN_RCONST(0.0); + J[2] = -u[0] * u[1]; + J[3] = SUN_RCONST(0.0); + J[4] = SUN_RCONST(0.0); + J[5] = -u[1]; + J[6] = SUN_RCONST(0.0); + J[7] = u[0] * u[1]; + + return 0; +} + +inline int parameter_jvp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, + N_Vector uvec, N_Vector udotvec, void* user_data, + N_Vector tmp) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = u[0] * v[0]; + Jv[1] = -u[0] * u[1] * v[0]; + Jv[2] = -u[1] * v[1]; + Jv[3] = u[0] * u[1] * v[1]; + + return 0; +} + +inline int parameter_vjp(N_Vector vvec, N_Vector Jvvec, sunrealtype t, + N_Vector uvec, N_Vector udotvec, void* user_data, + N_Vector tmp) +{ + sunrealtype* u = N_VGetArrayPointer(uvec); + sunrealtype* v = N_VGetArrayPointer(vvec); + sunrealtype* Jv = N_VGetArrayPointer(Jvvec); + + Jv[0] = u[0] * v[0]; + Jv[1] = -u[0] * u[1] * v[0]; + Jv[2] = -u[1] * v[1]; + Jv[3] = u[0] * u[1] * v[1]; + + return 0; +} + +} // namespace lotka_volterra +} // namespace problems + +#endif \ No newline at end of file