Skip to content

Commit

Permalink
Turn to async for SolutionReadout host process.
Browse files Browse the repository at this point in the history
Signed-off-by: GaboFGuerra <gabriel.fonseca.guerra@intel.com>
  • Loading branch information
GaboFGuerra committed Nov 6, 2023
1 parent d61b2c8 commit 6028f06
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/lava/lib/optimization/solvers/generic/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def constructor(self, proc):
setattr(self, f"finder_{idx}", finder)
finders.append(finder)
if not proc.is_continuous:
getattr(finder, f"cost_out_last_bytes_{idx}").connect(
getattr(finder, f"cost_out_last_bytes").connect(
getattr(
self.solution_reader,
f"read_gate_in_port_last_bytes_{idx}",
)
)
getattr(finder, f"cost_out_first_byte_{idx}").connect(
getattr(finder, f"cost_out_first_byte").connect(
getattr(
self.solution_reader,
f"read_gate_in_port_first_byte_{idx}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from lava.lib.optimization.solvers.generic.monitoring_processes\
.solution_readout.process import SolutionReadout
from lava.magma.core.decorator import implements, requires
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.model import PyAsyncProcessModel
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.resources import CPU
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol


@implements(SolutionReadout, protocol=LoihiProtocol)
@implements(SolutionReadout, protocol=AsyncProtocol)
@requires(CPU)
class SolutionReadoutPyModel(PyLoihiProcessModel):
class SolutionReadoutPyModel(PyAsyncProcessModel):
"""CPU model for the SolutionReadout process.
The process receives two types of messages, an updated cost and the
state of
Expand Down Expand Up @@ -43,7 +43,7 @@ class SolutionReadoutPyModel(PyLoihiProcessModel):
def _is_multichip(self):
return False

def run_spk(self):
def run_async(self):
if self.stop:
return
raw_cost, min_cost_id = self.cost_in.recv()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/
import numpy as np
import typing as ty

from lava.lib.optimization.solvers.generic.read_gate.process import ReadGate
from lava.magma.core.decorator import implements, requires
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def __init__(self, proc):
self.cost_convergence_check.cost_first_byte
)
self.cost_convergence_check.cost_out_last_bytes.connect(
getattr(proc.out_ports, f"cost_out_last_bytes_{idx}")
getattr(proc.out_ports, f"cost_out_last_bytes")
)
self.cost_convergence_check.cost_out_first_byte.connect(
getattr(proc.out_ports, f"cost_out_first_byte_{idx}")
getattr(proc.out_ports, f"cost_out_first_byte")
)

elif continuous_var_shape:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,5 @@ def __init__(
)
self.cost_last_bytes = Var(shape=(1,), init=(0,))
self.cost_first_byte = Var(shape=(1,), init=(0,))
setattr(self, f"cost_out_last_bytes_{idx}", OutPort(shape=(1,)))
setattr(self, f"cost_out_first_byte_{idx}", OutPort(shape=(1,)))
# self.cost_out_last_bytes = OutPort(shape=(1,))
# self.cost_out_first_byte = OutPort(shape=(1,))
self.cost_out_last_bytes = OutPort(shape=(1,))
self.cost_out_first_byte = OutPort(shape=(1,))
19 changes: 8 additions & 11 deletions src/lava/lib/optimization/solvers/generic/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,9 @@ def _prepare_solver(self, config: SolverConfig):
num_steps=config.timeout,
)
probes.append(self._state_tracker)
run_cfg = self._get_run_config(
backend=config.backend,
probes=probes,
num_in_ports=num_in_ports,
)
run_cfg = self._get_run_config(config=config,
probes=probes,
num_in_ports=num_in_ports)
run_condition = RunSteps(num_steps=config.timeout)
self._prepare_profiler(config=config, run_cfg=run_cfg)
return run_condition, run_cfg
Expand Down Expand Up @@ -417,14 +415,13 @@ def _get_probed_data(self, tracker, var_name):
else:
return tracker.time_series

def _get_run_config(
self, backend: BACKENDS, probes=None, num_in_ports: int = None
):
def _get_run_config(self, config: SolverConfig, probes=None,
num_in_ports: int = None):
from lava.lib.optimization.solvers.generic.read_gate.process import (
ReadGate
)

if backend in CPUS:
if config.backend in CPUS:
from lava.lib.optimization.solvers.generic.read_gate.models import (
get_read_gate_model_class,
)
Expand All @@ -442,7 +439,7 @@ def _get_run_config(
}
return Loihi1SimCfg(exception_proc_model_map=pdict,
select_sub_proc_model=True)
elif backend in NEUROCORES:
elif config.backend in NEUROCORES:
from lava.lib.optimization.solvers.generic.read_gate.ncmodels \
import get_read_gate_model_class_c
pdict = {
Expand All @@ -465,7 +462,7 @@ def _get_run_config(
callback_fxs=probes,
)
else:
raise NotImplementedError(str(backend) + BACKEND_MSG)
raise NotImplementedError(str(config.backend) + BACKEND_MSG)

def _prepare_profiler(self, config: SolverConfig, run_cfg) -> None:
if config.probe_time or config.probe_energy:
Expand Down

0 comments on commit 6028f06

Please sign in to comment.