Skip to content

Commit c9f7646

Browse files
committed
incorporate reviews
1 parent 4462206 commit c9f7646

File tree

5 files changed

+287
-263
lines changed

5 files changed

+287
-263
lines changed

src/lava/proc/s4d/models.py

Lines changed: 87 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
# Copyright (C) 2022 Intel Corporation
1+
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: BSD-3-Clause
33
# See: https://spdx.org/licenses/
44

5-
from lava.proc.sdn.models import AbstractSigmaDeltaModel
5+
import numpy as np
66
from typing import Any, Dict
7+
from lava.proc.sdn.models import AbstractSigmaDeltaModel
78
from lava.magma.core.decorator import implements, requires, tag
8-
import numpy as np
99
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
10-
from lava.proc.s4d.process import SigmaS4Delta, SigmaS4DeltaLayer
10+
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer
1111
from lava.magma.core.resources import CPU
1212
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
1313
from lava.magma.core.model.py.type import LavaPyType
1414
from lava.magma.core.model.sub.model import AbstractSubProcessModel
1515
from lava.proc.sparse.process import Sparse
1616

1717

18-
class AbstractSigmaS4DeltaModel(AbstractSigmaDeltaModel):
18+
class AbstractSigmaS4dDeltaModel(AbstractSigmaDeltaModel):
1919
a_in = None
2020
s_out = None
2121

@@ -29,62 +29,87 @@ class AbstractSigmaS4DeltaModel(AbstractSigmaDeltaModel):
2929
bias = None
3030

3131
# S4 Variables
32-
A = None
33-
B = None
34-
C = None
35-
S4state = None
36-
S4_exp = None
32+
a = None
33+
b = None
34+
c = None
35+
s4_state = None
36+
s4_exp = None
3737

3838
def __init__(self, proc_params: Dict[str, Any]) -> None:
39-
super().__init__(proc_params)
40-
self.A = self.proc_params['A']
41-
self.B = self.proc_params['B']
42-
self.C = self.proc_params['C']
43-
self.S4state = self.proc_params['S4state']
44-
45-
def activation_dynamics(self, sigma_data: np.ndarray) -> np.ndarray:
46-
"""Sigma Delta activation dynamics. Performs S4D dynamics.
39+
"""
40+
Sigma delta neuron model that implements S4D
41+
(as described by Gu et al., 2022) dynamics as its activation function.
4742
4843
Parameters
4944
----------
50-
sigma_data: np.ndarray
51-
sigma decoded data
45+
shape: Tuple
46+
Shape of the sigma process.
47+
vth: int or float
48+
Threshold of the delta encoder.
49+
a: np.ndarray
50+
Diagonal elements of the state matrix of the S4D model.
51+
b: np.ndarray
52+
Diagonal elements of the input matrix of the S4D model.
53+
c: np.ndarray
54+
Diagonal elements of the output matrix of the S4D model.
55+
state_exp: int
56+
Scaling exponent with base 2 for the reconstructed sigma variables.
57+
Note: This should only be used for nc models.
58+
Default is 0.
59+
s4_exp: int
60+
Scaling exponent with base 2 for the S4 state variables.
61+
Note: This should only be used for nc models.
62+
Default is 0.
63+
"""
64+
super().__init__(proc_params)
65+
self.a = self.proc_params['a']
66+
self.b = self.proc_params['b']
67+
self.c = self.proc_params['c']
68+
self.s4_state = self.proc_params['s4_state']
5269

53-
Returns
54-
-------
55-
np.ndarray
56-
activation output
70+
def activation_dynamics(self, sigma_data: np.ndarray) -> np.ndarray:
71+
"""Sigma Delta activation dynamics. Performs S4D dynamics.
5772
58-
Notes
59-
-----
6073
This function simulates the behavior of a linear time-invariant system
61-
with diagonalized state-space representation. (S4D)
74+
with diagonalized state-space representation.
75+
(For reference see Gu et al., 2022)
76+
6277
The state-space equations are given by:
63-
x_{k+1} = A * x_k + B * u_k
64-
y_k = C * x_k
78+
s4_state_{k+1} = A * s4_state_k + B * input_k
79+
act_k = C * s4_state_k
6580
6681
where:
67-
- x_k is the state vector at time step k,
68-
- u_k is the input vector at time step k,
69-
- y_k is the output vector at time step k,
82+
- s4_state_k is the state vector at time step k,
83+
- input_k is the input vector at time step k,
84+
- act_k is the output vector at time step k,
7085
- A is the diagonal state matrix,
7186
- B is the diagonal input matrix,
7287
- C is the diagonal output matrix.
7388
7489
The function computes the next output step of the
7590
system for the given input signal.
91+
92+
Parameters
93+
----------
94+
sigma_data: np.ndarray
95+
sigma decoded data
96+
97+
Returns
98+
-------
99+
np.ndarray
100+
activation output
76101
"""
77102

78-
self.S4state = self.S4state * self.A + sigma_data * self.B
79-
act = self.C * self.S4state * 2
103+
self.s4_state = self.s4_state * self.a + sigma_data * self.b
104+
act = self.c * self.s4_state * 2
80105
return act
81106

82107

83-
@implements(proc=SigmaS4Delta, protocol=LoihiProtocol)
108+
@implements(proc=SigmaS4dDelta, protocol=LoihiProtocol)
84109
@requires(CPU)
85110
@tag('floating_pt')
86-
class PySigmaS4DeltaModelFloat(AbstractSigmaS4DeltaModel):
87-
"""Floating point implementation of SigmaS4Delta neuron."""
111+
class PySigmaS4dDeltaModelFloat(AbstractSigmaS4dDeltaModel):
112+
"""Floating point implementation of SigmaS4dDelta neuron."""
88113
a_in = LavaPyType(PyInPort.VEC_DENSE, float)
89114
s_out = LavaPyType(PyOutPort.VEC_DENSE, float)
90115

@@ -98,13 +123,13 @@ class PySigmaS4DeltaModelFloat(AbstractSigmaS4DeltaModel):
98123
state_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
99124
cum_error: np.ndarray = LavaPyType(np.ndarray, bool, precision=1)
100125
spike_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
101-
S4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
126+
s4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3)
102127

103-
# S4 stuff
104-
S4state: np.ndarray = LavaPyType(np.ndarray, float)
105-
A: np.ndarray = LavaPyType(np.ndarray, float)
106-
B: np.ndarray = LavaPyType(np.ndarray, float)
107-
C: np.ndarray = LavaPyType(np.ndarray, float)
128+
# S4 vaiables
129+
s4_state: np.ndarray = LavaPyType(np.ndarray, float)
130+
a: np.ndarray = LavaPyType(np.ndarray, float)
131+
b: np.ndarray = LavaPyType(np.ndarray, float)
132+
c: np.ndarray = LavaPyType(np.ndarray, float)
108133

109134
def run_spk(self) -> None:
110135
# Receive synaptic input
@@ -113,42 +138,42 @@ def run_spk(self) -> None:
113138
self.s_out.send(s_out)
114139

115140

116-
@implements(proc=SigmaS4DeltaLayer, protocol=LoihiProtocol)
141+
@implements(proc=SigmaS4dDeltaLayer, protocol=LoihiProtocol)
117142
class SubDenseLayerModel(AbstractSubProcessModel):
118143
def __init__(self, proc):
119144
"""Builds (Sparse -> S4D -> Sparse) connection of the process."""
120145
conn_weights = proc.proc_params.get("conn_weights")
121146
shape = proc.proc_params.get("shape")
122147
state_exp = proc.proc_params.get("state_exp")
123148
num_message_bits = proc.proc_params.get("num_message_bits")
124-
S4_exp = proc.proc_params.get("S4_exp")
149+
s4_exp = proc.proc_params.get("s4_exp")
125150
d_states = proc.proc_params.get("d_states")
126-
A = proc.proc_params.get("A")
127-
B = proc.proc_params.get("B")
128-
C = proc.proc_params.get("C")
151+
a = proc.proc_params.get("a")
152+
b = proc.proc_params.get("b")
153+
c = proc.proc_params.get("c")
129154
vth = proc.proc_params.get("vth")
130155

131156
# Instantiate processes
132157
self.sparse1 = Sparse(weights=conn_weights.T, weight_exp=state_exp,
133158
num_message_bits=num_message_bits)
134-
self.sigmaS4delta = SigmaS4Delta(shape=(shape[0] * d_states,),
135-
vth=vth,
136-
state_exp=state_exp,
137-
S4_exp=S4_exp,
138-
A=A,
139-
B=B,
140-
C=C)
159+
self.sigma_S4d_delta = SigmaS4dDelta(shape=(shape[0] * d_states,),
160+
vth=vth,
161+
state_exp=state_exp,
162+
s4_exp=s4_exp,
163+
a=a,
164+
b=b,
165+
c=c)
141166
self.sparse2 = Sparse(weights=conn_weights, weight_exp=state_exp,
142167
num_message_bits=num_message_bits)
143168

144169
# Make connections Sparse -> SigmaS4Delta -> Sparse
145170
proc.in_ports.s_in.connect(self.sparse1.in_ports.s_in)
146-
self.sparse1.out_ports.a_out.connect(self.sigmaS4delta.in_ports.a_in)
147-
self.sigmaS4delta.out_ports.s_out.connect(self.sparse2.s_in)
171+
self.sparse1.out_ports.a_out.connect(self.sigma_S4d_delta.in_ports.a_in)
172+
self.sigma_S4d_delta.out_ports.s_out.connect(self.sparse2.s_in)
148173
self.sparse2.out_ports.a_out.connect(proc.out_ports.a_out)
149174

150-
# Set aliasses
151-
proc.vars.A.alias(self.sigmaS4delta.vars.A)
152-
proc.vars.B.alias(self.sigmaS4delta.vars.B)
153-
proc.vars.C.alias(self.sigmaS4delta.vars.C)
154-
proc.vars.S4state.alias(self.sigmaS4delta.vars.S4state)
175+
# Set aliases
176+
proc.vars.a.alias(self.sigma_S4d_delta.vars.a)
177+
proc.vars.b.alias(self.sigma_S4d_delta.vars.b)
178+
proc.vars.c.alias(self.sigma_S4d_delta.vars.c)
179+
proc.vars.s4_state.alias(self.sigma_S4d_delta.vars.s4_state)

0 commit comments

Comments
 (0)