diff --git a/src/lava/lib/optimization/solvers/generic/solution_receiver/models.py b/src/lava/lib/optimization/solvers/generic/solution_receiver/models.py index e9949e82..5c1eda10 100644 --- a/src/lava/lib/optimization/solvers/generic/solution_receiver/models.py +++ b/src/lava/lib/optimization/solvers/generic/solution_receiver/models.py @@ -136,19 +136,62 @@ def __init__(self, proc): connection_config = proc.proc_params.get("connection_config") - weights_state_in = self._get_input_weights( + weights_state_in_0 = self._get_input_weights( num_vars=num_bin_variables, num_spike_int=num_spike_integrators, num_vars_per_int=num_message_bits, + weight_exp=0 ) - self.synapses_state_in = Sparse( - weights=weights_state_in, + self.synapses_state_in_0 = Sparse( + weights=weights_state_in_0, #sign_mode=SignMode.EXCITATORY, num_weight_bits=8, num_message_bits=num_message_bits, weight_exp=0, ) + weights_state_in_1 = self._get_input_weights( + num_vars=num_bin_variables, + num_spike_int=num_spike_integrators, + num_vars_per_int=num_message_bits, + weight_exp=8 + ) + self.synapses_state_in_1 = Sparse( + weights=weights_state_in_1, + #sign_mode=SignMode.EXCITATORY, + num_weight_bits=8, + num_message_bits=num_message_bits, + weight_exp=8, + ) + + weights_state_in_2 = self._get_input_weights( + num_vars=num_bin_variables, + num_spike_int=num_spike_integrators, + num_vars_per_int=num_message_bits, + weight_exp=16 + ) + self.synapses_state_in_2 = Sparse( + weights=weights_state_in_2, + #sign_mode=SignMode.EXCITATORY, + num_weight_bits=8, + num_message_bits=num_message_bits, + weight_exp=16, + ) + + weights_state_in_3 = self._get_input_weights( + num_vars=num_bin_variables, + num_spike_int=num_spike_integrators, + num_vars_per_int=num_message_bits, + weight_exp=24 + ) + self.synapses_state_in_3 = Sparse( + weights=weights_state_in_3, + #sign_mode=SignMode.EXCITATORY, + num_weight_bits=8, + num_message_bits=num_message_bits, + weight_exp=24, + ) + #CAREFUL! Weights are negated here, since CostIn will always be < 0 # but SpikeIntegrators only deal with positive numbers. # This is accounted for in self._decompress_state() @@ -186,12 +229,18 @@ def __init__(self, proc): ) # Connect the parent InPort to the InPort of the child-Process. - proc.in_ports.states_in.connect(self.synapses_state_in.s_in) + proc.in_ports.states_in.connect(self.synapses_state_in_0.s_in) + proc.in_ports.states_in.connect(self.synapses_state_in_1.s_in) + proc.in_ports.states_in.connect(self.synapses_state_in_2.s_in) + proc.in_ports.states_in.connect(self.synapses_state_in_3.s_in) proc.in_ports.cost_in.connect(self.synapses_cost_in.s_in) proc.in_ports.timestep_in.connect(self.synapses_timestep_in.s_in) # Connect intermediate ports - self.synapses_state_in.a_out.connect(self.spike_integrators.a_in) + self.synapses_state_in_0.a_out.connect(self.spike_integrators.a_in) + self.synapses_state_in_1.a_out.connect(self.spike_integrators.a_in) + self.synapses_state_in_2.a_out.connect(self.spike_integrators.a_in) + self.synapses_state_in_3.a_out.connect(self.spike_integrators.a_in) self.synapses_cost_in.a_out.connect(self.spike_integrators.a_in) self.synapses_timestep_in.a_out.connect(self.spike_integrators.a_in) @@ -204,7 +253,7 @@ def __init__(self, proc): proc.vars.best_cost.alias(self.solution_receiver.best_cost) @staticmethod - def _get_input_weights(num_vars, num_spike_int, num_vars_per_int): + def _get_input_weights(num_vars, num_spike_int, num_vars_per_int, weight_exp): """To be verified. Deprecated due to efficiency""" weights = np.zeros((num_spike_int, num_vars), dtype=np.int8) @@ -212,13 +261,14 @@ def _get_input_weights(num_vars, num_spike_int, num_vars_per_int): #print(f"{num_spike_int=}") #print(f"{num_vars_per_int=}") for spike_integrator in range(2, num_spike_int - 1): - variable_start = num_vars_per_int*spike_integrator + variable_start = 32 * (spike_integrator - 2) + weight_exp weights[spike_integrator, variable_start:variable_start + - num_vars_per_int] = 1 + 9] = np.power(2, np.arange(8)) # The last spike integrator might be connected by less than # num_vars_per_int neurons # This happens when mod(num_variables, num_vars_per_int) != 0 - weights[-1, num_vars_per_int*(num_spike_int - 3):] = 1 + variable_start = 32 * (num_spike_int - 3) + weight_exp + weights[-1, variable_start:] = np.power(2, np.arange(weights.shape[1]-variable_start)) #print("=" * 20) #print(f"{weights=}")