Skip to content

Commit

Permalink
Add synapses NEBM -> SpikeIntegrators
Browse files Browse the repository at this point in the history
  • Loading branch information
AlessandroPierro committed Jan 12, 2024
1 parent 8bcfbe9 commit 6dc1607
Showing 1 changed file with 59 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,62 @@ def __init__(self, proc):

connection_config = proc.proc_params.get("connection_config")

weights_state_in = self._get_input_weights(
weights_state_in_0 = self._get_input_weights(
num_vars=num_bin_variables,
num_spike_int=num_spike_integrators,
num_vars_per_int=num_message_bits,
weight_exp=0
)
self.synapses_state_in = Sparse(
weights=weights_state_in,
self.synapses_state_in_0 = Sparse(
weights=weights_state_in_0,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=0,
)

weights_state_in_1 = self._get_input_weights(
num_vars=num_bin_variables,
num_spike_int=num_spike_integrators,
num_vars_per_int=num_message_bits,
weight_exp=8
)
self.synapses_state_in_1 = Sparse(
weights=weights_state_in_1,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=8,
)

weights_state_in_2 = self._get_input_weights(
num_vars=num_bin_variables,
num_spike_int=num_spike_integrators,
num_vars_per_int=num_message_bits,
weight_exp=16
)
self.synapses_state_in_2 = Sparse(
weights=weights_state_in_2,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=16,
)

weights_state_in_3 = self._get_input_weights(
num_vars=num_bin_variables,
num_spike_int=num_spike_integrators,
num_vars_per_int=num_message_bits,
weight_exp=24
)
self.synapses_state_in_3 = Sparse(
weights=weights_state_in_3,
#sign_mode=SignMode.EXCITATORY,
num_weight_bits=8,
num_message_bits=num_message_bits,
weight_exp=24,
)

#CAREFUL! Weights are negated here, since CostIn will always be < 0
# but SpikeIntegrators only deal with positive numbers.
# This is accounted for in self._decompress_state()
Expand Down Expand Up @@ -186,12 +229,18 @@ def __init__(self, proc):
)

# Connect the parent InPort to the InPort of the child-Process.
proc.in_ports.states_in.connect(self.synapses_state_in.s_in)
proc.in_ports.states_in.connect(self.synapses_state_in_0.s_in)
proc.in_ports.states_in.connect(self.synapses_state_in_1.s_in)
proc.in_ports.states_in.connect(self.synapses_state_in_2.s_in)
proc.in_ports.states_in.connect(self.synapses_state_in_3.s_in)
proc.in_ports.cost_in.connect(self.synapses_cost_in.s_in)
proc.in_ports.timestep_in.connect(self.synapses_timestep_in.s_in)

# Connect intermediate ports
self.synapses_state_in.a_out.connect(self.spike_integrators.a_in)
self.synapses_state_in_0.a_out.connect(self.spike_integrators.a_in)
self.synapses_state_in_1.a_out.connect(self.spike_integrators.a_in)
self.synapses_state_in_2.a_out.connect(self.spike_integrators.a_in)
self.synapses_state_in_3.a_out.connect(self.spike_integrators.a_in)
self.synapses_cost_in.a_out.connect(self.spike_integrators.a_in)
self.synapses_timestep_in.a_out.connect(self.spike_integrators.a_in)

Expand All @@ -204,21 +253,22 @@ def __init__(self, proc):
proc.vars.best_cost.alias(self.solution_receiver.best_cost)

@staticmethod
def _get_input_weights(num_vars, num_spike_int, num_vars_per_int):
def _get_input_weights(num_vars, num_spike_int, num_vars_per_int, weight_exp):
"""To be verified. Deprecated due to efficiency"""

weights = np.zeros((num_spike_int, num_vars), dtype=np.int8)
#print(f"{num_vars=}")
#print(f"{num_spike_int=}")
#print(f"{num_vars_per_int=}")
for spike_integrator in range(2, num_spike_int - 1):
variable_start = num_vars_per_int*spike_integrator
variable_start = 32 * (spike_integrator - 2) + weight_exp
weights[spike_integrator, variable_start:variable_start +
num_vars_per_int] = 1
9] = np.power(2, np.arange(8))
# The last spike integrator might be connected by less than
# num_vars_per_int neurons
# This happens when mod(num_variables, num_vars_per_int) != 0
weights[-1, num_vars_per_int*(num_spike_int - 3):] = 1
variable_start = 32 * (num_spike_int - 3) + weight_exp
weights[-1, variable_start:] = np.power(2, np.arange(weights.shape[1]-variable_start))

#print("=" * 20)
#print(f"{weights=}")
Expand Down

0 comments on commit 6dc1607

Please sign in to comment.