From 37c19a29526637f9952161f3af119bd78148d678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Gil?= Date: Thu, 20 Jun 2024 14:20:15 +0100 Subject: [PATCH] Fix: subthreshold dynamics equation of refractory lif (#842) * Fix: subthreshold dynamics equation of refractory lif * Fix: RefractoryLIF unit test to test the voltage dynamics --- src/lava/proc/lif/models.py | 5 ++--- tests/lava/proc/lif/test_models.py | 7 ++++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lava/proc/lif/models.py b/src/lava/proc/lif/models.py index b58ffe7db..72f02be9b 100644 --- a/src/lava/proc/lif/models.py +++ b/src/lava/proc/lif/models.py @@ -477,9 +477,8 @@ def subthr_dynamics(self, activation_in: np.ndarray): self.u[:] = self.u * (1 - self.du) self.u[:] += activation_in non_refractory = self.refractory_period_end < self.time_step - self.v[non_refractory] = (self.v[non_refractory] * ( - (1 - self.dv) + self.u[non_refractory]) - + self.bias_mant[non_refractory]) + self.v[non_refractory] = self.v[non_refractory] * (1 - self.dv) + ( + self.u[non_refractory] + self.bias_mant[non_refractory]) def process_spikes(self, spike_vector: np.ndarray): self.refractory_period_end[spike_vector] = (self.time_step diff --git a/tests/lava/proc/lif/test_models.py b/tests/lava/proc/lif/test_models.py index d9ec66528..1d9f29943 100644 --- a/tests/lava/proc/lif/test_models.py +++ b/tests/lava/proc/lif/test_models.py @@ -837,12 +837,13 @@ def test_float_model(self): refractory_period = 1 # Two neurons with different biases + # No Input current provided to make the voltage dependent on the bias lif_refractory = LIFRefractory(shape=(num_neurons,), - u=np.arange(num_neurons), + u=np.zeros(num_neurons), bias_mant=np.arange(num_neurons) + 1, bias_exp=np.ones( (num_neurons,), dtype=float), - vth=4., + vth=4, refractory_period=refractory_period) v_logger = io.sink.Read(buffer=num_steps) @@ -856,6 +857,6 @@ def test_float_model(self): # Voltage is expected to remain at reset level for two time steps v_expected = np.array([[1, 2, 3, 4, 0, 0, 1, 2], - [2, 0, 0, 2, 0, 0, 2, 0]], dtype=float) + [2, 4, 0, 0, 2, 4, 0, 0]], dtype=float) assert_almost_equal(v, v_expected)