Skip to content

Commit

Permalink
Add CvodeBlock and CvodeVisitor (#1467)
Browse files Browse the repository at this point in the history
  • Loading branch information
JCGoran authored Oct 28, 2024
1 parent edaa090 commit 6b6a630
Show file tree
Hide file tree
Showing 14 changed files with 529 additions and 2 deletions.
75 changes: 75 additions & 0 deletions docs/contents/cvode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
Variable timestep integration (CVODE)
=====================================

As opposed to fixed timestep integration, variable timestep integration (CVODE
in NEURON parlance) uses the SUNDIALS package to solve a ``DERIVATIVE`` or
``KINETIC`` block using a variable timestep. This allows for faster computation
times if the function in question does not vary too wildly.

Implementation in NMODL
-----------------------

The code generation for CVODE is activated only if exactly one of the following
is satisfied:

1. there is one ``KINETIC`` block in the mod file
2. there is one ``DERIVATIVE`` block in the mod file
3. a ``PROCEDURE`` block is solved with the ``after_cvode``, ``cvode_t``, or
``cvode_t_v`` methods

In NMODL, all ``KINETIC`` blocks are internally first converted to
``DERIVATIVE`` blocks. The ``DERIVATIVE`` block is then converted to a
``CVODE`` block, which contains two parts; the first part contains the update
step for non-stiff systems (functional iteration), while the second part
contains the update step for stiff systems (additional step using the
Jacobian). For more information, see `CVODES documentation`_, eqs. (4.8) and
(4.9). Given a ``DERIVATIVE`` block of the form:

.. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html

.. code-block::
DERIVATIVE state {
x_i' = f(x_1, ..., x_n)
}
the structure of the ``CVODE`` block is then roughly:

.. code-block::
CVODE state[n] {
Dx_i = f_i(x_1, ..., x_n)
}{
Dx_i = Dx_i / (1 - dt * J_ii(f))
}
where ``N`` is the total number of ODEs to solve, and ``J_ii(f)`` is the
diagonal part of the Jacobian, i.e.

.. math::
J_{ii}(f) = \frac{ \partial f_i(x_1, \ldots, x_n) }{\partial x_i}
As an example, consider the following ``DERIVATIVE`` block:

.. code-block::
DERIVATIVE state {
X' = - X
}
Where ``X`` is a ``STATE`` variable with some initial value, specified in the
``INITIAL`` block. The corresponding ``CVODE`` block is then:

.. code-block::
CVODE state[1] {
DX = - X
}{
DX = DX / (1 - dt * (-1))
}
**NOTE**: in case there are ``CONSERVE`` statements in ``KINETIC`` blocks, as
they are merely hints to NMODL, and have no impact on the results, they are
removed from ``CVODE`` blocks before the codegen stage.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ About NMODL
contents/cable_equations
contents/globals
contents/longitudinal_diffusion
contents/cvode

.. toctree::
:maxdepth: 3
Expand Down
1 change: 1 addition & 0 deletions src/language/code_generator.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ set(AST_GENERATED_SOURCES
${PROJECT_BINARY_DIR}/src/ast/constant_statement.hpp
${PROJECT_BINARY_DIR}/src/ast/constant_var.hpp
${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp
${PROJECT_BINARY_DIR}/src/ast/cvode_block.hpp
${PROJECT_BINARY_DIR}/src/ast/define.hpp
${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp
${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp
Expand Down
20 changes: 20 additions & 0 deletions src/language/codegen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@
- finalize_block:
brief: "Statement block to be executed after calling linear solver"
type: StatementBlock
- CvodeBlock:
nmodl: "CVODE_BLOCK "
members:
- name:
brief: "Name of the block"
type: Name
node_name: true
suffix: {value: " "}
- n_odes:
brief: "number of ODEs to solve"
type: Integer
prefix: {value: "["}
suffix: {value: "]"}
- non_stiff_block:
brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems"
type: StatementBlock
- stiff_block:
brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f)), used for updating stiff systems"
type: StatementBlock
brief: "Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks"
- LongitudinalDiffusionBlock:
brief: "Extracts information required for LONGITUDINAL_DIFFUSION for each KINETIC block."
nmodl: "LONGITUDINAL_DIFFUSION_BLOCK"
Expand Down
11 changes: 11 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "visitors/after_cvode_to_cnexp_visitor.hpp"
#include "visitors/ast_visitor.hpp"
#include "visitors/constant_folder_visitor.hpp"
#include "visitors/cvode_visitor.hpp"
#include "visitors/function_callpath_visitor.hpp"
#include "visitors/global_var_visitor.hpp"
#include "visitors/implicit_argument_visitor.hpp"
Expand Down Expand Up @@ -516,6 +517,8 @@ int run_nmodl(int argc, const char* argv[]) {

enable_sympy(solver_exists(*ast, "derivimplicit"), "'SOLVE ... METHOD derivimplicit'");
enable_sympy(node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK), "'LINEAR' block");
enable_sympy(node_exists(*ast, ast::AstNodeType::DERIVATIVE_BLOCK),
"'DERIVATIVE' block");
enable_sympy(node_exists(*ast, ast::AstNodeType::NON_LINEAR_BLOCK),
"'NONLINEAR' block");
enable_sympy(solver_exists(*ast, "sparse"), "'SOLVE ... METHOD sparse'");
Expand All @@ -526,6 +529,14 @@ int run_nmodl(int argc, const char* argv[]) {
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance()
.api()
.initialize_interpreter();

if (neuron_code) {
logger->info("Running CVODE visitor");
CvodeVisitor().visit_program(*ast);
SymtabVisitor(update_symtab).visit_program(*ast);
ast_to_nmodl(*ast, filepath("cvode"));
}

if (sympy_conductance) {
logger->info("Running sympy conductance visitor");
SympyConductanceVisitor().visit_program(*ast);
Expand Down
49 changes: 47 additions & 2 deletions src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

#include "codegen/codegen_naming.hpp"
#include "pybind/pyembed.hpp"

#include <fmt/format.h>
#include <optional>
#include <pybind11/embed.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -186,6 +187,49 @@ except Exception as e:
return {std::move(solution), std::move(exception_message)};
}

std::tuple<std::string, std::string> call_diff2c(
const std::string& expression,
const std::pair<std::string, std::optional<int>>& variable,
const std::unordered_set<std::string>& indexed_vars) {
std::string statements;
// only indexed variables require special treatment
for (const auto& var: indexed_vars) {
statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var);
}
auto [name, property] = variable;
if (property.has_value()) {
name = fmt::format("sp.IndexedBase('{}', shape=[1])", name);
statements += fmt::format("_allvars.append({})", name);
} else {
name = fmt::format("'{}'", name);
}
auto locals = py::dict("expression"_a = expression);
std::string script =
fmt::format(R"(
_allvars = []
{}
variable = {}
exception_message = ""
try:
solution = differentiate2c(expression,
variable,
_allvars,
)
except Exception as e:
# if we fail, fail silently and return empty string
solution = ""
exception_message = str(e)
)",
statements,
property.has_value() ? fmt::format("{}[{}]", name, property.value()) : name);

py::exec(nmodl::pybind_wrappers::ode_py + script, locals);

auto solution = locals["solution"].cast<std::string>();
auto exception_message = locals["exception_message"].cast<std::string>();

return {std::move(solution), std::move(exception_message)};
}

void initialize_interpreter_func() {
pybind11::initialize_interpreter(true);
Expand All @@ -203,7 +247,8 @@ NMODL_EXPORT pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept {
&call_solve_nonlinear_system,
&call_solve_linear_system,
&call_diffeq_solver,
&call_analytic_diff};
&call_analytic_diff,
&call_diff2c};
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/pybind/wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

#pragma once

#include <optional>
#include <set>
#include <string>
#include <unordered_set>
#include <vector>

namespace nmodl {
Expand Down Expand Up @@ -44,13 +46,25 @@ std::tuple<std::string, std::string> call_analytic_diff(
const std::vector<std::string>& expressions,
const std::set<std::string>& used_names_in_block);


/// \brief Differentiates an expression with respect to a variable
/// \param expression The expression we want to differentiate
/// \param variable The name of the independent variable we are differentiating against
/// \param index_vars A set of array (indexable) variables that appear in \ref expression
/// \return The tuple (solution, exception)
std::tuple<std::string, std::string> call_diff2c(
const std::string& expression,
const std::pair<std::string, std::optional<int>>& variable,
const std::unordered_set<std::string>& indexed_vars = {});

struct pybind_wrap_api {
decltype(&initialize_interpreter_func) initialize_interpreter;
decltype(&finalize_interpreter_func) finalize_interpreter;
decltype(&call_solve_nonlinear_system) solve_nonlinear_system;
decltype(&call_solve_linear_system) solve_linear_system;
decltype(&call_diffeq_solver) diffeq_solver;
decltype(&call_analytic_diff) analytic_diff;
decltype(&call_diff2c) diff2c;
};

#ifdef _WIN32
Expand Down
1 change: 1 addition & 0 deletions src/visitors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_library(
visitor STATIC
after_cvode_to_cnexp_visitor.cpp
constant_folder_visitor.cpp
cvode_visitor.cpp
defuse_analyze_visitor.cpp
function_callpath_visitor.cpp
global_var_visitor.cpp
Expand Down
Loading

0 comments on commit 6b6a630

Please sign in to comment.