Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lava execution caching #840

Closed
wants to merge 12 commits into from
82 changes: 81 additions & 1 deletion src/lava/magma/compiler/channel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# See: https://spdx.org/licenses/

import itertools
import os
import pickle
import typing as ty
from collections import defaultdict
from dataclasses import dataclass
Expand All @@ -11,6 +13,7 @@
from lava.magma.compiler.utils import PortInitializer
from lava.magma.core.process.ports.ports import AbstractPort
from lava.magma.core.process.ports.ports import AbstractSrcPort, AbstractDstPort
from lava.magma.core.process.process import AbstractProcess


@dataclass(eq=True, frozen=True)
Expand All @@ -27,6 +30,9 @@ class Payload:
dst_port_initializer: PortInitializer = None


def lmt_init_id():
return -1

class ChannelMap(dict):
"""The ChannelMap is used by the SubCompilers during compilation to
communicate how they are planning to partition Processes onto their
Expand All @@ -35,7 +41,7 @@ class ChannelMap(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._initializers_lookup = dict()
self._lmt_allocation_dict: ty.Dict[int, int] = defaultdict(lambda: -1)
self._lmt_allocation_dict: ty.Dict[int, int] = defaultdict(lmt_init_id)

def __setitem__(
self, key: PortPair, value: Payload, dict_setitem=dict.__setitem__
Expand Down Expand Up @@ -125,3 +131,77 @@ def get_port_initializer(self, port):

def has_port_initializer(self, port) -> bool:
return port in self._initializers_lookup

def write_to_cache(self,
cache_object: ty.Dict[ty.Any, ty.Any],
proc_to_procname_map: ty.Dict[AbstractProcess, str]):
cache_object["lmt_allocation"] = self._lmt_allocation_dict

initializers_serializable: ty.List[ty.Tuple[str, str,
PortInitializer]] = []
port: AbstractPort
pi: PortInitializer
for port, pi in self._initializers_lookup.items():
procname = proc_to_procname_map[port.process]
if procname.startswith("Process_"):
msg = f"Unable to Cache. " \
f"Please give unique names to every process. " \
f"Violation Name: {procname=}"
raise Exception(msg)

initializers_serializable.append((procname, port.name, pi))
cache_object["initializers"] = initializers_serializable

cm_serializable: ty.List[ty.Tuple[ty.Tuple[str, str],
ty.Tuple[str, str],
Payload]] = []
port_pair: PortPair
payload: Payload
for port_pair, payload in self.items():
src_port: AbstractPort = ty.cast(AbstractPort, port_pair.src)
dst_port: AbstractPort = ty.cast(AbstractPort, port_pair.dst)
src_proc_name: str = proc_to_procname_map[src_port.process]
src_port_info = (src_proc_name, src_port.name)
dst_proc_name: str = proc_to_procname_map[dst_port.process]
dst_port_info = (dst_proc_name, dst_port.name)
if src_proc_name.startswith("Process_") or \
dst_proc_name.startswith("Process_"):
msg = f"Unable to Cache. " \
f"Please give unique names to every process. " \
f"Violation Name: {src_proc_name=} {dst_proc_name=}"
raise Exception(msg)

cm_serializable.append((src_port_info, dst_port_info, payload))
cache_object["channelmap_dict"] = cm_serializable

def read_from_cache(self,
cache_object: ty.Dict[ty.Any, ty.Any],
procname_to_proc_map: ty.Dict[str, AbstractProcess]):
self._lmt_allocation_dict = cache_object["lmt_allocation"]
initializers_serializable = cache_object["initializers"]
cm_serializable = cache_object["channelmap_dict"]

for procname, port_name, pi in initializers_serializable:
process: AbstractProcess = procname_to_proc_map[procname]
port: AbstractPort = getattr(process, port_name)
self._initializers_lookup[port] = pi

src_port_info: ty.Tuple[str, str]
dst_port_info: ty.Tuple[str, str]
payload: Payload
for src_port_info, dst_port_info, payload in cm_serializable:
src_port_process: AbstractProcess = procname_to_proc_map[
src_port_info[0]]
src: AbstractPort = getattr(src_port_process,
src_port_info[1])
dst_port_process: AbstractProcess = procname_to_proc_map[
dst_port_info[0]]
dst: AbstractPort = getattr(dst_port_process,
dst_port_info[1])
for port_pair, pld in self.items():
s, d = port_pair.src, port_pair.dst
if s.name == src.name and d.name == dst.name and \
s.process.name == src_port_process.name and \
d.process.name == dst_port_process.name:
pld.src_port_initializer = payload.src_port_initializer
pld.dst_port_initializer = payload.dst_port_initializer
49 changes: 49 additions & 0 deletions src/lava/magma/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import itertools
import logging
import os
import pickle
import typing as ty
from collections import OrderedDict, defaultdict

Expand Down Expand Up @@ -219,6 +221,31 @@ def _compile_proc_groups(
The global dict-like ChannelMap given as input but with values
updated according to partitioning done by subcompilers.
"""
procname_to_proc_map: ty.Dict[str, AbstractProcess] = {}
proc_to_procname_map: ty.Dict[AbstractProcess, str] = {}
for proc_group in proc_groups:
for p in proc_group:
procname_to_proc_map[p.name] = p
proc_to_procname_map[p] = p.name

if self._compile_config.get("cache", False):
cache_dir = self._compile_config["cache_dir"]
if os.path.exists(cache_dir):
with open(os.path.join(cache_dir, "cache"), "rb") as cache_file:
cache_object = pickle.load(cache_file)

proc_builders_values = cache_object["procname_to_proc_builder"]
proc_builders = {}
for proc_name, pb in proc_builders_values.items():
proc = procname_to_proc_map[proc_name]
proc_builders[proc] = pb
pb.proc_params = proc.proc_params

channel_map.read_from_cache(cache_object, procname_to_proc_map)
print(f"\nBuilders and Channel Map loaded from " \
f"Cache {cache_dir}\n")
return proc_builders, channel_map

# Create the global ChannelMap that is passed between
# SubCompilers to communicate about Channels between Processes.

Expand Down Expand Up @@ -248,6 +275,28 @@ def _compile_proc_groups(
subcompilers, channel_map
)

if self._compile_config.get("cache", False):
cache_dir = self._compile_config["cache_dir"]
os.makedirs(cache_dir)
cache_object = {}
# Validate All Processes are Named
procname_to_proc_builder = {}
for p, pb in proc_builders.items():
if p.name in procname_to_proc_builder or \
"Process_" in p.name:
msg = f"Unable to Cache. " \
f"Please give unique names to every process. " \
f"Violation Name: {p.name=}"
raise Exception(msg)
procname_to_proc_builder[p.name] = pb
pb.proc_params = None
cache_object["procname_to_proc_builder"] = procname_to_proc_builder
channel_map.write_to_cache(cache_object, proc_to_procname_map)
with open(os.path.join(cache_dir, "cache"), "wb") as cache_file:
pickle.dump(cache_object, cache_file)
for p, pb in proc_builders.items():
pb.proc_params = p.proc_params
print(f"\nBuilders and Channel Map stored to Cache {cache_dir}\n")
return proc_builders, channel_map

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions src/lava/magma/compiler/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def map_cores(self, executable: Executable,
address.update(chips)
break
if len(address) > 1 and hasattr(var_model, "address"):
print('=' * 50)
print('Note to JOYESH from the future:')
print('Add logic to make multichip conv input work for YOLO.')
print('=' * 50)
raise ValueError("Lava Compiler doesn't support port"
"splitting currently. MultiChip "
"Not Supported ")
Expand Down
14 changes: 12 additions & 2 deletions src/lava/magma/compiler/subcompilers/py/pyproc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ImplicitVarPort,
VarPort,
)
from lava.magma.core.process.ports.connection_config import ConnectionConfig
from lava.magma.core.process.process import AbstractProcess
from lava.magma.compiler.subcompilers.constants import SPIKE_BLOCK_CORE

Expand Down Expand Up @@ -189,7 +190,11 @@ def _create_inport_initializers(
pi.embedded_counters = \
np.arange(counter_start_idx,
counter_start_idx + num_counters, dtype=np.int32)
pi.connection_config = list(port.connection_configs.values())[0]
if port.connection_configs.values():
conn_config = list(port.connection_configs.values())[0]
else:
conn_config = ConnectionConfig()
pi.connection_config = conn_config
port_initializers.append(pi)
self._tmp_channel_map.set_port_initializer(port, pi)
else:
Expand All @@ -209,7 +214,7 @@ def _create_outport_initializers(
self, process: AbstractProcess
) -> ty.List[PortInitializer]:
port_initializers = []
for port in list(process.out_ports):
for k, port in enumerate(list(process.out_ports)):
pi = PortInitializer(
port.name,
port.shape,
Expand All @@ -218,6 +223,11 @@ def _create_outport_initializers(
self._compile_config["pypy_channel_size"],
port.get_incoming_transform_funcs(),
)
if port.connection_configs.values():
conn_config = list(port.connection_configs.values())[k]
else:
conn_config = ConnectionConfig()
pi.connection_config = conn_config
port_initializers.append(pi)
self._tmp_channel_map.set_port_initializer(port, pi)
return port_initializers
Expand Down
8 changes: 8 additions & 0 deletions src/lava/magma/compiler/var_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ class NcSpikeIOVarModel(NcVarModel):
interface: SpikeIOInterface = SpikeIOInterface.ETHERNET
spike_io_port: SpikeIOPort = SpikeIOPort.ETHERNET
spike_io_mode: SpikeIOMode = SpikeIOMode.TIME_COMPARE
ethernet_chip_id: ty.Optional[ty.Tuple[int, int, int]] = None
ethernet_chip_idx: ty.Optional[int] = None
decode_config: ty.Optional[DecodeConfig] = None
time_compare: ty.Optional[TimeCompare] = None
spike_encoder: ty.Optional[SpikeEncoder] = None


@dataclass
class NcConvSpikeInVarModel(NcSpikeIOVarModel):
# Tuple will be in the order of [atom_paylod, atom_axon, addr_idx]
region_map: ty.List[ty.List[ty.Tuple[int, int, int]]] = None
4 changes: 4 additions & 0 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,10 @@ def run_async(self) -> None:
if py_loihi_model.post_guard(self):
py_loihi_model.run_post_mgmt(self)
self.time_step += 1
# self.advance_to_time_step(self.time_step)
for port in self.py_ports:
if isinstance(port, PyOutPort):
port.advance_to_time_step(self.time_step)

py_async_model = type(
name,
Expand Down
3 changes: 3 additions & 0 deletions src/lava/magma/core/process/ports/connection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# expressly stated in the License.
from dataclasses import dataclass
from enum import IntEnum, Enum
import typing as ty


class SpikeIOInterface(IntEnum):
Expand Down Expand Up @@ -54,3 +55,5 @@ class ConnectionConfig:
spike_io_mode: SpikeIOMode = SpikeIOMode.TIME_COMPARE
num_time_buckets: int = 1 << 16
ethernet_mac_address: str = "0x90e2ba01214c"
ethernet_chip_id: ty.Optional[ty.Tuple[int, int, int]] = None
ethernet_chip_idx: ty.Optional[int] = None
26 changes: 26 additions & 0 deletions src/lava/proc/io/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.model import PyAsyncProcessModel
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyInPort, PyRefPort
from lava.magma.compiler.channels.pypychannel import PyPyChannel
Expand Down Expand Up @@ -137,6 +139,30 @@ def __del__(self) -> None:
self._pm_to_p_src_port.join()


@implements(proc=Extractor, protocol=AsyncProtocol)
@requires(CPU)
class PyLoihiExtractorModelAsync(PyAsyncProcessModel):
in_port: PyInPort = LavaPyType(PyInPort.VEC_DENSE, float)

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params=proc_params)

channel_config = self.proc_params["channel_config"]
self._pm_to_p_src_port = self.proc_params["pm_to_p_src_port"]
self._pm_to_p_src_port.start()

self._send = channel_config.get_send_full_function()
self.time_step = 1

def run_async(self) -> None:
while self.time_step != self.num_steps + 1:
self._send(self._pm_to_p_src_port, self.in_port.recv())
self.time_step += 1

def __del__(self) -> None:
self._pm_to_p_src_port.join()


class VarWire(AbstractProcess):
"""VarWire allows non-Lava code, such as a third-party Python library
to tap data from a Lava Process Variable (Var) while the Lava Runtime is
Expand Down
44 changes: 44 additions & 0 deletions src/lava/proc/io/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from lava.magma.core.resources import CPU
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.model import PyAsyncProcessModel
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.model.py.ports import PyOutPort
from lava.magma.runtime.message_infrastructure.multiprocessing import \
Expand Down Expand Up @@ -139,3 +141,45 @@ def run_spk(self) -> None:

def __del__(self) -> None:
self._p_to_pm_dst_port.join()


@implements(proc=Injector, protocol=AsyncProtocol)
@requires(CPU)
class PyLoihiInjectorModelAsync(PyAsyncProcessModel):
"""PyAsyncProcessModel for the Injector Process."""
out_port: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, float)

def __init__(self, proc_params: dict) -> None:
super().__init__(proc_params=proc_params)

shape = self.proc_params["shape"]
channel_config = self.proc_params["channel_config"]
self._p_to_pm_dst_port = self.proc_params["p_to_pm_dst_port"]
self._p_to_pm_dst_port.start()

self._zeros = np.zeros(shape)

self._receive_when_empty = channel_config.get_receive_empty_function()
self._receive_when_not_empty = \
channel_config.get_receive_not_empty_function()
self.time_step = 1

def run_async(self) -> None:
while self.time_step != self.num_steps + 1:
self._zeros.fill(0)
elements_in_buffer = self._p_to_pm_dst_port._queue.qsize()

if elements_in_buffer == 0:
data = self._receive_when_empty(
self._p_to_pm_dst_port,
self._zeros)
else:
data = self._receive_when_not_empty(
self._p_to_pm_dst_port,
self._zeros,
elements_in_buffer)
self.out_port.send(data)
self.time_step += 1

def __del__(self) -> None:
self._p_to_pm_dst_port.join()
Loading