|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: BSD-3-Clause |
| 3 | +# See: https://spdx.org/licenses/ |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from typing import Any, Dict |
| 7 | +from lava.proc.sdn.models import AbstractSigmaDeltaModel |
| 8 | +from lava.magma.core.decorator import implements, requires, tag |
| 9 | +from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol |
| 10 | +from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer |
| 11 | +from lava.magma.core.resources import CPU |
| 12 | +from lava.magma.core.model.py.ports import PyInPort, PyOutPort |
| 13 | +from lava.magma.core.model.py.type import LavaPyType |
| 14 | +from lava.magma.core.model.sub.model import AbstractSubProcessModel |
| 15 | +from lava.proc.sparse.process import Sparse |
| 16 | + |
| 17 | + |
| 18 | +class AbstractSigmaS4dDeltaModel(AbstractSigmaDeltaModel): |
| 19 | + a_in = None |
| 20 | + s_out = None |
| 21 | + |
| 22 | + # SigmaDelta Variables |
| 23 | + vth = None |
| 24 | + sigma = None |
| 25 | + act = None |
| 26 | + residue = None |
| 27 | + error = None |
| 28 | + state_exp = None |
| 29 | + bias = None |
| 30 | + |
| 31 | + # S4 Variables |
| 32 | + a = None |
| 33 | + b = None |
| 34 | + c = None |
| 35 | + s4_state = None |
| 36 | + s4_exp = None |
| 37 | + |
| 38 | + def __init__(self, proc_params: Dict[str, Any]) -> None: |
| 39 | + """ |
| 40 | + Sigma delta neuron model that implements S4D |
| 41 | + (as described by Gu et al., 2022) dynamics as its activation function. |
| 42 | +
|
| 43 | + Relevant parameters in proc_params |
| 44 | + -------------------------- |
| 45 | + a: np.ndarray |
| 46 | + Diagonal elements of the state matrix of the S4D model. |
| 47 | + b: np.ndarray |
| 48 | + Diagonal elements of the input matrix of the S4D model. |
| 49 | + c: np.ndarray |
| 50 | + Diagonal elements of the output matrix of the S4D model. |
| 51 | + s4_state: np.ndarray |
| 52 | + State vector of the S4D model. |
| 53 | + """ |
| 54 | + super().__init__(proc_params) |
| 55 | + self.a = self.proc_params['a'] |
| 56 | + self.b = self.proc_params['b'] |
| 57 | + self.c = self.proc_params['c'] |
| 58 | + self.s4_state = self.proc_params['s4_state'] |
| 59 | + |
| 60 | + def activation_dynamics(self, sigma_data: np.ndarray) -> np.ndarray: |
| 61 | + """Sigma Delta activation dynamics. Performs S4D dynamics. |
| 62 | +
|
| 63 | + This function simulates the behavior of a linear time-invariant system |
| 64 | + with diagonalized state-space representation. |
| 65 | + (For reference see Gu et al., 2022) |
| 66 | +
|
| 67 | + The state-space equations are given by: |
| 68 | + s4_state_{k+1} = A * s4_state_k + B * input_k |
| 69 | + act_k = C * s4_state_k |
| 70 | +
|
| 71 | + where: |
| 72 | + - s4_state_k is the state vector at time step k, |
| 73 | + - input_k is the input vector at time step k, |
| 74 | + - act_k is the output vector at time step k, |
| 75 | + - A is the diagonal state matrix, |
| 76 | + - B is the diagonal input matrix, |
| 77 | + - C is the diagonal output matrix. |
| 78 | +
|
| 79 | + The function computes the next output step of the |
| 80 | + system for the given input signal. |
| 81 | +
|
| 82 | + Parameters |
| 83 | + ---------- |
| 84 | + sigma_data: np.ndarray |
| 85 | + sigma decoded data |
| 86 | +
|
| 87 | + Returns |
| 88 | + ------- |
| 89 | + np.ndarray |
| 90 | + activation output |
| 91 | + """ |
| 92 | + |
| 93 | + self.s4_state = self.s4_state * self.a + sigma_data * self.b |
| 94 | + act = self.c * self.s4_state * 2 |
| 95 | + return act |
| 96 | + |
| 97 | + |
| 98 | +@implements(proc=SigmaS4dDelta, protocol=LoihiProtocol) |
| 99 | +@requires(CPU) |
| 100 | +@tag('floating_pt') |
| 101 | +class PySigmaS4dDeltaModelFloat(AbstractSigmaS4dDeltaModel): |
| 102 | + """Floating point implementation of SigmaS4dDelta neuron.""" |
| 103 | + a_in = LavaPyType(PyInPort.VEC_DENSE, float) |
| 104 | + s_out = LavaPyType(PyOutPort.VEC_DENSE, float) |
| 105 | + |
| 106 | + vth: np.ndarray = LavaPyType(np.ndarray, float) |
| 107 | + sigma: np.ndarray = LavaPyType(np.ndarray, float) |
| 108 | + act: np.ndarray = LavaPyType(np.ndarray, float) |
| 109 | + residue: np.ndarray = LavaPyType(np.ndarray, float) |
| 110 | + error: np.ndarray = LavaPyType(np.ndarray, float) |
| 111 | + bias: np.ndarray = LavaPyType(np.ndarray, float) |
| 112 | + |
| 113 | + state_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) |
| 114 | + cum_error: np.ndarray = LavaPyType(np.ndarray, bool, precision=1) |
| 115 | + spike_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) |
| 116 | + s4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) |
| 117 | + |
| 118 | + # S4 vaiables |
| 119 | + s4_state: np.ndarray = LavaPyType(np.ndarray, float) |
| 120 | + a: np.ndarray = LavaPyType(np.ndarray, float) |
| 121 | + b: np.ndarray = LavaPyType(np.ndarray, float) |
| 122 | + c: np.ndarray = LavaPyType(np.ndarray, float) |
| 123 | + |
| 124 | + def run_spk(self) -> None: |
| 125 | + # Receive synaptic input |
| 126 | + a_in_data = self.a_in.recv() |
| 127 | + s_out = self.dynamics(a_in_data) |
| 128 | + self.s_out.send(s_out) |
| 129 | + |
| 130 | + |
| 131 | +@implements(proc=SigmaS4dDeltaLayer, protocol=LoihiProtocol) |
| 132 | +class SubDenseLayerModel(AbstractSubProcessModel): |
| 133 | + def __init__(self, proc): |
| 134 | + """Builds (Sparse -> S4D -> Sparse) connection of the process.""" |
| 135 | + conn_weights = proc.proc_params.get("conn_weights") |
| 136 | + shape = proc.proc_params.get("shape") |
| 137 | + state_exp = proc.proc_params.get("state_exp") |
| 138 | + num_message_bits = proc.proc_params.get("num_message_bits") |
| 139 | + s4_exp = proc.proc_params.get("s4_exp") |
| 140 | + d_states = proc.proc_params.get("d_states") |
| 141 | + a = proc.proc_params.get("a") |
| 142 | + b = proc.proc_params.get("b") |
| 143 | + c = proc.proc_params.get("c") |
| 144 | + vth = proc.proc_params.get("vth") |
| 145 | + |
| 146 | + # Instantiate processes |
| 147 | + self.sparse1 = Sparse(weights=conn_weights.T, weight_exp=state_exp, |
| 148 | + num_message_bits=num_message_bits) |
| 149 | + self.sigma_S4d_delta = SigmaS4dDelta(shape=(shape[0] * d_states,), |
| 150 | + vth=vth, |
| 151 | + state_exp=state_exp, |
| 152 | + s4_exp=s4_exp, |
| 153 | + a=a, |
| 154 | + b=b, |
| 155 | + c=c) |
| 156 | + self.sparse2 = Sparse(weights=conn_weights, weight_exp=state_exp, |
| 157 | + num_message_bits=num_message_bits) |
| 158 | + |
| 159 | + # Make connections Sparse -> SigmaS4Delta -> Sparse |
| 160 | + proc.in_ports.s_in.connect(self.sparse1.in_ports.s_in) |
| 161 | + self.sparse1.out_ports.a_out.connect(self.sigma_S4d_delta.in_ports.a_in) |
| 162 | + self.sigma_S4d_delta.out_ports.s_out.connect(self.sparse2.s_in) |
| 163 | + self.sparse2.out_ports.a_out.connect(proc.out_ports.a_out) |
| 164 | + |
| 165 | + # Set aliases |
| 166 | + proc.vars.a.alias(self.sigma_S4d_delta.vars.a) |
| 167 | + proc.vars.b.alias(self.sigma_S4d_delta.vars.b) |
| 168 | + proc.vars.c.alias(self.sigma_S4d_delta.vars.c) |
| 169 | + proc.vars.s4_state.alias(self.sigma_S4d_delta.vars.s4_state) |
0 commit comments