Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CvodeBlock and CvodeVisitor #1467

Merged
merged 55 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
417acac
Add `DerivativeOriginalFunctionBlock` and `DerivativeVisitor`
JCGoran Sep 24, 2024
d33a594
Remove unused functions
JCGoran Sep 24, 2024
b9f08d0
Add test for DerivativeOriginalVisitor
JCGoran Sep 25, 2024
c5dc45e
Fmt
JCGoran Sep 25, 2024
1dadd7a
Fix leak
JCGoran Sep 25, 2024
1125fdf
Remove unused stuff
JCGoran Sep 25, 2024
0267fbd
Update block description
JCGoran Sep 25, 2024
e58070f
Rename DERIVATIVE_ORIGINAL to CVODE
JCGoran Sep 27, 2024
044dfd9
Finish renaming
JCGoran Sep 30, 2024
50f38ce
Add item with Jacobian
JCGoran Sep 30, 2024
bc68701
Merge branch 'master' into jelic/cvode_visitors
JCGoran Sep 30, 2024
f82fe1f
Do not use an int but an enum-wrapped int
JCGoran Sep 30, 2024
bd2fd36
Add support for diffing expressions with indexed vars
JCGoran Sep 30, 2024
b082f0d
Allow diffing implicit functions in `differentiate2c`
JCGoran Sep 30, 2024
edf33a7
Simplify condition
JCGoran Sep 30, 2024
565fa03
Better testing
JCGoran Oct 1, 2024
6bd6aed
Add suggestions from code review
JCGoran Oct 1, 2024
0eba407
Add `stepsize` param to `differentiate2c`
JCGoran Oct 2, 2024
405909e
Merge remote-tracking branch 'origin/master' into jelic/diff_implicit
JCGoran Oct 2, 2024
c1e7fd3
Try Python 3.9 maybe?
JCGoran Oct 2, 2024
0207373
Merge branch 'jelic/diff_indexed' into jelic/cvode_visitors
JCGoran Oct 2, 2024
32d36a8
Merge branch 'jelic/diff_implicit' into jelic/cvode_visitors
JCGoran Oct 2, 2024
4e9fb49
Merge remote-tracking branch 'origin/master' into jelic/cvode_visitors
JCGoran Oct 7, 2024
4fde929
Put back Python 3.8 for now
JCGoran Oct 7, 2024
a08df25
Remove remaining occurrences of `DerivativeOriginalVisitor`
JCGoran Oct 7, 2024
9fee9a8
WIP on CONSERVE
JCGoran Oct 8, 2024
d98fcc0
Ignore CONSERVE equations
JCGoran Oct 9, 2024
321cdb3
Add documentation
JCGoran Oct 10, 2024
2984e46
Really delete CONSERVE statements this time
JCGoran Oct 10, 2024
313330b
Add test for CONSERVE statement
JCGoran Oct 10, 2024
03b40e8
Fix variable naming
JCGoran Oct 10, 2024
836ec74
Update docstring
JCGoran Oct 10, 2024
9f6b751
Fix typo
JCGoran Oct 10, 2024
1348ab9
Update docstring
JCGoran Oct 10, 2024
96b1bb6
Merge branch 'master' into jelic/cvode_visitors
JCGoran Oct 14, 2024
ee9c187
Add option for diffing IndexedName
JCGoran Oct 14, 2024
54a480e
Refactor
JCGoran Oct 14, 2024
32aa0cb
Merge branch 'master' into jelic/cvode_visitors
JCGoran Oct 14, 2024
cefc159
Enable sympy if NEURON codegen
JCGoran Oct 14, 2024
a3e1c6c
Mark constructors as explicit
JCGoran Oct 14, 2024
7177fb6
Merge branch 'master' into jelic/cvode_visitors
JCGoran Oct 15, 2024
659b018
Update tests
JCGoran Oct 15, 2024
bd376c3
Remove code duplication
JCGoran Oct 15, 2024
a60577b
Remove unused class field
JCGoran Oct 15, 2024
8bc7d18
Only enable sympy if DERIVATIVE block exists
JCGoran Oct 16, 2024
40cb10b
Rename CVODE subblocks with more apt names
JCGoran Oct 16, 2024
9b58feb
`nonstiff` -> `non_stiff`
JCGoran Oct 16, 2024
a6ea5ab
Get # of ODEs to solve
JCGoran Oct 21, 2024
5d94f7a
der_block(s) -> derivative_block(s)
JCGoran Oct 21, 2024
bf9db7a
get_name_map -> get_indexed_variables
JCGoran Oct 21, 2024
fd0c619
Add check for multiple DERIVATIVE blocks
JCGoran Oct 21, 2024
5f4a00a
Update tests for CVODE
JCGoran Oct 21, 2024
e30c27c
Update docs
JCGoran Oct 21, 2024
167d38f
Address comments from review
JCGoran Oct 24, 2024
4ce3d85
Fix variable name
JCGoran Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/language/code_generator.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ set(AST_GENERATED_SOURCES
${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp
${PROJECT_BINARY_DIR}/src/ast/define.hpp
${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp
${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp
${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp
${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp
${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp
Expand Down
25 changes: 24 additions & 1 deletion src/language/codegen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,30 @@
type: StatementBlock
- finalize_block:
brief: "Statement block to be executed after calling linear solver"
type: StatementBlock
type: StatementBlock
- DerivativeOriginalFunctionBlock:
JCGoran marked this conversation as resolved.
Show resolved Hide resolved
nmodl: "DERIVATIVE_ORIGINAL_FUNCTION "
members:
- name:
brief: "Name of the derivative block"
type: Name
node_name: true
suffix: {value: " "}
- statement_block:
brief: "Block with statements vector"
type: StatementBlock
getter: {override: true}
brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL"
JCGoran marked this conversation as resolved.
Show resolved Hide resolved
description: |
The original `DERIVATIVE` block in NMODL is
replaced in-place if the system of ODEs is
solvable analytically. Therefore, this
block's sole purpose is to keep the
original, unsolved block in the AST. This is
primarily useful when we need to solve the
ODE system using implicit methods, for
instance, CVode.

- WrappedExpression:
brief: "Wrap any other expression type"
members:
Expand Down
8 changes: 8 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "visitors/after_cvode_to_cnexp_visitor.hpp"
#include "visitors/ast_visitor.hpp"
#include "visitors/constant_folder_visitor.hpp"
#include "visitors/derivative_original_visitor.hpp"
#include "visitors/function_callpath_visitor.hpp"
#include "visitors/global_var_visitor.hpp"
#include "visitors/implicit_argument_visitor.hpp"
Expand Down Expand Up @@ -497,6 +498,13 @@ int run_nmodl(int argc, const char* argv[]) {
const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK);
const bool sympy_sparse = solver_exists(*ast, "sparse");

if (neuron_code) {
logger->info("Running derivative visitor");
DerivativeOriginalVisitor().visit_program(*ast);
SymtabVisitor(update_symtab).visit_program(*ast);
ast_to_nmodl(*ast, filepath("derivative_original"));
}

if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit ||
sympy_linear) {
nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance()
Expand Down
1 change: 1 addition & 0 deletions src/visitors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_library(
visitor STATIC
after_cvode_to_cnexp_visitor.cpp
constant_folder_visitor.cpp
derivative_original_visitor.cpp
defuse_analyze_visitor.cpp
function_callpath_visitor.cpp
global_var_visitor.cpp
Expand Down
103 changes: 103 additions & 0 deletions src/visitors/derivative_original_visitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright 2023 Blue Brain Project, EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#include "visitors/derivative_original_visitor.hpp"

#include "ast/all.hpp"
#include "lexer/token_mapping.hpp"
#include "pybind/pyembed.hpp"
#include "utils/logger.hpp"
#include "visitors/visitor_utils.hpp"
#include <optional>
#include <utility>

namespace pywrap = nmodl::pybind_wrappers;

namespace nmodl {
namespace visitor {


void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) {
node.visit_children(*this);
der_block_function = node.clone();
}


void DerivativeOriginalVisitor::visit_derivative_original_function_block(
ast::DerivativeOriginalFunctionBlock& node) {
derivative_block = true;
node_type = node.get_node_type();
node.visit_children(*this);
node_type = ast::AstNodeType::NODE;
derivative_block = false;
}

void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) {
differential_equation = true;
node.visit_children(*this);
differential_equation = false;
}


void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) {
const auto& lhs = node.get_lhs();

/// we have to only solve ODEs under original derivative block where lhs is variable
if (!derivative_block || !differential_equation || !lhs->is_var_name()) {
return;
}

auto name = std::dynamic_pointer_cast<ast::VarName>(lhs)->get_name();

if (name->is_prime_name() || name->is_indexed_name()) {
JCGoran marked this conversation as resolved.
Show resolved Hide resolved
std::string varname;
if (name->is_prime_name()) {
varname = "D" + name->get_node_name();
logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}",
name->get_node_name(),
varname,
to_nmodl(node));
node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
if (program_symtab->lookup(varname) == nullptr) {
auto symbol = std::make_shared<symtab::Symbol>(varname, ModToken());
symbol->set_original_name(name->get_node_name());
program_symtab->insert(symbol);
}
} else {
varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\'');
// we discard the RHS here so it can be anything (as long as NMODL considers it valid)
auto statement = fmt::format("{} = {}", varname, varname);
logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}",
to_nmodl(node.get_lhs()),
varname,
to_nmodl(node));
auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
create_statement(statement));
const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
expr_statement->get_expression());
node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
// TODO add symbol?
JCGoran marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

void DerivativeOriginalVisitor::visit_program(ast::Program& node) {
program_symtab = node.get_symbol_table();
node.visit_children(*this);
if (der_block_function) {
auto der_node =
new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(),
der_block_function->get_statement_block());
node.emplace_back_node(der_node);
}

// re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block
node.visit_children(*this);
JCGoran marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace visitor
} // namespace nmodl
64 changes: 64 additions & 0 deletions src/visitors/derivative_original_visitor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2023 Blue Brain Project, EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

/**
* \file
* \brief \copybrief nmodl::visitor::DerivativeOriginalVisitor
*/

#include "symtab/decl.hpp"
#include "visitors/ast_visitor.hpp"
#include <string>

namespace nmodl {
namespace visitor {

/**
* \addtogroup visitor_classes
* \{
*/

/**
* \class DerivativeOriginalVisitor
JCGoran marked this conversation as resolved.
Show resolved Hide resolved
* \brief Make a copy of the `DERIVATIVE` block (if it exists), and insert back as
* `DERIVATIVE_ORIGINAL_FUNCTION` block.
*
* If \ref SympySolverVisitor runs successfully, it replaces the original
* solution. This block is inserted before that to prevent losing access to
* information about the block.
*/
class DerivativeOriginalVisitor: public AstVisitor {
private:
/// The copy of the derivative block we are solving
ast::DerivativeBlock* der_block_function = nullptr;
JCGoran marked this conversation as resolved.
Show resolved Hide resolved

/// true while visiting differential equation
bool differential_equation = false;
JCGoran marked this conversation as resolved.
Show resolved Hide resolved

/// global symbol table
symtab::SymbolTable* program_symtab = nullptr;

/// visiting derivative block
bool derivative_block = false;

ast::AstNodeType node_type = ast::AstNodeType::NODE;
JCGoran marked this conversation as resolved.
Show resolved Hide resolved

public:
void visit_derivative_block(ast::DerivativeBlock& node) override;
void visit_program(ast::Program& node) override;
void visit_derivative_original_function_block(
ast::DerivativeOriginalFunctionBlock& node) override;
void visit_diff_eq_expression(ast::DiffEqExpression& node) override;
void visit_binary_expression(ast::BinaryExpression& node) override;
};

/** \} */ // end of visitor_classes

} // namespace visitor
} // namespace nmodl
4 changes: 4 additions & 0 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) {
}
}

// Skip visiting DERIVATIVE_ORIGINAL block
void SympySolverVisitor::visit_derivative_original_function_block(
ast::DerivativeOriginalFunctionBlock& node) {}

void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) {
const auto& lhs = node.get_expression()->get_lhs();

Expand Down
2 changes: 2 additions & 0 deletions src/visitors/sympy_solver_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class SympySolverVisitor: public AstVisitor {
void visit_expression_statement(ast::ExpressionStatement& node) override;
void visit_statement_block(ast::StatementBlock& node) override;
void visit_program(ast::Program& node) override;
void visit_derivative_original_function_block(
ast::DerivativeOriginalFunctionBlock& node) override;
};

/** @} */ // end of visitor_classes
Expand Down
Loading