Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ethernet IO changes for Demo #828

Merged
merged 10 commits into from
Jan 17, 2024
4 changes: 4 additions & 0 deletions src/lava/magma/compiler/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def map_cores(self, executable: Executable,
address.update(chips)
break
if len(address) > 1 and hasattr(var_model, "address"):
print('=' * 50)
print('Note to JOYESH from the future:')
print('Add logic to make multichip conv input work for YOLO.')
print('=' * 50)
raise ValueError("Lava Compiler doesn't support port"
"splitting currently. MultiChip "
"Not Supported ")
Expand Down
14 changes: 12 additions & 2 deletions src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ImplicitVarPort,
VarPort,
)
from lava.magma.core.process.ports.connection_config import ConnectionConfig
from lava.magma.core.process.process import AbstractProcess
from lava.magma.compiler.subcompilers.constants import SPIKE_BLOCK_CORE

Expand Down Expand Up @@ -189,7 +190,11 @@ def _create_inport_initializers(
pi.embedded_counters = \
np.arange(counter_start_idx,
counter_start_idx + num_counters, dtype=np.int32)
pi.connection_config = list(port.connection_configs.values())[0]
if port.connection_configs.values():
conn_config = list(port.connection_configs.values())[0]
else:
conn_config = ConnectionConfig()
pi.connection_config = conn_config
port_initializers.append(pi)
self._tmp_channel_map.set_port_initializer(port, pi)
else:
Expand All @@ -209,7 +214,7 @@ def _create_outport_initializers(
self, process: AbstractProcess
) -> ty.List[PortInitializer]:
port_initializers = []
for port in list(process.out_ports):
for k, port in enumerate(list(process.out_ports)):
pi = PortInitializer(
port.name,
port.shape,
Expand All @@ -218,6 +223,11 @@ def _create_outport_initializers(
self._compile_config["pypy_channel_size"],
port.get_incoming_transform_funcs(),
)
if port.connection_configs.values():
conn_config = list(port.connection_configs.values())[k]
else:
conn_config = ConnectionConfig()
pi.connection_config = conn_config
port_initializers.append(pi)
self._tmp_channel_map.set_port_initializer(port, pi)
return port_initializers
Expand Down
8 changes: 8 additions & 0 deletions src/lava/magma/compiler/var_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ class NcSpikeIOVarModel(NcVarModel):
interface: SpikeIOInterface = SpikeIOInterface.ETHERNET
spike_io_port: SpikeIOPort = SpikeIOPort.ETHERNET
spike_io_mode: SpikeIOMode = SpikeIOMode.TIME_COMPARE
ethernet_chip_id: ty.Optional[ty.Tuple[int, int, int]] = None
ethernet_chip_idx: ty.Optional[int] = None
decode_config: ty.Optional[DecodeConfig] = None
time_compare: ty.Optional[TimeCompare] = None
spike_encoder: ty.Optional[SpikeEncoder] = None


@dataclass
class NcConvSpikeInVarModel(NcSpikeIOVarModel):
# Tuple will be in the order of [atom_paylod, atom_axon, addr_idx]
region_map: ty.List[ty.List[ty.Tuple[int, int, int]]] = None
4 changes: 4 additions & 0 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,10 @@ def run_async(self) -> None:
if py_loihi_model.post_guard(self):
py_loihi_model.run_post_mgmt(self)
self.time_step += 1
# self.advance_to_time_step(self.time_step)
for port in self.py_ports:
if isinstance(port, PyOutPort):
port.advance_to_time_step(self.time_step)

py_async_model = type(
name,
Expand Down
3 changes: 3 additions & 0 deletions src/lava/magma/core/process/ports/connection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# expressly stated in the License.
from dataclasses import dataclass
from enum import IntEnum, Enum
import typing as ty


class SpikeIOInterface(IntEnum):
Expand Down Expand Up @@ -54,3 +55,5 @@ class ConnectionConfig:
spike_io_mode: SpikeIOMode = SpikeIOMode.TIME_COMPARE
num_time_buckets: int = 1 << 16
ethernet_mac_address: str = "0x90e2ba01214c"
ethernet_chip_id: ty.Optional[ty.Tuple[int, int, int]] = None
ethernet_chip_idx: ty.Optional[int] = None
26 changes: 26 additions & 0 deletions src/lava/proc/io/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.model import PyAsyncProcessModel
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyInPort, PyRefPort
from lava.magma.compiler.channels.pypychannel import PyPyChannel
Expand Down Expand Up @@ -137,6 +139,30 @@ def __del__(self) -> None:
self._pm_to_p_src_port.join()


@implements(proc=Extractor, protocol=AsyncProtocol)
@requires(CPU)
class PyLoihiExtractorModelAsync(PyAsyncProcessModel):
in_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params=proc_params)

channel_config = self.proc_params["channel_config"]
self._pm_to_p_src_port = self.proc_params["pm_to_p_src_port"]
self._pm_to_p_src_port.start()

self._send = channel_config.get_send_full_function()
self.time_step = 1

def run_async(self) -> None:
while self.time_step != self.num_steps + 1:
self._send(self._pm_to_p_src_port, self.in_port.recv())
self.time_step += 1

def __del__(self) -> None:
self._pm_to_p_src_port.join()


class VarWire(AbstractProcess):
"""VarWire allows non-Lava code, such as a third-party Python library
to tap data from a Lava Process Variable (Var) while the Lava Runtime is
Expand Down
44 changes: 44 additions & 0 deletions src/lava/proc/io/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.model import PyAsyncProcessModel
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyOutPort
from lava.magma.runtime.message_infrastructure.multiprocessing import \
Expand Down Expand Up @@ -139,3 +141,45 @@ def run_spk(self) -> None:

def __del__(self) -> None:
self._p_to_pm_dst_port.join()


@implements(proc=Injector, protocol=AsyncProtocol)
@requires(CPU)
class PyLoihiInjectorModelAsync(PyAsyncProcessModel):
"""PyAsyncProcessModel for the Injector Process."""
out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params=proc_params)

shape = self.proc_params["shape"]
channel_config = self.proc_params["channel_config"]
self._p_to_pm_dst_port = self.proc_params["p_to_pm_dst_port"]
self._p_to_pm_dst_port.start()

self._zeros = np.zeros(shape)

self._receive_when_empty = channel_config.get_receive_empty_function()
self._receive_when_not_empty = \
channel_config.get_receive_not_empty_function()
self.time_step = 1

def run_async(self) -> None:
while self.time_step != self.num_steps + 1:
self._zeros.fill(0)
elements_in_buffer = self._p_to_pm_dst_port._queue.qsize()

if elements_in_buffer == 0:
data = self._receive_when_empty(
self._p_to_pm_dst_port,
self._zeros)
else:
data = self._receive_when_not_empty(
self._p_to_pm_dst_port,
self._zeros,
elements_in_buffer)
self.out_port.send(data)
self.time_step += 1

def __del__(self) -> None:
self._p_to_pm_dst_port.join()
Loading