Skip to content

Commit

Permalink
Fix euler method resolution by computing in two times (#1031)
Browse files Browse the repository at this point in the history
First time, compute the derivative part.
Second time, compute the value.
This way if there is inter-depends it will still be correct.

BREAKPOINT {
  SOLVE states METHOD euler
}

DERIVATIVE states {
   a' = a/n_tau
   n' = (n - a)/n_tau
}

Co-authored-by: Ioannis Magkanaris <ioannis.magkanaris@epfl.ch>
  • Loading branch information
alkino and iomaganaris authored Apr 19, 2023
1 parent 5f864dc commit f5bc783
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 21 deletions.
6 changes: 3 additions & 3 deletions src/parser/diffeq_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ std::string DiffEqContext::get_cnexp_solution() const {
/**
* Return solution for euler method
*/
std::string DiffEqContext::get_euler_solution() const {
return state + " = " + state + "+dt*(" + rhs + ")";
std::string DiffEqContext::get_euler_derivate() const {
return "D" + state + " = " + rhs;
}


Expand Down Expand Up @@ -171,7 +171,7 @@ std::string DiffEqContext::get_solution(bool& cnexp_possible) {
std::string solution;
if (method == "euler") {
cnexp_possible = false;
solution = get_euler_solution();
solution = get_euler_derivate();
} else if (method == "cnexp" && !(deriv_invalid && eqn_invalid)) {
cnexp_possible = true;
solution = get_cnexp_solution();
Expand Down
4 changes: 2 additions & 2 deletions src/parser/diffeq_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class DiffEqContext {
/// return solution for cnexp method
std::string get_cnexp_solution() const;

/// return solution for euler method
std::string get_euler_solution() const;
/// return only the derivate for euler method
std::string get_euler_derivate() const;

/// return solution for non-cnexp method
std::string get_non_cnexp_solution() const;
Expand Down
32 changes: 25 additions & 7 deletions src/visitors/neuron_solve_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ void NeuronSolveVisitor::visit_derivative_block(ast::DerivativeBlock& node) {
derivative_block = true;
node.visit_children(*this);
derivative_block = false;
if (solve_blocks[derivative_block_name] == codegen::naming::EULER_METHOD) {
auto& statement_block = node.get_statement_block();
for (auto& e: euler_solution_expressions) {
statement_block->emplace_back_statement(e);
}
}
}


Expand Down Expand Up @@ -70,13 +76,25 @@ void NeuronSolveVisitor::visit_binary_expression(ast::BinaryExpression& node) {
to_nmodl(node));
}
} else if (solve_method == codegen::naming::EULER_METHOD) {
std::string solution = parser::DiffeqDriver::solve(equation, solve_method);
auto statement = create_statement(solution);
auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(statement);
const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
expr_statement->get_expression());
node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
// computation of the derivative in place
{
std::string solution = parser::DiffeqDriver::solve(equation, solve_method);
auto statement = create_statement(solution);
auto expr_statement = std::dynamic_pointer_cast<ast::ExpressionStatement>(
statement);
const auto bin_expr = std::dynamic_pointer_cast<const ast::BinaryExpression>(
expr_statement->get_expression());
node.set_lhs(std::shared_ptr<ast::Expression>(bin_expr->get_lhs()->clone()));
node.set_rhs(std::shared_ptr<ast::Expression>(bin_expr->get_rhs()->clone()));
}

// create a new statement to compute the value based on the derivative
// this statement will be pushed at the end of the derivative block
{
std::string n = name->get_node_name();
auto statement = create_statement(fmt::format("{} = {} + dt * D{}", n, n, n));
euler_solution_expressions.emplace_back(statement);
}
} else if (solve_method == codegen::naming::DERIVIMPLICIT_METHOD) {
auto varname = "D" + name->get_node_name();
node.set_lhs(std::make_shared<ast::Name>(new ast::String(varname)));
Expand Down
2 changes: 2 additions & 0 deletions src/visitors/neuron_solve_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class NeuronSolveVisitor: public AstVisitor {
/// the derivative name currently being visited
std::string derivative_block_name;

std::vector<std::shared_ptr<ast::Statement>> euler_solution_expressions;

public:
NeuronSolveVisitor() = default;

Expand Down
35 changes: 35 additions & 0 deletions test/unit/codegen/codegen_cpp_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,41 @@ SCENARIO("Some tests on derivimplicit", "[codegen][derivimplicit_solver]") {
}


SCENARIO("Some tests on euler solver", "[codegen][euler_solver]") {
GIVEN("A mod file with euler") {
std::string const nmodl_text = R"(
NEURON {
RANGE inf
}
INITIAL {
inf = 2
}
STATE {
n
m
}
BREAKPOINT {
SOLVE state METHOD euler
}
DERIVATIVE state {
m' = 2 * m
inf = inf * 3
n' = (2 + m - inf) * n
}
)";
THEN("Correct code is generated") {
auto const generated = get_cpp_code(nmodl_text);
std::string nrn_state_expected_code = R"(inst->Dm[id] = 2.0 * inst->m[id];
inf = inf * 3.0;
inst->Dn[id] = (2.0 + inst->m[id] - inf) * inst->n[id];
inst->m[id] = inst->m[id] + nt->_dt * inst->Dm[id];
inst->n[id] = inst->n[id] + nt->_dt * inst->Dn[id];)";
REQUIRE_THAT(generated, Contains(nrn_state_expected_code));
}
}
}


SCENARIO("Check codegen for MUTEX and PROTECT", "[codegen][mutex_protect]") {
GIVEN("A mod file containing MUTEX & PROTECT") {
std::string const nmodl_text = R"(
Expand Down
18 changes: 9 additions & 9 deletions test/unit/utils/nmodl_constructs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,7 @@ std::vector<DiffEqTestCase> const diff_eq_constructs{
{
"GluSynapse.mod",
"A_AMPA' = A_AMPA*A_AMPA",
"A_AMPA = A_AMPA+dt*(A_AMPA*A_AMPA)",
"DA_AMPA = A_AMPA*A_AMPA",
"euler"
},

Expand Down Expand Up @@ -1525,42 +1525,42 @@ std::vector<DiffEqTestCase> const diff_eq_constructs{
{
"GluSynapse.mod",
"A_AMPA' = -A_AMPA/tau_r_AMPA",
"A_AMPA = A_AMPA+dt*(-A_AMPA/tau_r_AMPA)",
"DA_AMPA = -A_AMPA/tau_r_AMPA",
"euler"
},

{
"GluSynapse.mod",
"m_VDCC' = (minf_VDCC-m_VDCC)/mtau_VDCC",
"m_VDCC = m_VDCC+dt*((minf_VDCC-m_VDCC)/mtau_VDCC)",
"Dm_VDCC = (minf_VDCC-m_VDCC)/mtau_VDCC",
"euler"
},

{
"GluSynapse.mod",
"cai_CR' = -(1e-9)*(ica_NMDA + ica_VDCC)*gamma_ca_CR/((1e-15)*volume_CR*2*FARADAY) - (cai_CR - min_ca_CR)/tau_ca_CR",
"cai_CR = cai_CR+dt*(-(1e-9)*(ica_NMDA + ica_VDCC)*gamma_ca_CR/((1e-15)*volume_CR*2*FARADAY) - (cai_CR - min_ca_CR)/tau_ca_CR)",
"Dcai_CR = -(1e-9)*(ica_NMDA + ica_VDCC)*gamma_ca_CR/((1e-15)*volume_CR*2*FARADAY) - (cai_CR - min_ca_CR)/tau_ca_CR",
"euler"
},

{
"GluSynapse.mod",
"effcai_GB' = -0.005*effcai_GB + (cai_CR - min_ca_CR)",
"effcai_GB = effcai_GB+dt*(-0.005*effcai_GB + (cai_CR - min_ca_CR))",
"Deffcai_GB = -0.005*effcai_GB + (cai_CR - min_ca_CR)",
"euler"
},

{
"GluSynapse.mod",
"Rho_GB' = ( - Rho_GB*(1-Rho_GB)*(rho_star_GB-Rho_GB) + potentiate_GB*gamma_p_GB*(1-Rho_GB) - depress_GB*gamma_d_GB*Rho_GB ) / ((1e3)*tau_GB)",
"Rho_GB = Rho_GB+dt*(( - Rho_GB*(1-Rho_GB)*(rho_star_GB-Rho_GB) + potentiate_GB*gamma_p_GB*(1-Rho_GB) - depress_GB*gamma_d_GB*Rho_GB ) / ((1e3)*tau_GB))",
"DRho_GB = ( - Rho_GB*(1-Rho_GB)*(rho_star_GB-Rho_GB) + potentiate_GB*gamma_p_GB*(1-Rho_GB) - depress_GB*gamma_d_GB*Rho_GB ) / ((1e3)*tau_GB)",
"euler"
},

{
"GluSynapse.mod",
"Use_GB' = (Use_d_GB + Rho_GB*(Use_p_GB-Use_d_GB) - Use_GB) / ((1e3)*tau_Use_GB)",
"Use_GB = Use_GB+dt*((Use_d_GB + Rho_GB*(Use_p_GB-Use_d_GB) - Use_GB) / ((1e3)*tau_Use_GB))",
"DUse_GB = (Use_d_GB + Rho_GB*(Use_p_GB-Use_d_GB) - Use_GB) / ((1e3)*tau_Use_GB)",
"euler"
},

Expand Down Expand Up @@ -1620,14 +1620,14 @@ std::vector<DiffEqTestCase> const diff_eq_constructs{
{
"syn_bip_gan.mod",
"s' = (s_inf-s)/((1-s_inf)*tau*s)",
"s = s+dt*((s_inf-s)/((1-s_inf)*tau*s))",
"Ds = (s_inf-s)/((1-s_inf)*tau*s)",
"euler"
},

{
"syn_rod_bip.mod",
"s' = (s_inf-s)/((1-s_inf)*tau*s)",
"s = s+dt*((s_inf-s)/((1-s_inf)*tau*s))",
"Ds = (s_inf-s)/((1-s_inf)*tau*s)",
"euler"
},

Expand Down

0 comments on commit f5bc783

Please sign in to comment.