From f5bc7832a6dce9eef79dbc9bdd23ce084cd13b51 Mon Sep 17 00:00:00 2001 From: Nicolas Cornu Date: Wed, 19 Apr 2023 14:55:01 +0200 Subject: [PATCH] Fix euler method resolution by computing in two times (#1031) First time, compute the derivative part. Second time, compute the value. This way if there is inter-depends it will still be correct. BREAKPOINT { SOLVE states METHOD euler } DERIVATIVE states { a' = a/n_tau n' = (n - a)/n_tau } Co-authored-by: Ioannis Magkanaris --- src/parser/diffeq_context.cpp | 6 ++-- src/parser/diffeq_context.hpp | 4 +-- src/visitors/neuron_solve_visitor.cpp | 32 ++++++++++++++++----- src/visitors/neuron_solve_visitor.hpp | 2 ++ test/unit/codegen/codegen_cpp_visitor.cpp | 35 +++++++++++++++++++++++ test/unit/utils/nmodl_constructs.cpp | 18 ++++++------ 6 files changed, 76 insertions(+), 21 deletions(-) diff --git a/src/parser/diffeq_context.cpp b/src/parser/diffeq_context.cpp index 7aa4478795..5c3a9f6963 100644 --- a/src/parser/diffeq_context.cpp +++ b/src/parser/diffeq_context.cpp @@ -139,8 +139,8 @@ std::string DiffEqContext::get_cnexp_solution() const { /** * Return solution for euler method */ -std::string DiffEqContext::get_euler_solution() const { - return state + " = " + state + "+dt*(" + rhs + ")"; +std::string DiffEqContext::get_euler_derivate() const { + return "D" + state + " = " + rhs; } @@ -171,7 +171,7 @@ std::string DiffEqContext::get_solution(bool& cnexp_possible) { std::string solution; if (method == "euler") { cnexp_possible = false; - solution = get_euler_solution(); + solution = get_euler_derivate(); } else if (method == "cnexp" && !(deriv_invalid && eqn_invalid)) { cnexp_possible = true; solution = get_cnexp_solution(); diff --git a/src/parser/diffeq_context.hpp b/src/parser/diffeq_context.hpp index caec0a5ad5..6fc6c13693 100644 --- a/src/parser/diffeq_context.hpp +++ b/src/parser/diffeq_context.hpp @@ -81,8 +81,8 @@ class DiffEqContext { /// return solution for cnexp method std::string get_cnexp_solution() const; - /// return solution for euler method - std::string get_euler_solution() const; + /// return only the derivate for euler method + std::string get_euler_derivate() const; /// return solution for non-cnexp method std::string get_non_cnexp_solution() const; diff --git a/src/visitors/neuron_solve_visitor.cpp b/src/visitors/neuron_solve_visitor.cpp index a687cd9160..80a1c05b2d 100644 --- a/src/visitors/neuron_solve_visitor.cpp +++ b/src/visitors/neuron_solve_visitor.cpp @@ -31,6 +31,12 @@ void NeuronSolveVisitor::visit_derivative_block(ast::DerivativeBlock& node) { derivative_block = true; node.visit_children(*this); derivative_block = false; + if (solve_blocks[derivative_block_name] == codegen::naming::EULER_METHOD) { + auto& statement_block = node.get_statement_block(); + for (auto& e: euler_solution_expressions) { + statement_block->emplace_back_statement(e); + } + } } @@ -70,13 +76,25 @@ void NeuronSolveVisitor::visit_binary_expression(ast::BinaryExpression& node) { to_nmodl(node)); } } else if (solve_method == codegen::naming::EULER_METHOD) { - std::string solution = parser::DiffeqDriver::solve(equation, solve_method); - auto statement = create_statement(solution); - auto expr_statement = std::dynamic_pointer_cast(statement); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); - node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + // computation of the derivative in place + { + std::string solution = parser::DiffeqDriver::solve(equation, solve_method); + auto statement = create_statement(solution); + auto expr_statement = std::dynamic_pointer_cast( + statement); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + } + + // create a new statement to compute the value based on the derivative + // this statement will be pushed at the end of the derivative block + { + std::string n = name->get_node_name(); + auto statement = create_statement(fmt::format("{} = {} + dt * D{}", n, n, n)); + euler_solution_expressions.emplace_back(statement); + } } else if (solve_method == codegen::naming::DERIVIMPLICIT_METHOD) { auto varname = "D" + name->get_node_name(); node.set_lhs(std::make_shared(new ast::String(varname))); diff --git a/src/visitors/neuron_solve_visitor.hpp b/src/visitors/neuron_solve_visitor.hpp index 1d8410af0b..6d4bd77b6a 100644 --- a/src/visitors/neuron_solve_visitor.hpp +++ b/src/visitors/neuron_solve_visitor.hpp @@ -59,6 +59,8 @@ class NeuronSolveVisitor: public AstVisitor { /// the derivative name currently being visited std::string derivative_block_name; + std::vector> euler_solution_expressions; + public: NeuronSolveVisitor() = default; diff --git a/test/unit/codegen/codegen_cpp_visitor.cpp b/test/unit/codegen/codegen_cpp_visitor.cpp index cfbb4e85e2..5e6d72cda1 100644 --- a/test/unit/codegen/codegen_cpp_visitor.cpp +++ b/test/unit/codegen/codegen_cpp_visitor.cpp @@ -1083,6 +1083,41 @@ SCENARIO("Some tests on derivimplicit", "[codegen][derivimplicit_solver]") { } +SCENARIO("Some tests on euler solver", "[codegen][euler_solver]") { + GIVEN("A mod file with euler") { + std::string const nmodl_text = R"( + NEURON { + RANGE inf + } + INITIAL { + inf = 2 + } + STATE { + n + m + } + BREAKPOINT { + SOLVE state METHOD euler + } + DERIVATIVE state { + m' = 2 * m + inf = inf * 3 + n' = (2 + m - inf) * n + } + )"; + THEN("Correct code is generated") { + auto const generated = get_cpp_code(nmodl_text); + std::string nrn_state_expected_code = R"(inst->Dm[id] = 2.0 * inst->m[id]; + inf = inf * 3.0; + inst->Dn[id] = (2.0 + inst->m[id] - inf) * inst->n[id]; + inst->m[id] = inst->m[id] + nt->_dt * inst->Dm[id]; + inst->n[id] = inst->n[id] + nt->_dt * inst->Dn[id];)"; + REQUIRE_THAT(generated, Contains(nrn_state_expected_code)); + } + } +} + + SCENARIO("Check codegen for MUTEX and PROTECT", "[codegen][mutex_protect]") { GIVEN("A mod file containing MUTEX & PROTECT") { std::string const nmodl_text = R"( diff --git a/test/unit/utils/nmodl_constructs.cpp b/test/unit/utils/nmodl_constructs.cpp index 1bad61e079..05b61eb35d 100644 --- a/test/unit/utils/nmodl_constructs.cpp +++ b/test/unit/utils/nmodl_constructs.cpp @@ -1468,7 +1468,7 @@ std::vector const diff_eq_constructs{ { "GluSynapse.mod", "A_AMPA' = A_AMPA*A_AMPA", - "A_AMPA = A_AMPA+dt*(A_AMPA*A_AMPA)", + "DA_AMPA = A_AMPA*A_AMPA", "euler" }, @@ -1525,42 +1525,42 @@ std::vector const diff_eq_constructs{ { "GluSynapse.mod", "A_AMPA' = -A_AMPA/tau_r_AMPA", - "A_AMPA = A_AMPA+dt*(-A_AMPA/tau_r_AMPA)", + "DA_AMPA = -A_AMPA/tau_r_AMPA", "euler" }, { "GluSynapse.mod", "m_VDCC' = (minf_VDCC-m_VDCC)/mtau_VDCC", - "m_VDCC = m_VDCC+dt*((minf_VDCC-m_VDCC)/mtau_VDCC)", + "Dm_VDCC = (minf_VDCC-m_VDCC)/mtau_VDCC", "euler" }, { "GluSynapse.mod", "cai_CR' = -(1e-9)*(ica_NMDA + ica_VDCC)*gamma_ca_CR/((1e-15)*volume_CR*2*FARADAY) - (cai_CR - min_ca_CR)/tau_ca_CR", - "cai_CR = cai_CR+dt*(-(1e-9)*(ica_NMDA + ica_VDCC)*gamma_ca_CR/((1e-15)*volume_CR*2*FARADAY) - (cai_CR - min_ca_CR)/tau_ca_CR)", + "Dcai_CR = -(1e-9)*(ica_NMDA + ica_VDCC)*gamma_ca_CR/((1e-15)*volume_CR*2*FARADAY) - (cai_CR - min_ca_CR)/tau_ca_CR", "euler" }, { "GluSynapse.mod", "effcai_GB' = -0.005*effcai_GB + (cai_CR - min_ca_CR)", - "effcai_GB = effcai_GB+dt*(-0.005*effcai_GB + (cai_CR - min_ca_CR))", + "Deffcai_GB = -0.005*effcai_GB + (cai_CR - min_ca_CR)", "euler" }, { "GluSynapse.mod", "Rho_GB' = ( - Rho_GB*(1-Rho_GB)*(rho_star_GB-Rho_GB) + potentiate_GB*gamma_p_GB*(1-Rho_GB) - depress_GB*gamma_d_GB*Rho_GB ) / ((1e3)*tau_GB)", - "Rho_GB = Rho_GB+dt*(( - Rho_GB*(1-Rho_GB)*(rho_star_GB-Rho_GB) + potentiate_GB*gamma_p_GB*(1-Rho_GB) - depress_GB*gamma_d_GB*Rho_GB ) / ((1e3)*tau_GB))", + "DRho_GB = ( - Rho_GB*(1-Rho_GB)*(rho_star_GB-Rho_GB) + potentiate_GB*gamma_p_GB*(1-Rho_GB) - depress_GB*gamma_d_GB*Rho_GB ) / ((1e3)*tau_GB)", "euler" }, { "GluSynapse.mod", "Use_GB' = (Use_d_GB + Rho_GB*(Use_p_GB-Use_d_GB) - Use_GB) / ((1e3)*tau_Use_GB)", - "Use_GB = Use_GB+dt*((Use_d_GB + Rho_GB*(Use_p_GB-Use_d_GB) - Use_GB) / ((1e3)*tau_Use_GB))", + "DUse_GB = (Use_d_GB + Rho_GB*(Use_p_GB-Use_d_GB) - Use_GB) / ((1e3)*tau_Use_GB)", "euler" }, @@ -1620,14 +1620,14 @@ std::vector const diff_eq_constructs{ { "syn_bip_gan.mod", "s' = (s_inf-s)/((1-s_inf)*tau*s)", - "s = s+dt*((s_inf-s)/((1-s_inf)*tau*s))", + "Ds = (s_inf-s)/((1-s_inf)*tau*s)", "euler" }, { "syn_rod_bip.mod", "s' = (s_inf-s)/((1-s_inf)*tau*s)", - "s = s+dt*((s_inf-s)/((1-s_inf)*tau*s))", + "Ds = (s_inf-s)/((1-s_inf)*tau*s)", "euler" },