Skip to content

Commit 881d87e

Browse files
authored
SigmaS4Delta Neuronmodel and Layer with Unittests (#830)
* first wokring version * S4D model cleaned * update license * fix imports * linting * incorporate reviews * update docstring
1 parent 73499c2 commit 881d87e

File tree

9 files changed

+713
-2
lines changed

9 files changed

+713
-2
lines changed

src/lava/proc/s4d/models.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
# See: https://spdx.org/licenses/
4+
5+
import numpy as np
6+
from typing import Any, Dict
7+
from lava.proc.sdn.models import AbstractSigmaDeltaModel
8+
from lava.magma.core.decorator import implements, requires, tag
9+
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
10+
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer
11+
from lava.magma.core.resources import CPU
12+
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
13+
from lava.magma.core.model.py.type import LavaPyType
14+
from lava.magma.core.model.sub.model import AbstractSubProcessModel
15+
from lava.proc.sparse.process import Sparse
16+
17+
18+
class AbstractSigmaS4dDeltaModel(AbstractSigmaDeltaModel):
19+
a_in = None
20+
s_out = None
21+
22+
# SigmaDelta Variables
23+
vth = None
24+
sigma = None
25+
act = None
26+
residue = None
27+
error = None
28+
state_exp = None
29+
bias = None
30+
31+
# S4 Variables
32+
a = None
33+
b = None
34+
c = None
35+
s4_state = None
36+
s4_exp = None
37+
38+
def __init__(self, proc_params: Dict[str, Any]) -> None:
39+
"""
40+
Sigma delta neuron model that implements S4D
41+
(as described by Gu et al., 2022) dynamics as its activation function.
42+
43+
Relevant parameters in proc_params
44+
--------------------------
45+
a: np.ndarray
46+
Diagonal elements of the state matrix of the S4D model.
47+
b: np.ndarray
48+
Diagonal elements of the input matrix of the S4D model.
49+
c: np.ndarray
50+
Diagonal elements of the output matrix of the S4D model.
51+
s4_state: np.ndarray
52+
State vector of the S4D model.
53+
"""
54+
super().__init__(proc_params)
55+
self.a = self.proc_params['a']
56+
self.b = self.proc_params['b']
57+
self.c = self.proc_params['c']
58+
self.s4_state = self.proc_params['s4_state']
59+
60+
def activation_dynamics(self, sigma_data: np.ndarray) -> np.ndarray:
61+
"""Sigma Delta activation dynamics. Performs S4D dynamics.
62+
63+
This function simulates the behavior of a linear time-invariant system
64+
with diagonalized state-space representation.
65+
(For reference see Gu et al., 2022)
66+
67+
The state-space equations are given by:
68+
s4_state_{k+1} = A * s4_state_k + B * input_k
69+
act_k = C * s4_state_k
70+
71+
where:
72+
- s4_state_k is the state vector at time step k,
73+
- input_k is the input vector at time step k,
74+
- act_k is the output vector at time step k,
75+
- A is the diagonal state matrix,
76+
- B is the diagonal input matrix,
77+
- C is the diagonal output matrix.
78+
79+
The function computes the next output step of the
80+
system for the given input signal.
81+
82+
Parameters
83+
----------
84+
sigma_data: np.ndarray
85+
sigma decoded data
86+
87+
Returns
88+
-------
89+
np.ndarray
90+
activation output
91+
"""
92+
93+
self.s4_state = self.s4_state * self.a + sigma_data * self.b
94+
act = self.c * self.s4_state * 2
95+
return act
96+
97+
98+
@implements(proc=SigmaS4dDelta, protocol=LoihiProtocol)
99+
@requires(CPU)
100+
@tag('floating_pt')
101+
class PySigmaS4dDeltaModelFloat(AbstractSigmaS4dDeltaModel):
102+
"""Floating point implementation of SigmaS4dDelta neuron."""
103+
a_in = LavaPyType(PyInPort.VEC_DENSE, float)
104+
s_out = LavaPyType(PyOutPort.VEC_DENSE, float)
105+
106+
vth: np.ndarray = LavaPyType(np.ndarray, float)
107+
sigma: np.ndarray = LavaPyType(np.ndarray, float)
108+
act: np.ndarray = LavaPyType(np.ndarray, float)
109+
residue: np.ndarray = LavaPyType(np.ndarray, float)
110+
error: np.ndarray = LavaPyType(np.ndarray, float)
111+
bias: np.ndarray = LavaPyType(np.ndarray, float)
112+
113+
state_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
114+
cum_error: np.ndarray = LavaPyType(np.ndarray, bool, precision=1)
115+
spike_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
116+
s4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
117+
118+
# S4 vaiables
119+
s4_state: np.ndarray = LavaPyType(np.ndarray, float)
120+
a: np.ndarray = LavaPyType(np.ndarray, float)
121+
b: np.ndarray = LavaPyType(np.ndarray, float)
122+
c: np.ndarray = LavaPyType(np.ndarray, float)
123+
124+
def run_spk(self) -> None:
125+
# Receive synaptic input
126+
a_in_data = self.a_in.recv()
127+
s_out = self.dynamics(a_in_data)
128+
self.s_out.send(s_out)
129+
130+
131+
@implements(proc=SigmaS4dDeltaLayer, protocol=LoihiProtocol)
132+
class SubDenseLayerModel(AbstractSubProcessModel):
133+
def __init__(self, proc):
134+
"""Builds (Sparse -> S4D -> Sparse) connection of the process."""
135+
conn_weights = proc.proc_params.get("conn_weights")
136+
shape = proc.proc_params.get("shape")
137+
state_exp = proc.proc_params.get("state_exp")
138+
num_message_bits = proc.proc_params.get("num_message_bits")
139+
s4_exp = proc.proc_params.get("s4_exp")
140+
d_states = proc.proc_params.get("d_states")
141+
a = proc.proc_params.get("a")
142+
b = proc.proc_params.get("b")
143+
c = proc.proc_params.get("c")
144+
vth = proc.proc_params.get("vth")
145+
146+
# Instantiate processes
147+
self.sparse1 = Sparse(weights=conn_weights.T, weight_exp=state_exp,
148+
num_message_bits=num_message_bits)
149+
self.sigma_S4d_delta = SigmaS4dDelta(shape=(shape[0] * d_states,),
150+
vth=vth,
151+
state_exp=state_exp,
152+
s4_exp=s4_exp,
153+
a=a,
154+
b=b,
155+
c=c)
156+
self.sparse2 = Sparse(weights=conn_weights, weight_exp=state_exp,
157+
num_message_bits=num_message_bits)
158+
159+
# Make connections Sparse -> SigmaS4Delta -> Sparse
160+
proc.in_ports.s_in.connect(self.sparse1.in_ports.s_in)
161+
self.sparse1.out_ports.a_out.connect(self.sigma_S4d_delta.in_ports.a_in)
162+
self.sigma_S4d_delta.out_ports.s_out.connect(self.sparse2.s_in)
163+
self.sparse2.out_ports.a_out.connect(proc.out_ports.a_out)
164+
165+
# Set aliases
166+
proc.vars.a.alias(self.sigma_S4d_delta.vars.a)
167+
proc.vars.b.alias(self.sigma_S4d_delta.vars.b)
168+
proc.vars.c.alias(self.sigma_S4d_delta.vars.c)
169+
proc.vars.s4_state.alias(self.sigma_S4d_delta.vars.s4_state)

src/lava/proc/s4d/process.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
# See: https://spdx.org/licenses/
4+
5+
import typing as ty
6+
import numpy as np
7+
from lava.magma.core.process.process import AbstractProcess
8+
from lava.magma.core.process.variable import Var
9+
from lava.magma.core.process.ports.ports import InPort, OutPort
10+
from lava.proc.sdn.process import ActivationMode, SigmaDelta
11+
12+
13+
class SigmaS4dDelta(SigmaDelta, AbstractProcess):
14+
def __init__(
15+
self,
16+
shape: ty.Tuple[int, ...],
17+
vth: ty.Union[int, float],
18+
a: float,
19+
b: float,
20+
c: float,
21+
state_exp: ty.Optional[int] = 0,
22+
s4_exp: ty.Optional[int] = 0) -> None:
23+
"""
24+
Sigma delta neuron process that implements S4D (described by
25+
Gu et al., 2022) dynamics as its activation function.
26+
27+
This process simulates the behavior of a linear time-invariant system
28+
with diagonal state-space representation.
29+
The state-space equations are given by:
30+
s4_state_{k+1} = A * s4_state_k + B * inp_k
31+
act_k = C * s4_state_k
32+
33+
where:
34+
- s4_state_k is the state vector at time step k,
35+
- inp_k is the input vector at time step k,
36+
- act_k is the output vector at time step k,
37+
- A is the diagonal state matrix,
38+
- B is the diagonal input matrix,
39+
- C is the diagonal output matrix.
40+
41+
Parameters
42+
----------
43+
shape: Tuple
44+
Shape of the sigma process.
45+
vth: int or float
46+
Threshold of the delta encoder.
47+
a: np.ndarray
48+
Diagonal elements of the state matrix of the S4D model.
49+
b: np.ndarray
50+
Diagonal elements of the input matrix of the S4D model.
51+
c: np.ndarray
52+
Diagonal elements of the output matrix of the S4D model.
53+
state_exp: int
54+
Scaling exponent with base 2 for the reconstructed sigma variables.
55+
Note: This should only be used for nc models.
56+
Default is 0.
57+
s4_exp: int
58+
Scaling exponent with base 2 for the S4 state variables.
59+
Note: This should only be used for nc models.
60+
Default is 0.
61+
"""
62+
63+
super().__init__(shape=shape,
64+
vth=vth,
65+
a=a,
66+
b=b,
67+
c=c,
68+
s4_state=0,
69+
state_exp=state_exp,
70+
s4_exp=s4_exp)
71+
72+
# Variables for S4
73+
self.a = Var(shape=shape, init=a)
74+
self.b = Var(shape=shape, init=b)
75+
self.c = Var(shape=shape, init=c)
76+
self.s4_state = Var(shape=shape, init=0)
77+
self.s4_exp = Var(shape=(1,), init=s4_exp)
78+
79+
80+
class SigmaS4dDeltaLayer(AbstractProcess):
81+
def __init__(
82+
self,
83+
shape: ty.Tuple[int, ...],
84+
vth: ty.Union[int, float],
85+
a: float,
86+
b: float,
87+
c: float,
88+
d_states: ty.Optional[int] = 1,
89+
s4_exp: ty.Optional[int] = 0,
90+
num_message_bits: ty.Optional[int] = 24,
91+
state_exp: ty.Optional[int] = 0) -> None:
92+
"""
93+
Combines S4D neuron with Sparse Processes that allow for multiple
94+
d_states.
95+
96+
Connectivity: Sparse -> SigmaS4dDelta -> Sparse.
97+
Relieves user from computing required connection weights for multiple
98+
d_states.
99+
100+
Parameters
101+
----------
102+
shape: Tuple
103+
Shape of the sigma process.
104+
vth: int or float
105+
Threshold of the delta encoder.
106+
a: np.ndarray
107+
Diagonal elements of the state matrix of the S4D model.
108+
b: np.ndarray
109+
Diagonal elements of the input matrix of the S4D model.
110+
c: np.ndarray
111+
Diagonal elements of the output matrix of the S4D model.
112+
d_states: int
113+
Number of hidden states of the S4D model.
114+
Default is 1.
115+
state_exp: int
116+
Scaling exponent with base 2 for the reconstructed sigma variables.
117+
Note: Only relevant for nc model.
118+
Default is 0.
119+
num_message_bits: int
120+
Number of message bits to be used in Sparse connection processes.
121+
Note: Only relevant for nc model.
122+
s4_exp: int
123+
Scaling exponent with base 2 for the S4 state variables.
124+
Note: Only relevant for nc model.
125+
Default is 0.
126+
"""
127+
128+
# Automatically takes care of expansion and reduction of dimensionality
129+
# for multiple hidden states (d_states)
130+
conn_weights = np.kron(np.eye(shape[0]), np.ones(d_states))
131+
s4_state = 0
132+
super().__init__(shape=shape,
133+
vth=vth,
134+
a=a,
135+
b=b,
136+
c=c,
137+
s4_exp=s4_exp,
138+
s4_state=s4_state,
139+
conn_weights=conn_weights,
140+
num_message_bits=num_message_bits,
141+
d_states=d_states,
142+
state_exp=state_exp,
143+
act_mode=ActivationMode.UNIT)
144+
145+
# Ports
146+
self.s_in = InPort(shape=shape)
147+
self.a_out = OutPort(shape=shape)
148+
149+
# General variables
150+
self.state_exp = Var(shape=(1,), init=state_exp)
151+
152+
# Variables for S4
153+
self.a = Var(shape=(shape[0] * d_states,), init=a)
154+
self.b = Var(shape=(shape[0] * d_states,), init=b)
155+
self.c = Var(shape=(shape[0] * d_states,), init=c)
156+
self.s4_state = Var(shape=(shape[0] * d_states,), init=0)
157+
self.S4_exp = Var(shape=(1,), init=s4_exp)
158+
159+
# Variables for connecting Dense processes
160+
# Project input_dim to input_dim * d_states
161+
self.conn_weights = Var(shape=shape, init=conn_weights)
162+
self.num_message_bits = Var(shape=(1,), init=num_message_bits)
163+
164+
@property
165+
def shape(self) -> ty.Tuple[int, ...]:
166+
"""Return shape of the Process."""
167+
return self.proc_params['shape']

src/lava/proc/sdn/process.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def __init__(
126126
act_mode: ty.Optional[ActivationMode] = ActivationMode.RELU,
127127
cum_error: ty.Optional[bool] = False,
128128
spike_exp: ty.Optional[int] = 0,
129-
state_exp: ty.Optional[int] = 0) -> None:
129+
state_exp: ty.Optional[int] = 0,
130+
**kwargs) -> None:
130131
"""Sigma delta neuron process. At the moment only ReLu activation is
131132
supported. Spike mechanism based on accumulated error is also supported.
132133
@@ -173,7 +174,7 @@ def __init__(
173174
"""
174175
super().__init__(shape=shape, vth=vth, bias=bias,
175176
act_mode=act_mode, cum_error=cum_error,
176-
spike_exp=spike_exp, state_exp=state_exp)
177+
spike_exp=spike_exp, state_exp=state_exp, **kwargs)
177178
# scaling factor for fixed precision scaling
178179
vth = vth * (1 << (spike_exp + state_exp))
179180
bias = bias * (1 << (spike_exp + state_exp))

tests/lava/proc/s4d/s4d_A.dat.npy

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:9e4c9f5f11d3139b86ccdfe3d8e2179566dc8ac8da31a4a23951dba174425663
3+
size 5248

tests/lava/proc/s4d/s4d_B.dat.npy

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:4b67f99c0a172c862abfc65b3aabeba9bc91fe1f2254d0df066d19c9b3e3b8fe
3+
size 5248

tests/lava/proc/s4d/s4d_C.dat.npy

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:055720b65a2eb0bf0043989b1a078cc028f7a105a0cb394ba03cdbf3adac8ac1
3+
size 5248

0 commit comments

Comments
 (0)