From 6028f069e47cd59d33b26175b06b6ee4bf11867c Mon Sep 17 00:00:00 2001 From: GaboFGuerra Date: Mon, 6 Nov 2023 13:21:05 -0800 Subject: [PATCH] Turn to async for SolutionReadout host process. Signed-off-by: GaboFGuerra --- .../optimization/solvers/generic/builder.py | 4 ++-- .../solution_readout/models.py | 10 +++++----- .../solvers/generic/read_gate/models.py | 1 + .../solvers/generic/solution_finder/models.py | 4 ++-- .../generic/solution_finder/process.py | 6 ++---- .../optimization/solvers/generic/solver.py | 19 ++++++++----------- 6 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/lava/lib/optimization/solvers/generic/builder.py b/src/lava/lib/optimization/solvers/generic/builder.py index 560f651d..b7e04e5e 100644 --- a/src/lava/lib/optimization/solvers/generic/builder.py +++ b/src/lava/lib/optimization/solvers/generic/builder.py @@ -277,13 +277,13 @@ def constructor(self, proc): setattr(self, f"finder_{idx}", finder) finders.append(finder) if not proc.is_continuous: - getattr(finder, f"cost_out_last_bytes_{idx}").connect( + getattr(finder, f"cost_out_last_bytes").connect( getattr( self.solution_reader, f"read_gate_in_port_last_bytes_{idx}", ) ) - getattr(finder, f"cost_out_first_byte_{idx}").connect( + getattr(finder, f"cost_out_first_byte").connect( getattr( self.solution_reader, f"read_gate_in_port_first_byte_{idx}", diff --git a/src/lava/lib/optimization/solvers/generic/monitoring_processes/solution_readout/models.py b/src/lava/lib/optimization/solvers/generic/monitoring_processes/solution_readout/models.py index 0fca02e7..727e3272 100644 --- a/src/lava/lib/optimization/solvers/generic/monitoring_processes/solution_readout/models.py +++ b/src/lava/lib/optimization/solvers/generic/monitoring_processes/solution_readout/models.py @@ -5,16 +5,16 @@ from lava.lib.optimization.solvers.generic.monitoring_processes\ .solution_readout.process import SolutionReadout from lava.magma.core.decorator import implements, requires -from lava.magma.core.model.py.model import PyLoihiProcessModel +from lava.magma.core.model.py.model import PyAsyncProcessModel from lava.magma.core.model.py.ports import PyInPort, PyOutPort from lava.magma.core.model.py.type import LavaPyType from lava.magma.core.resources import CPU -from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.magma.core.sync.protocols.async_protocol import AsyncProtocol -@implements(SolutionReadout, protocol=LoihiProtocol) +@implements(SolutionReadout, protocol=AsyncProtocol) @requires(CPU) -class SolutionReadoutPyModel(PyLoihiProcessModel): +class SolutionReadoutPyModel(PyAsyncProcessModel): """CPU model for the SolutionReadout process. The process receives two types of messages, an updated cost and the state of @@ -43,7 +43,7 @@ class SolutionReadoutPyModel(PyLoihiProcessModel): def _is_multichip(self): return False - def run_spk(self): + def run_async(self): if self.stop: return raw_cost, min_cost_id = self.cost_in.recv() diff --git a/src/lava/lib/optimization/solvers/generic/read_gate/models.py b/src/lava/lib/optimization/solvers/generic/read_gate/models.py index 01dd4704..9142aa31 100644 --- a/src/lava/lib/optimization/solvers/generic/read_gate/models.py +++ b/src/lava/lib/optimization/solvers/generic/read_gate/models.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # See: https://spdx.org/licenses/ import numpy as np +import typing as ty from lava.lib.optimization.solvers.generic.read_gate.process import ReadGate from lava.magma.core.decorator import implements, requires diff --git a/src/lava/lib/optimization/solvers/generic/solution_finder/models.py b/src/lava/lib/optimization/solvers/generic/solution_finder/models.py index 901c7a24..bd05221b 100644 --- a/src/lava/lib/optimization/solvers/generic/solution_finder/models.py +++ b/src/lava/lib/optimization/solvers/generic/solution_finder/models.py @@ -115,10 +115,10 @@ def __init__(self, proc): self.cost_convergence_check.cost_first_byte ) self.cost_convergence_check.cost_out_last_bytes.connect( - getattr(proc.out_ports, f"cost_out_last_bytes_{idx}") + getattr(proc.out_ports, f"cost_out_last_bytes") ) self.cost_convergence_check.cost_out_first_byte.connect( - getattr(proc.out_ports, f"cost_out_first_byte_{idx}") + getattr(proc.out_ports, f"cost_out_first_byte") ) elif continuous_var_shape: diff --git a/src/lava/lib/optimization/solvers/generic/solution_finder/process.py b/src/lava/lib/optimization/solvers/generic/solution_finder/process.py index a31751d7..5a2fbefb 100644 --- a/src/lava/lib/optimization/solvers/generic/solution_finder/process.py +++ b/src/lava/lib/optimization/solvers/generic/solution_finder/process.py @@ -42,7 +42,5 @@ def __init__( ) self.cost_last_bytes = Var(shape=(1,), init=(0,)) self.cost_first_byte = Var(shape=(1,), init=(0,)) - setattr(self, f"cost_out_last_bytes_{idx}", OutPort(shape=(1,))) - setattr(self, f"cost_out_first_byte_{idx}", OutPort(shape=(1,))) - # self.cost_out_last_bytes = OutPort(shape=(1,)) - # self.cost_out_first_byte = OutPort(shape=(1,)) + self.cost_out_last_bytes = OutPort(shape=(1,)) + self.cost_out_first_byte = OutPort(shape=(1,)) diff --git a/src/lava/lib/optimization/solvers/generic/solver.py b/src/lava/lib/optimization/solvers/generic/solver.py index 4ab6e62e..baadb291 100644 --- a/src/lava/lib/optimization/solvers/generic/solver.py +++ b/src/lava/lib/optimization/solvers/generic/solver.py @@ -323,11 +323,9 @@ def _prepare_solver(self, config: SolverConfig): num_steps=config.timeout, ) probes.append(self._state_tracker) - run_cfg = self._get_run_config( - backend=config.backend, - probes=probes, - num_in_ports=num_in_ports, - ) + run_cfg = self._get_run_config(config=config, + probes=probes, + num_in_ports=num_in_ports) run_condition = RunSteps(num_steps=config.timeout) self._prepare_profiler(config=config, run_cfg=run_cfg) return run_condition, run_cfg @@ -417,14 +415,13 @@ def _get_probed_data(self, tracker, var_name): else: return tracker.time_series - def _get_run_config( - self, backend: BACKENDS, probes=None, num_in_ports: int = None - ): + def _get_run_config(self, config: SolverConfig, probes=None, + num_in_ports: int = None): from lava.lib.optimization.solvers.generic.read_gate.process import ( ReadGate ) - if backend in CPUS: + if config.backend in CPUS: from lava.lib.optimization.solvers.generic.read_gate.models import ( get_read_gate_model_class, ) @@ -442,7 +439,7 @@ def _get_run_config( } return Loihi1SimCfg(exception_proc_model_map=pdict, select_sub_proc_model=True) - elif backend in NEUROCORES: + elif config.backend in NEUROCORES: from lava.lib.optimization.solvers.generic.read_gate.ncmodels \ import get_read_gate_model_class_c pdict = { @@ -465,7 +462,7 @@ def _get_run_config( callback_fxs=probes, ) else: - raise NotImplementedError(str(backend) + BACKEND_MSG) + raise NotImplementedError(str(config.backend) + BACKEND_MSG) def _prepare_profiler(self, config: SolverConfig, run_cfg) -> None: if config.probe_time or config.probe_energy: