From ada5d124ce6bc391e7ff8c667cc1dcc7b97f8e5f Mon Sep 17 00:00:00 2001 From: Eduard Gash Date: Mon, 14 Jul 2025 13:52:16 +0000 Subject: [PATCH 1/2] feat: async aggregation improvements --- nebula/controller/scenarios.py | 64 ++-- .../updatehandlers/caffupdatehandler.py | 343 ++++++++++++++++++ .../updatehandlers/updatehandler.py | 29 +- nebula/core/engine.py | 103 +++++- nebula/core/nebulaevents.py | 27 +- nebula/core/network/actions.py | 1 + nebula/core/pb/nebula.proto | 1 + nebula/core/pb/nebula_pb2.py | 93 +++-- .../frontend/config/participant.json.example | 7 + .../static/js/deployment/help-content.js | 45 ++- .../frontend/static/js/deployment/scenario.js | 19 +- .../static/js/deployment/ui-controls.js | 80 +++- nebula/frontend/templates/deployment.html | 76 +++- 13 files changed, 758 insertions(+), 130 deletions(-) create mode 100644 nebula/core/aggregation/updatehandlers/caffupdatehandler.py diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index 421ab793..41bcb404 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -6,15 +6,12 @@ import math import os import shutil -import subprocess -import sys import time from datetime import datetime from urllib.parse import quote -from aiohttp import FormData import docker -import tensorboard_reducer as tbr +from aiohttp import FormData from nebula.addons.topologymanager import TopologyManager from nebula.config.config import Config @@ -108,6 +105,9 @@ def __init__( sar_neighbor_policy, sar_training, sar_training_policy, + aggregator_args, # <-- NEW + communication_args, # <-- NEW + caff_args, # <-- NEW physical_ips=None, ): """ @@ -199,30 +199,30 @@ def __init__( self.mobile_participants_percent = mobile_participants_percent self.additional_participants = additional_participants self.with_trustworthiness = with_trustworthiness - self.robustness_pillar = robustness_pillar, - self.resilience_to_attacks = resilience_to_attacks, - self.algorithm_robustness = algorithm_robustness, - self.client_reliability = client_reliability, - self.privacy_pillar = privacy_pillar, - self.technique = technique, - self.uncertainty = uncertainty, - self.indistinguishability = indistinguishability, - self.fairness_pillar = fairness_pillar, - self.selection_fairness = selection_fairness, - self.performance_fairness = performance_fairness, - self.class_distribution = class_distribution, - self.explainability_pillar = explainability_pillar, - self.interpretability = interpretability, - self.post_hoc_methods = post_hoc_methods, - self.accountability_pillar = accountability_pillar, - self.factsheet_completeness = factsheet_completeness, - self.architectural_soundness_pillar = architectural_soundness_pillar, - self.client_management = client_management, - self.optimization = optimization, - self.sustainability_pillar = sustainability_pillar, - self.energy_source = energy_source, - self.hardware_efficiency = hardware_efficiency, - self.federation_complexity = federation_complexity, + self.robustness_pillar = (robustness_pillar,) + self.resilience_to_attacks = (resilience_to_attacks,) + self.algorithm_robustness = (algorithm_robustness,) + self.client_reliability = (client_reliability,) + self.privacy_pillar = (privacy_pillar,) + self.technique = (technique,) + self.uncertainty = (uncertainty,) + self.indistinguishability = (indistinguishability,) + self.fairness_pillar = (fairness_pillar,) + self.selection_fairness = (selection_fairness,) + self.performance_fairness = (performance_fairness,) + self.class_distribution = (class_distribution,) + self.explainability_pillar = (explainability_pillar,) + self.interpretability = (interpretability,) + self.post_hoc_methods = (post_hoc_methods,) + self.accountability_pillar = (accountability_pillar,) + self.factsheet_completeness = (factsheet_completeness,) + self.architectural_soundness_pillar = (architectural_soundness_pillar,) + self.client_management = (client_management,) + self.optimization = (optimization,) + self.sustainability_pillar = (sustainability_pillar,) + self.energy_source = (energy_source,) + self.hardware_efficiency = (hardware_efficiency,) + self.federation_complexity = (federation_complexity,) self.schema_additional_participants = schema_additional_participants self.random_topology_probability = random_topology_probability self.with_sa = with_sa @@ -234,6 +234,9 @@ def __init__( self.sar_training = sar_training self.sar_training_policy = sar_training_policy self.physical_ips = physical_ips + self.aggregator_args = aggregator_args # <-- NEW + self.communication_args = communication_args # <-- NEW + self.caff_args = caff_args # <-- NEW def attack_node_assign( self, @@ -692,6 +695,11 @@ def __init__(self, scenario, user=None): participant_config["device_args"]["gpu_id"] = self.scenario.gpu_id participant_config["device_args"]["logging"] = self.scenario.logginglevel participant_config["aggregator_args"]["algorithm"] = self.scenario.agg_algorithm + participant_config["aggregator_args"]["aggregation_timeout"] = int( + self.scenario.aggregator_args["aggregation_timeout"] + ) # <-- New Addition + participant_config["communication_args"] = self.scenario.communication_args # <-- New Addition + participant_config["caff_args"] = self.scenario.caff_args # <-- New Addition # To be sure that benign nodes have no attack parameters if node_config["malicious"]: participant_config["adversarial_args"]["attack_params"] = node_config["attack_params"] diff --git a/nebula/core/aggregation/updatehandlers/caffupdatehandler.py b/nebula/core/aggregation/updatehandlers/caffupdatehandler.py new file mode 100644 index 00000000..8da0f985 --- /dev/null +++ b/nebula/core/aggregation/updatehandlers/caffupdatehandler.py @@ -0,0 +1,343 @@ +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import NodeTerminatedEvent, UpdateNeighborEvent, UpdateReceivedEvent +from nebula.core.utils.locker import Locker + +if TYPE_CHECKING: + from nebula.core.aggregation.aggregator import Aggregator + + +class Update: + def __init__(self, model, weight, source, round): + self.model = model + self.weight = weight + self.source = source + self.round = round + + +class CAFFUpdateHandler(UpdateHandler): + def __init__(self, aggregator: "Aggregator", addr: str): + self._addr = addr + self._aggregator = aggregator + + self._update_cache: dict[str, Update] = {} + self._terminated_peers: set[str] = set() + self._sources_expected: set[str] = set() + + self._cache_lock = Locker(name="caff_cache_lock", async_lock=True) + + self._fallback_task = None + self._notification_sent = False + + self._aggregation_fraction = 1.0 + self._staleness_threshold = 1 + + self._max_local_rounds = 1 + + self._local_round = 0 + self._K_node = 1 + + self._local_update: Update | None = None + self._local_training_done = False + + self._should_stop_training = False + + @property + def us(self): + return self._update_cache + + @property + def agg(self): + return self._aggregator + + async def init(self, config): + self._aggregation_fraction = config.participant["caff_args"]["aggregation_fraction"] + self._staleness_threshold = config.participant["caff_args"]["staleness_threshold"] + # self._fallback_timeout = config.participant["caff_args"]["fallback_timeout"] + self._max_local_rounds = config.participant["scenario_args"]["rounds"] + + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) + await EventManager.get_instance().subscribe_node_event(NodeTerminatedEvent, self.handle_node_termination) + await EventManager.get_instance().subscribe(("control", "terminated"), self.handle_peer_terminated) + + async def handle_peer_terminated(self, source: str, message: Any): + # 1) Guard against repeated termination handling + if self._should_stop_training: + logging.info(f"[CAFF] DEBUG should_stop_training = {self._should_stop_training}-----------------") + return + + logging.info(f"[CAFF] Received TERMINATED control message from {source}") + self._terminated_peers.add(source) + self._sources_expected.discard(source) + await self._check_k_node_satisfaction() + + async def _check_k_node_satisfaction(self): + # 2) If we’re already stopping, skip any further checks + if self._should_stop_training: + return + + active_peers_remaining = len(self._sources_expected - {self._addr} - self._terminated_peers) + required_external = self._K_node - 1 + logging.info( + f"[CAFF] Checking K_node feasibility: active_peers={active_peers_remaining}, required_external={required_external}" + ) + + if active_peers_remaining < required_external: + logging.warning("[CAFF] Not enough active peers to reach K_node. Triggering self-termination.") + # flip the stop flag so everyone knows we’re shutting down + self._should_stop_training = True + await self.terminate_self() + + # also let the engine know immediately + if hasattr(self._aggregator, "engine"): + logging.info("[CAFF] STOP: Notifying Engine to stop training") + self._aggregator.engine.force_stop_training() + + async def round_expected_updates(self, federation_nodes: set): + self._sources_expected = federation_nodes.copy() + # self._terminated_peers.clear() + self._update_cache.clear() + self._notification_sent = False + self._local_round = self._aggregator.engine.get_round() + + # Exclude self from expected peers, then add +1 later in aggregation threshold + expected_peers = self._sources_expected - {self._addr} + self._K_node = max( + 1 + 1, round(self._aggregation_fraction * len(expected_peers)) + 1 + ) # <-- fix here with local update once min reached + + self._local_update = None + self._local_training_done = False + + # if self._fallback_task: + # logging.info(f"[CAFF] FALLBACK TASK TRUE)") + # self._fallback_task.cancel() + # self._fallback_task = asyncio.create_task(self._start_fallback_timer()) + + async def storage_update(self, updt_received_event: UpdateReceivedEvent): + model, weight, source, round_received, local_flag = await updt_received_event.get_event_data() + await self._cache_lock.acquire_async() + + if self._addr == source: + update = Update(model, weight, source, round_received) + self._local_update = update + self._update_cache[source] = update # Make sure own update is in cache + self._local_training_done = True + logging.info(f"[CAFF] Received own update from {source} (local training done)") + logging.info(f"[CAFF] [CACHE-AFTER] Update cache for round {self._local_round}: {self._update_cache}") + await self._cache_lock.release_async() + await self._maybe_aggregate(force_check=True) + return + + if source in self._update_cache: + cached_update = self._update_cache[source] + if round_received > cached_update.round: + self._update_cache[source] = Update(model, weight, source, round_received) + logging.info(f"[CAFF] Replaced cached update from {source} with newer round {round_received}") + else: + logging.info(f"[CAFF] Discarded stale update from {source} (round {round_received})") + await self._cache_lock.release_async() + if self._local_training_done: + await self._maybe_aggregate() + return + + # if source in self._terminated_peers: + # logging.info(f"[CAFF] Ignoring update from terminated peer {source}") + # await self._cache_lock.release_async() + # return + + if self._local_round - round_received <= self._staleness_threshold: + self._update_cache[source] = Update(model, weight, source, round_received) + logging.info(f"[CAFF] Cached new update from {source} (round {round_received})") + else: + logging.info(f"[CAFF] Discarded stale update from {source} (round {round_received})") + + peer_count = len(self._update_cache) + has_local = self._local_training_done + + logging.info( + f"[CAFF] Round {self._local_round} aggregation check: {peer_count} / {self._K_node} (local ready: {has_local})" + ) + + await self._cache_lock.release_async() + + # Trigger aggregate only if local training is complete + if self._local_training_done: + await self._maybe_aggregate() + + async def get_round_updates(self): + await self._cache_lock.acquire_async() + updates = {peer: (u.model, u.weight) for peer, u in self._update_cache.items()} + if self._local_update: + updates[self._addr] = (self._local_update.model, self._local_update.weight) + logging.info(f"[CAFF] DEBUG local update = {self._local_update}-----------------") + await self._cache_lock.release_async() + return updates + + async def notify_federation_update(self, updt_nei_event: UpdateNeighborEvent): + source, remove = await updt_nei_event.get_event_data() + if remove: + self._terminated_peers.add(source) + logging.info(f"[CAFF] Peer {source} marked as terminated") + + async def get_round_missing_nodes(self): + return self._sources_expected.difference(self._update_cache.keys()).difference(self._terminated_peers) + + async def notify_if_all_updates_received(self): + logging.info("[CAFF] DEBUG NOTIFY ALL UPDATES RECEIVED---------------") + await self._maybe_aggregate() + + async def stop_notifying_updates(self): + if self._fallback_task: + self._fallback_task.cancel() + self._fallback_task = None + + # async def mark_self_terminated(self): + # source_id = self._aggregator.get_id() + # await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(source_id)) + + async def handle_node_termination(self, event: NodeTerminatedEvent): + source = await event.get_event_data() + if source in self._terminated_peers: + return + self._terminated_peers.add(source) + logging.info(f"[CAFF] Peer {source} terminated, remaining expected: {self._sources_expected}") + + # re-evaluate K_node feasibility immediately + await self._check_k_node_satisfaction() + + # if source not in self._terminated_peers: + # self._terminated_peers.add(source) + # #self._sources_expected.discard(source) + # logging.info(f"[CAFF] Node {source} terminated") + # logging.info(f"[CAFF] Updated terminated peers: {self._terminated_peers}") + # logging.info(f"[CAFF] Remaining active peers: {self._sources_expected - {self._addr} - self._terminated_peers}") + # if not self._local_training_done: + # await self. (force_check=True) + # else: + # logging.info("[CAFF] Ignoring termination impact – local training already done.") + + async def _maybe_aggregate(self, force_check=False): + logging.info( + f"[{self._addr}] ENTERED _maybe_aggregate (round={self._local_round}) | local_done={self._local_training_done} | cache={list(self._update_cache.keys())}" + ) + await self._cache_lock.acquire_async() + + if self._notification_sent and not force_check: + logging.info("CAFF DEBUGG got out of notification send") + await self._cache_lock.release_async() + return + + # if self._addr in self._update_cache: + # logging.info(f"----------------[CAFF] LOCAL UPDATE IS IN CACHE -------------------------") + # self._local_training_done = True + + peer_count = len(self._update_cache) + has_local = self._local_training_done + + logging.info( + f"[CAFF] Round {self._local_round} aggregation check: {peer_count} / {self._K_node} (local ready: {has_local})" + ) + + if self._addr in self._update_cache: + logging.info(f"[CAFF] [MAYBE_AGGREGATE] Local update is already cached for round {self._local_round}") + else: + logging.warning( + f"[CAFF] [MAYBE_AGGREGATE] Local update is NOT cached at aggregation time (round {self._local_round})" + ) + + if not has_local: + logging.debug("[CAFF] Waiting for local training to finish before aggregation.") + await self._cache_lock.release_async() + return + + # ✅ NEW: For round 0, enforce synchronous wait for all expected sources + if self._local_round == 0: + expected = self._sources_expected | {self._addr} + logging.info(f"Expected: {expected}") + received = set(self._update_cache.keys()) + logging.info(f"Received: {received}") + missing = expected - received + logging.info(f"Missing: {missing}") + if missing: + logging.info(f"[CAFF][ROUND 0 SYNC] Waiting for all updates in round 0. Still missing: {missing}") + await self._cache_lock.release_async() + return + else: + logging.info("[CAFF][ROUND 0 SYNC] All updates received for round 0 — proceeding with aggregation") + + if has_local and peer_count >= self._K_node: + logging.info(f"[CAFF] Aggregating with {peer_count} updates (incl. local)") + self._notification_sent = True + await self._cache_lock.release_async() + await asyncio.sleep(0.5) + await self.agg.notify_all_updates_received() + return + + active_peers_remaining = len(self._sources_expected - {self._addr} - self._terminated_peers) + required_external = self._K_node - 1 + + if active_peers_remaining < required_external: + logging.warning("[CAFF] Not enough active peers to reach K_node. Triggering self-termination.") + self._should_stop_training = True + logging.warning( + f"[CAFF] Triggered self-termination due to unsatisfiable K_node condition (active: {active_peers_remaining}) | (required: {required_external})" + ) + await self.terminate_self() + if hasattr(self._aggregator, "engine"): + self._aggregator.engine.force_stop_training() + return + + await self._cache_lock.release_async() + + # async def _start_fallback_timer(self): + # await asyncio.sleep(self._fallback_timeout) + # await self._cache_lock.acquire_async() + # if self._notification_sent: + # await self._cache_lock.release_async() + # return + # if self._update_cache and self._local_training_done: + # logging.warning("[CAFF] Fallback triggered — aggregating with partial cache.") + # self._notification_sent = True + # await self._cache_lock.release_async() + # await self.agg.notify_all_updates_received() + # else: + # await self._cache_lock.release_async() + + async def should_continue_training(self) -> bool: + logging.info( + f"[CAFF] should_continue_training = {not self._should_stop_training} (should_stop_training: {self._should_stop_training})" + ) + return not self._should_stop_training + + async def terminate_self(self): + # 3) Idempotency: only run this block once + if self._notification_sent: + return + + # mark that we've broadcast our termination + self._notification_sent = True + self._should_stop_training = True + + logging.warning(f"[CAFF] Node {self._addr} terminating itself.") + + # cancel the fallback timer if it’s still pending + # if self._fallback_task: + # self._fallback_task.cancel() + # self._fallback_task = None + + # publish our own termination event + await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(self._addr)) + + # send the control‐terminated message exactly once + terminate_msg = self._aggregator.engine.cm.create_message("control", "terminated") + await self._aggregator.engine.cm.send_message_to_neighbors(terminate_msg) + + # also flip the engine’s flag so the loop exits ASAP + if hasattr(self._aggregator, "engine"): + self._aggregator.engine.force_stop_training() diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index f3484923..67c24a6e 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -107,15 +107,32 @@ async def stop_notifying_updates(self): def factory_update_handler(updt_handler, aggregator, addr) -> UpdateHandler: + from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler from nebula.core.aggregation.updatehandlers.cflupdatehandler import CFLUpdateHandler from nebula.core.aggregation.updatehandlers.dflupdatehandler import DFLUpdateHandler from nebula.core.aggregation.updatehandlers.sdflupdatehandler import SDFLUpdateHandler - UPDATE_HANDLERS = {"DFL": DFLUpdateHandler, "CFL": CFLUpdateHandler, "SDFL": SDFLUpdateHandler} - - update_handler = UPDATE_HANDLERS.get(updt_handler) - - if update_handler: - return update_handler(aggregator, addr) + # added by Eduard to choose between communication strategy once DFL is selected + if updt_handler == "DFL": + mechanism = aggregator.config.participant.get("communication_args", {}).get( + "mechanism", "standard" + ) # strategy can be either standard or CAFF + if mechanism == "CAFF": + return CAFFUpdateHandler(aggregator, addr) + else: + return DFLUpdateHandler(aggregator, addr) + elif updt_handler == "CFL": + return CFLUpdateHandler(aggregator, addr) + elif updt_handler == "SDFL": + return SDFLUpdateHandler(aggregator, addr) else: raise UpdateHandlerException(f"Update Handler {updt_handler} not found") + + # UPDATE_HANDLERS = {"DFL": DFLUpdateHandler, "CFL": CFLUpdateHandler, "SDFL": SDFLUpdateHandler, "CAFF": CAFFUpdateHandler} + + # update_handler = UPDATE_HANDLERS.get(updt_handler) + + # if update_handler: + # return update_handler(aggregator, addr) + # else: + # raise UpdateHandlerException(f"Update Handler {updt_handler} not found") diff --git a/nebula/core/engine.py b/nebula/core/engine.py index fd0a24ff..21233b94 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -4,10 +4,9 @@ import random import socket import time + import docker -from nebula.core.role import Role, factory_node_role -from nebula.addons.attacks.attacks import create_attack from nebula.addons.functions import print_msg_box from nebula.addons.reporter import Reporter from nebula.addons.reputation.reputation import Reputation @@ -16,13 +15,15 @@ from nebula.core.eventmanager import EventManager from nebula.core.nebulaevents import ( AggregationEvent, + ExperimentFinishEvent, + NodeTerminatedEvent, RoundEndEvent, RoundStartEvent, UpdateNeighborEvent, UpdateReceivedEvent, - ExperimentFinishEvent, ) from nebula.core.network.communications import CommunicationsManager +from nebula.core.role import Role, factory_node_role from nebula.core.situationalawareness.situationalawareness import SituationalAwareness from nebula.core.utils.locker import Locker @@ -37,6 +38,7 @@ import sys from nebula.config.config import Config +from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler from nebula.core.training.lightning import Lightning @@ -118,6 +120,16 @@ def __init__( self._trainer = trainer(model, datamodule, config=self.config) self._aggregator = create_aggregator(config=self.config, engine=self) + self._termination_sent = False # <-- needed for CAFF mechanism + self._using_caff = self.config.participant.get("communication_args", {}).get("mechanism", "").lower() == "caff" + if self._using_caff: + logging.info(f"[{self.addr}] Communication mechanism: CAFF") + else: + logging.info( + f"[{self.addr}] Communication mechanism: {self.config.participant.get('communication_args', {}).get('mechanism', 'standard')}" + ) + self._caff_force_terminate = False + self._secure_neighbors = [] self._is_malicious = self.config.participant["adversarial_args"]["attack_params"]["attacks"] != "No Attack" @@ -227,6 +239,18 @@ def set_round(self, new_round): self.round = new_round self.trainer.set_current_round(new_round) + def _get_caff_handler(self) -> CAFFUpdateHandler | None: + if self.config.participant["scenario_args"]["federation"].lower() == "dfl" and self._using_caff: + from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler + + handler = getattr(self.aggregator, "_update_storage", None) + if isinstance(handler, CAFFUpdateHandler): + return handler + return None + + def force_stop_training(self): + self._caff_force_terminate = True + """ ############################## # MODEL CALLBACKS # ############################## @@ -314,10 +338,10 @@ async def _control_leadership_transfer_callback(self, source, message): await self.cm.send_message(source, message) logging.info(f"🔧 handle_control_message | Trigger | Leadership transfer ack message sent to {source}") else: - logging.info(f"🔧 handle_control_message | Trigger | Only one neighbor found, I am the leader") + logging.info("🔧 handle_control_message | Trigger | Only one neighbor found, I am the leader") else: self.role = Role.AGGREGATOR - logging.info(f"🔧 handle_control_message | Trigger | I am now the leader") + logging.info("🔧 handle_control_message | Trigger | I am now the leader") message = self.cm.create_message("control", "leadership_transfer_ack") await self.cm.send_message(source, message) logging.info(f"🔧 handle_control_message | Trigger | Leadership transfer ack message sent to {source}") @@ -710,7 +734,7 @@ def learning_cycle_finished(self): if not self.round or not self.total_rounds: return False else: - return (self.round < self.total_rounds) + return self.round < self.total_rounds async def _learning_cycle(self): """ @@ -734,7 +758,8 @@ async def _learning_cycle(self): This function blocks (awaits) until the full FL process concludes. """ - while self.round is not None and self.round < self.total_rounds: + # while self.round is not None and self.round < self.total_rounds: + while await self._continue_training(): current_time = time.time() print_msg_box( msg=f"Round {self.round} of {self.total_rounds - 1} started (max. {self.total_rounds} rounds)", @@ -776,6 +801,7 @@ async def _learning_cycle(self): ) # Set current round in config (send to the controller) await self.get_round_lock().release_async() + logging.info(f"[{self.addr}] Training finished at round {self.round}, entering idle mode.") # End of the learning cycle self.trainer.on_learning_cycle_end() @@ -799,14 +825,25 @@ async def _learning_cycle(self): await asyncio.sleep(5) - # Kill itself - if self.config.participant["scenario_args"]["deployment"] == "docker": - try: - docker_id = socket.gethostname() - logging.info(f"📦 Killing docker container with ID {docker_id}") - self.client.containers.get(docker_id).kill() - except Exception as e: - logging.exception(f"📦 Error stopping Docker container with ID {docker_id}: {e}") + # If it's CAFF, log what should_continue_training() returns + caff_handler = self._get_caff_handler() + if caff_handler: + should_continue = await caff_handler.should_continue_training() + logging.info(f"[{self.addr}] CAFF should_continue_training() result: {should_continue}") + else: + should_continue = False + logging.info(f"[{self.addr}] Not using CAFF or not a DFL scenario — proceeding to terminate.") + + if ( + not caff_handler or not await caff_handler.should_continue_training() + ): # <--- comment out if training to stop as soon as first node reaches finish line + if self.config.participant["scenario_args"]["deployment"] == "docker": + try: + docker_id = socket.gethostname() + logging.info(f"📦 Killing docker container with ID {docker_id}") + self.client.containers.get(docker_id).kill() + except Exception as e: + logging.exception(f"📦 Error stopping Docker container with ID {docker_id}: {e}") async def _extended_learning_cycle(self): """ @@ -814,3 +851,39 @@ async def _extended_learning_cycle(self): functionalities. The method is called in the _learning_cycle method. """ pass + + async def _continue_training(self): + if self.round is None: + logging.info("LEAVING LEARNING CYCLE") + return False + + # If using CAFF: check if early termination has been triggered + if self._using_caff: + caff_handler = self._get_caff_handler() + if caff_handler: + should_continue = await caff_handler.should_continue_training() + if not should_continue: + logging.info(f"[{self.addr}] CAFF handler requested early stop. Terminating training.") + # Send termination alert only once + if not self._termination_sent: + await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(self.addr)) + terminate_msg = self.cm.create_message("control", "terminated") + await self.cm.send_message_to_neighbors(terminate_msg) + self._termination_sent = True + return False + + # Regular case: keep training if rounds left + if self.round < self.total_rounds: + return True + + # If max round is reached: send CAFF termination alert only once + if self._using_caff: + if not self._termination_sent: + logging.info("[CAFF] Max rounds reached. Announcing termination to peers.") + await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(self.addr)) + terminate_msg = self.cm.create_message("control", "terminated") + await self.cm.send_message_to_neighbors(terminate_msg) + self._termination_sent = True + + # Never train beyond max rounds — not even for CAFF + return False diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index e35c17fc..c97d6fb2 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -5,7 +5,7 @@ class AddonEvent(ABC): """ Abstract base class for all addon-related events in the system. """ - + @abstractmethod async def get_event_data(self): """ @@ -21,7 +21,7 @@ class NodeEvent(ABC): """ Abstract base class for all node-related events in the system. """ - + @abstractmethod async def get_event_data(self): """ @@ -52,7 +52,7 @@ class MessageEvent: source (str): Address or identifier of the message sender. message (Any): The actual message payload. """ - + def __init__(self, message_type, source, message): """ Initializes a MessageEvent instance. @@ -144,6 +144,20 @@ async def is_concurrent(self): return False +class NodeTerminatedEvent(NodeEvent): + def __init__(self, source_id: str): + self._source_id = source_id + + def __str__(self): + return f"NodeTerminatedEvent from {self._source_id}" + + async def get_event_data(self) -> str: + return self._source_id + + async def is_concurrent(self) -> bool: + return True + + class AggregationEvent(NodeEvent): def __init__(self, updates: dict, expected_nodes: set, missing_nodes: set): """Event triggered when model aggregation is ready. @@ -348,7 +362,7 @@ class GPSEvent(AddonEvent): Attributes: distances (dict): A dictionary mapping node addresses to their respective distances. """ - + def __init__(self, distances: dict): """ Initializes a GPSEvent. @@ -379,7 +393,7 @@ class ChangeLocationEvent(AddonEvent): latitude (float): New latitude of the node. longitude (float): New longitude of the node. """ - + def __init__(self, latitude, longitude): """ Initializes a ChangeLocationEvent. @@ -402,7 +416,8 @@ async def get_event_data(self): tuple: A tuple containing latitude and longitude. """ return (self.latitude, self.longitude) - + + class TestMetricsEvent(AddonEvent): def __init__(self, loss, accuracy): self._loss = loss diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index 77e1997c..109bbab5 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -47,6 +47,7 @@ class ControlAction(Enum): WEAK_LINK = nebula_pb2.ControlMessage.Action.WEAK_LINK LEADERSHIP_TRANSFER = nebula_pb2.ControlMessage.Action.LEADERSHIP_TRANSFER LEADERSHIP_TRANSFER_ACK = nebula_pb2.ControlMessage.Action.LEADERSHIP_TRANSFER_ACK + TERMINATED = nebula_pb2.ControlMessage.Action.Value("TERMINATED") class DiscoverAction(Enum): diff --git a/nebula/core/pb/nebula.proto b/nebula/core/pb/nebula.proto index 3360196e..23fe0005 100755 --- a/nebula/core/pb/nebula.proto +++ b/nebula/core/pb/nebula.proto @@ -49,6 +49,7 @@ message ControlMessage { WEAK_LINK = 4; LEADERSHIP_TRANSFER = 5; LEADERSHIP_TRANSFER_ACK = 6; + TERMINATED = 7; } Action action = 1; string log = 2; diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index 448675b3..ecf3c3b3 100644 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -1,62 +1,61 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: nebula.proto -# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" + from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0cnebula.proto\x12\x06nebula"\xae\x04\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x37\n\x12reputation_message\x18\x08 \x01(\x0b\x32\x19.nebula.ReputationMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\t \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\n \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\x0b \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02"\xe1\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t"\x92\x01\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\x12\x17\n\x13LEADERSHIP_TRANSFER\x10\x05\x12\x1b\n\x17LEADERSHIP_TRANSFER_ACK\x10\x06\x12\x0e\n\nTERMINATED\x10\x07"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03"\x95\x01\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action"R\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01"\x89\x01\n\x11ReputationMessage\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\r\n\x05round\x18\x03 \x01(\x05\x12\x30\n\x06\x61\x63tion\x18\x04 \x01(\x0e\x32 .nebula.ReputationMessage.Action"\x13\n\x06\x41\x63tion\x12\t\n\x05SHARE\x10\x00"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3' +) - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xae\x04\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x37\n\x12reputation_message\x18\x08 \x01(\x0b\x32\x19.nebula.ReputationMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\t \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\n \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\x0b \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\xd1\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"\x82\x01\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\x12\x17\n\x13LEADERSHIP_TRANSFER\x10\x05\x12\x1b\n\x17LEADERSHIP_TRANSFER_ACK\x10\x06\"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"\x95\x01\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"R\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"\x89\x01\n\x11ReputationMessage\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\r\n\x05round\x18\x03 \x01(\x05\x12\x30\n\x06\x61\x63tion\x18\x04 \x01(\x0e\x32 .nebula.ReputationMessage.Action\"\x13\n\x06\x41\x63tion\x12\t\n\x05SHARE\x10\x00\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', _globals) +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "nebula_pb2", globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_WRAPPER']._serialized_start=25 - _globals['_WRAPPER']._serialized_end=583 - _globals['_DISCOVERYMESSAGE']._serialized_start=586 - _globals['_DISCOVERYMESSAGE']._serialized_end=744 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=692 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=744 - _globals['_CONTROLMESSAGE']._serialized_start=747 - _globals['_CONTROLMESSAGE']._serialized_end=956 - _globals['_CONTROLMESSAGE_ACTION']._serialized_start=826 - _globals['_CONTROLMESSAGE_ACTION']._serialized_end=956 - _globals['_FEDERATIONMESSAGE']._serialized_start=959 - _globals['_FEDERATIONMESSAGE']._serialized_end=1164 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=1064 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=1164 - _globals['_MODELMESSAGE']._serialized_start=1166 - _globals['_MODELMESSAGE']._serialized_end=1231 - _globals['_CONNECTIONMESSAGE']._serialized_start=1234 - _globals['_CONNECTIONMESSAGE']._serialized_end=1377 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1305 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1377 - _globals['_DISCOVERMESSAGE']._serialized_start=1380 - _globals['_DISCOVERMESSAGE']._serialized_end=1529 - _globals['_DISCOVERMESSAGE_ACTION']._serialized_start=1447 - _globals['_DISCOVERMESSAGE_ACTION']._serialized_end=1529 - _globals['_OFFERMESSAGE']._serialized_start=1532 - _globals['_OFFERMESSAGE']._serialized_end=1738 - _globals['_OFFERMESSAGE_ACTION']._serialized_start=1695 - _globals['_OFFERMESSAGE_ACTION']._serialized_end=1738 - _globals['_LINKMESSAGE']._serialized_start=1740 - _globals['_LINKMESSAGE']._serialized_end=1859 - _globals['_LINKMESSAGE_ACTION']._serialized_start=1814 - _globals['_LINKMESSAGE_ACTION']._serialized_end=1859 - _globals['_REPUTATIONMESSAGE']._serialized_start=1862 - _globals['_REPUTATIONMESSAGE']._serialized_end=1999 - _globals['_REPUTATIONMESSAGE_ACTION']._serialized_start=1980 - _globals['_REPUTATIONMESSAGE_ACTION']._serialized_end=1999 - _globals['_RESPONSEMESSAGE']._serialized_start=2001 - _globals['_RESPONSEMESSAGE']._serialized_end=2036 + DESCRIPTOR._options = None + _WRAPPER._serialized_start = 25 + _WRAPPER._serialized_end = 583 + _DISCOVERYMESSAGE._serialized_start = 586 + _DISCOVERYMESSAGE._serialized_end = 744 + _DISCOVERYMESSAGE_ACTION._serialized_start = 692 + _DISCOVERYMESSAGE_ACTION._serialized_end = 744 + _CONTROLMESSAGE._serialized_start = 747 + _CONTROLMESSAGE._serialized_end = 972 + _CONTROLMESSAGE_ACTION._serialized_start = 826 + _CONTROLMESSAGE_ACTION._serialized_end = 972 + _FEDERATIONMESSAGE._serialized_start = 975 + _FEDERATIONMESSAGE._serialized_end = 1180 + _FEDERATIONMESSAGE_ACTION._serialized_start = 1080 + _FEDERATIONMESSAGE_ACTION._serialized_end = 1180 + _MODELMESSAGE._serialized_start = 1182 + _MODELMESSAGE._serialized_end = 1247 + _CONNECTIONMESSAGE._serialized_start = 1250 + _CONNECTIONMESSAGE._serialized_end = 1393 + _CONNECTIONMESSAGE_ACTION._serialized_start = 1321 + _CONNECTIONMESSAGE_ACTION._serialized_end = 1393 + _DISCOVERMESSAGE._serialized_start = 1396 + _DISCOVERMESSAGE._serialized_end = 1545 + _DISCOVERMESSAGE_ACTION._serialized_start = 1463 + _DISCOVERMESSAGE_ACTION._serialized_end = 1545 + _OFFERMESSAGE._serialized_start = 1548 + _OFFERMESSAGE._serialized_end = 1754 + _OFFERMESSAGE_ACTION._serialized_start = 1711 + _OFFERMESSAGE_ACTION._serialized_end = 1754 + _LINKMESSAGE._serialized_start = 1756 + _LINKMESSAGE._serialized_end = 1875 + _LINKMESSAGE_ACTION._serialized_start = 1830 + _LINKMESSAGE_ACTION._serialized_end = 1875 + _REPUTATIONMESSAGE._serialized_start = 1878 + _REPUTATIONMESSAGE._serialized_end = 2015 + _REPUTATIONMESSAGE_ACTION._serialized_start = 1996 + _REPUTATIONMESSAGE_ACTION._serialized_end = 2015 + _RESPONSEMESSAGE._serialized_start = 2017 + _RESPONSEMESSAGE._serialized_end = 2052 # @@protoc_insertion_point(module_scope) diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index 4f4d9b26..2be5d196 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -149,5 +149,12 @@ "misc_args": { "grace_time_connection": 10, "grace_time_start_federation": 10 + }, + "communication_args": { + "mechanism": "standard" + }, + "caff_args": { + "aggregation_fraction": 0.5, + "staleness_threshold": 2 } } diff --git a/nebula/frontend/static/js/deployment/help-content.js b/nebula/frontend/static/js/deployment/help-content.js index 673cae88..34ac2de3 100644 --- a/nebula/frontend/static/js/deployment/help-content.js +++ b/nebula/frontend/static/js/deployment/help-content.js @@ -13,7 +13,10 @@ const HelpContent = (function() { 'partitionMethodsHelpIcon': partitionMethods, 'parameterSettingHelpIcon': parameterSetting, 'modelHelpIcon': model, - 'maliciousHelpIcon': malicious + 'maliciousHelpIcon': malicious, + 'aggregationFractionHelpIcon': caff.aggregationFraction, + 'stalenessThresholdHelpIcon': caff.stalenessThreshold, + 'fallbackTimeoutHelpIcon': caff.fallbackTimeout, }; Object.entries(tooltipElements).forEach(([id, content]) => { @@ -160,6 +163,43 @@ const HelpContent = (function() { ` }; + const caff = { + aggregationFraction: ` +
+ Aggregation Fraction (alpha) +

+ The fraction of peers (N) whose updates you wait for before aggregating. + E.g. alpha=0.5 means “once 50% of other nodes have sent an update and local model is ready, start aggregation.” +
+ K = max(1, round(alpha * N)) +

+
`, + + stalenessThreshold: ` +
+ Staleness Threshold (Smax) +

+ A peer’s update is only accepted if: +
+ local_round - peer_update_round ≤ Smax. +
+ Otherwise (i.e. if it’s older), it’s discarded as too stale. +

+
`, + + fallbackTimeout: ` +
+ Fallback Timeout +

+ A timer that starts immediately after your local update is sent. +
+ If aggregation hasn’t triggered within this many seconds, +
+ it forces aggregation with whatever updates are in the cache. +

+
` + }; + return { initializePopovers, topology, @@ -171,7 +211,8 @@ const HelpContent = (function() { model, malicious, deployment, - reputation + reputation, + caff }; })(); diff --git a/nebula/frontend/static/js/deployment/scenario.js b/nebula/frontend/static/js/deployment/scenario.js index af2d9700..c1781cb5 100644 --- a/nebula/frontend/static/js/deployment/scenario.js +++ b/nebula/frontend/static/js/deployment/scenario.js @@ -47,7 +47,6 @@ const ScenarioManager = (function() { start: node.start }; }); - // Get topology type from select element const topologyType = document.getElementById('predefined-topology-select').value; @@ -131,7 +130,17 @@ const ScenarioManager = (function() { schema_additional_participants: document.getElementById("schemaAdditionalParticipantsSelect").value || "random", accelerator: "cpu", gpu_id: [], - physical_ips: physical_ips + physical_ips: physical_ips, + communication_args: { + mechanism: document.querySelector('input[name="mechanism"]:checked')?.value || "standard" + }, + caff_args: { + aggregation_fraction: parseFloat(document.getElementById("aggregationFraction")?.value || 0.5), + staleness_threshold: parseInt(document.getElementById("stalenessThreshold")?.value || 2) + }, + aggregator_args: { + aggregation_timeout: parseInt(document.getElementById("aggregationTimeout")?.value || 60) + } }; } @@ -353,16 +362,16 @@ const ScenarioManager = (function() { function setPhysicalIPs(ipList = []) { physical_ips = [...ipList]; } - + function setActualScenario(index) { actual_scenario = index; if (scenariosList[index]) { // Clear the current graph window.TopologyManager.clearGraph(); - + // Load new scenario data loadScenarioData(scenariosList[index]); - + // If physical deployment, set physical IPs if (scenariosList[index].deployment === 'physical' && scenariosList[index].physical_ips) { window.TopologyManager.setPhysicalIPs(scenariosList[index].physical_ips); diff --git a/nebula/frontend/static/js/deployment/ui-controls.js b/nebula/frontend/static/js/deployment/ui-controls.js index ed02efa7..e42ca158 100644 --- a/nebula/frontend/static/js/deployment/ui-controls.js +++ b/nebula/frontend/static/js/deployment/ui-controls.js @@ -11,11 +11,11 @@ const UIControls = (function() { /* === control Physical + Predefined => block input === */ document.querySelectorAll('input[name="deploymentRadioOptions"]') .forEach(r => r.addEventListener('change', togglePredefinedNodesInput)); - + ['custom-topology-btn', 'predefined-topology-btn'] .forEach(id => document.getElementById(id) .addEventListener('change', togglePredefinedNodesInput)); - + togglePredefinedNodesInput(); setupVpnDiscover(); setupParticipantDisplay(); @@ -24,6 +24,48 @@ const UIControls = (function() { // Initialize help icons window.HelpContent.initializePopovers(); setupDeploymentRadios(); + setupCommunicationMechanismControl(); + } + + function setupCommunicationMechanismControl() { + const mechanismContainer = document.getElementById("communicationMechanismContainer"); + const caffParams = document.getElementById("caffParametersContainer"); + + const federationSelect = document.getElementById('federationArchitecture'); + const radioStandard = document.getElementById("mechanismStandard"); + const radioCAFF = document.getElementById("mechanismCAFF"); + + // Show mechanism options only if DFL is selected + federationSelect.addEventListener('change', function () { + const federationType = this.value; + + if (federationType === 'DFL') { + mechanismContainer.style.display = "block"; + + // Show CAFF params if CAFF was already selected + if (radioCAFF.checked) { + caffParams.style.display = "block"; + } else { + caffParams.style.display = "none"; + } + } else { + mechanismContainer.style.display = "none"; + caffParams.style.display = "none"; + } + }); + + // When switching between mechanisms inside DFL + radioStandard.addEventListener("change", () => { + if (radioStandard.checked) { + caffParams.style.display = "none"; + } + }); + + radioCAFF.addEventListener("change", () => { + if (radioCAFF.checked) { + caffParams.style.display = "block"; + } + }); } function setupModeButton() { @@ -650,33 +692,33 @@ const UIControls = (function() { const radios = document.querySelectorAll('input[name="deploymentRadioOptions"]'); const discoverBtn = document.getElementById('discoverDevicesBtn'); if (!discoverBtn || !radios.length) return; - + const toggle = () => { const sel = document.querySelector('input[name="deploymentRadioOptions"]:checked'); discoverBtn.disabled = sel.value !== 'physical'; }; - + radios.forEach(r => r.addEventListener('change', toggle)); toggle(); } - + function setupVpnDiscover() { const discoverBtn = document.getElementById('discoverDevicesBtn'); if (!discoverBtn) return; - + discoverBtn.addEventListener('click', async () => { try { const res = await fetch('/platform/api/discover-vpn'); if (!res.ok) throw new Error(res.statusText); - + const { ips } = await res.json(); - + const form = document.getElementById('vpn-form'); form.innerHTML = ''; - + const currentScenario = window.ScenarioManager.getScenariosList()[window.ScenarioManager.getActualScenario()]; const selectedIPs = currentScenario?.physical_ips || []; - + ips.forEach(ip => { const wrapper = document.createElement('div'); wrapper.classList.add('form-check'); @@ -687,18 +729,18 @@ const UIControls = (function() { `; form.appendChild(wrapper); }); - + const modal = new bootstrap.Modal(document.getElementById('vpnModal')); modal.show(); - + document.getElementById('vpn-accept-btn').onclick = () => { const selected = Array.from(form.querySelectorAll('input:checked')) .map(i => i.value); - + window.ScenarioManager.setPhysicalIPs(selected); - + window.TopologyManager.setPhysicalIPs(selected); - + modal.hide(); }; } catch (err) { @@ -707,19 +749,19 @@ const UIControls = (function() { } }); } - + function togglePredefinedNodesInput() { const deployment = document.querySelector('input[name="deploymentRadioOptions"]:checked')?.value; const isPredefined = document.getElementById('predefined-topology-btn').checked; const nodesInput = document.getElementById('predefined-topology-nodes'); - + if (!nodesInput) return; - + const disable = deployment === 'physical' && isPredefined; nodesInput.disabled = disable; nodesInput.classList.toggle('disabled', disable); } - + function setupDeploymentRadios() { const radios = document.querySelectorAll('input[name="deploymentRadioOptions"]'); radios.forEach(radio => { diff --git a/nebula/frontend/templates/deployment.html b/nebula/frontend/templates/deployment.html index 375d74af..0a7b1f14 100755 --- a/nebula/frontend/templates/deployment.html +++ b/nebula/frontend/templates/deployment.html @@ -285,6 +285,78 @@
Number of rounds
+ +
+
Communication Mechanism
+
+ + +
+
+ + +
+
+ +
Network Topology @@ -961,7 +1033,7 @@
Fairness pillar
- % + %
@@ -1139,4 +1211,4 @@
Federation complexity notion
}); - {% endblock %} \ No newline at end of file + {% endblock %} From 09ff3fc8d27205773991a13576b4f9ab63f07917 Mon Sep 17 00:00:00 2001 From: Eduard Gash Date: Mon, 4 Aug 2025 17:04:08 +0000 Subject: [PATCH 2/2] Finalize CAFF (async aggregation mechanism) implementation and frontend integration --- .../updatehandlers/caffupdatehandler.py | 358 ++++++++---------- .../updatehandlers/updatehandler.py | 4 +- nebula/core/engine.py | 86 ----- nebula/core/nebulaevents.py | 14 - nebula/core/network/propagator.py | 22 +- nebula/frontend/templates/deployment.html | 2 +- 6 files changed, 180 insertions(+), 306 deletions(-) diff --git a/nebula/core/aggregation/updatehandlers/caffupdatehandler.py b/nebula/core/aggregation/updatehandlers/caffupdatehandler.py index 8da0f985..532b5c3e 100644 --- a/nebula/core/aggregation/updatehandlers/caffupdatehandler.py +++ b/nebula/core/aggregation/updatehandlers/caffupdatehandler.py @@ -1,10 +1,11 @@ import asyncio +import time import logging from typing import TYPE_CHECKING, Any from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler from nebula.core.eventmanager import EventManager -from nebula.core.nebulaevents import NodeTerminatedEvent, UpdateNeighborEvent, UpdateReceivedEvent +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent from nebula.core.utils.locker import Locker if TYPE_CHECKING: @@ -12,20 +13,43 @@ class Update: - def __init__(self, model, weight, source, round): + """ + Holds a single model update along with metadata. + + Attributes: + model (object): The serialized model or weights. + weight (float): The importance weight of the update. + source (str): Identifier of the sending node. + round (int): Training round number of this update. + time_received (float): Timestamp when update arrived. + """ + def __init__(self, model, weight, source, round, time_received): self.model = model self.weight = weight self.source = source self.round = round + self.time_received = time_received class CAFFUpdateHandler(UpdateHandler): + """ + CAFF: Cache-based Aggregation with Fairness and Filterin Update Handler. + + Manages peer updates asynchronously, applies a staleness filter, + and triggers aggregation once a dynamic threshold K_node is met. + """ def __init__(self, aggregator: "Aggregator", addr: str): + """ + Initialize handler state and async locks. + + Args: + aggregator (Aggregator): The federation aggregator instance. + addr (str): This node's unique address. + """ self._addr = addr self._aggregator = aggregator self._update_cache: dict[str, Update] = {} - self._terminated_peers: set[str] = set() self._sources_expected: set[str] = set() self._cache_lock = Locker(name="caff_cache_lock", async_lock=True) @@ -36,108 +60,117 @@ def __init__(self, aggregator: "Aggregator", addr: str): self._aggregation_fraction = 1.0 self._staleness_threshold = 1 - self._max_local_rounds = 1 - self._local_round = 0 self._K_node = 1 - self._local_update: Update | None = None self._local_training_done = False - self._should_stop_training = False + ### TEST FOR CAFF - Send model updates only to last contributors + self._last_contributors: list[str] = [] + ### TEST FOR CAFF - Send model updates only to last contributors + @property def us(self): + """Returns the internal update cache mapping source -> Update.""" return self._update_cache @property def agg(self): + """Returns the linked aggregator.""" return self._aggregator - async def init(self, config): - self._aggregation_fraction = config.participant["caff_args"]["aggregation_fraction"] - self._staleness_threshold = config.participant["caff_args"]["staleness_threshold"] - # self._fallback_timeout = config.participant["caff_args"]["fallback_timeout"] - self._max_local_rounds = config.participant["scenario_args"]["rounds"] - - await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) - await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) - await EventManager.get_instance().subscribe_node_event(NodeTerminatedEvent, self.handle_node_termination) - await EventManager.get_instance().subscribe(("control", "terminated"), self.handle_peer_terminated) - - async def handle_peer_terminated(self, source: str, message: Any): - # 1) Guard against repeated termination handling - if self._should_stop_training: - logging.info(f"[CAFF] DEBUG should_stop_training = {self._should_stop_training}-----------------") - return - - logging.info(f"[CAFF] Received TERMINATED control message from {source}") - self._terminated_peers.add(source) - self._sources_expected.discard(source) - await self._check_k_node_satisfaction() + @property + def last_contributors(self) -> list[str]: + """ + The list of peer-addresses who actually contributed in the previous round. + """ + return self._last_contributors + + def _recalc_K_node(self): + """ + Recalculate K_node based on current expected peers and fraction. + + Ensures at least one external peer and includes local update. + """ + # when only local node is still active and all others finished training + if len(self._sources_expected) == 1: + self._K_node = 1 + else: + # count only “external” peers + external = len(self._sources_expected - {self._addr}) + # round() to nearest int, Account for local update also stored with + 1 + self._K_node = max(1 + 1, round(self._aggregation_fraction * external) + 1) - async def _check_k_node_satisfaction(self): - # 2) If we’re already stopping, skip any further checks - if self._should_stop_training: - return + async def init(self, _role_name): + """ + Subscribe to federation events and load config args. - active_peers_remaining = len(self._sources_expected - {self._addr} - self._terminated_peers) - required_external = self._K_node - 1 - logging.info( - f"[CAFF] Checking K_node feasibility: active_peers={active_peers_remaining}, required_external={required_external}" - ) + Args: + _role_name (str): Unused placeholder for role. + """ + cfg = self._aggregator.config - if active_peers_remaining < required_external: - logging.warning("[CAFF] Not enough active peers to reach K_node. Triggering self-termination.") - # flip the stop flag so everyone knows we’re shutting down - self._should_stop_training = True - await self.terminate_self() + self._aggregation_fraction = cfg.participant["caff_args"]["aggregation_fraction"] + self._staleness_threshold = cfg.participant["caff_args"]["staleness_threshold"] - # also let the engine know immediately - if hasattr(self._aggregator, "engine"): - logging.info("[CAFF] STOP: Notifying Engine to stop training") - self._aggregator.engine.force_stop_training() + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) async def round_expected_updates(self, federation_nodes: set): + """ + Reset round state when a new round starts. + + Args: + federation_nodes (set[str]): IDs of peers expected this round. + """ self._sources_expected = federation_nodes.copy() - # self._terminated_peers.clear() self._update_cache.clear() self._notification_sent = False - self._local_round = self._aggregator.engine.get_round() + self._local_round = await self._aggregator.engine.get_round() - # Exclude self from expected peers, then add +1 later in aggregation threshold - expected_peers = self._sources_expected - {self._addr} - self._K_node = max( - 1 + 1, round(self._aggregation_fraction * len(expected_peers)) + 1 - ) # <-- fix here with local update once min reached + # compute initial K_node + self._recalc_K_node() - self._local_update = None self._local_training_done = False - # if self._fallback_task: - # logging.info(f"[CAFF] FALLBACK TASK TRUE)") - # self._fallback_task.cancel() - # self._fallback_task = asyncio.create_task(self._start_fallback_timer()) - async def storage_update(self, updt_received_event: UpdateReceivedEvent): - model, weight, source, round_received, local_flag = await updt_received_event.get_event_data() + """ + Handle an incoming model update event. + + Args: + updt_received_event (UpdateReceivedEvent): Carries model, weight, source, and round. + """ + time_received = time.time() + model, weight, source, round_received, _ = await updt_received_event.get_event_data() await self._cache_lock.acquire_async() + # reject any update not in expected set + if source not in self._sources_expected: + logging.info( + f"[CAFF] Discard update | source: {source} not in expected updates for round {self._local_round}" + ) + await self._cache_lock.release_async() + return + + # store local update if self._addr == source: - update = Update(model, weight, source, round_received) - self._local_update = update - self._update_cache[source] = update # Make sure own update is in cache + update = Update(model, weight, source, round_received, time_received) + self._update_cache[source] = update self._local_training_done = True logging.info(f"[CAFF] Received own update from {source} (local training done)") - logging.info(f"[CAFF] [CACHE-AFTER] Update cache for round {self._local_round}: {self._update_cache}") + logging.info( + f"[CAFF] Round {self._local_round} aggregation check: {len(self._update_cache)} / {self._K_node} (local ready: {self._local_training_done})" + ) await self._cache_lock.release_async() await self._maybe_aggregate(force_check=True) return + # replace or discard update from peer already in cache if source in self._update_cache: cached_update = self._update_cache[source] if round_received > cached_update.round: - self._update_cache[source] = Update(model, weight, source, round_received) + self._update_cache[source] = Update(model, weight, source, round_received, time_received) logging.info(f"[CAFF] Replaced cached update from {source} with newer round {round_received}") else: logging.info(f"[CAFF] Discarded stale update from {source} (round {round_received})") @@ -146,13 +179,9 @@ async def storage_update(self, updt_received_event: UpdateReceivedEvent): await self._maybe_aggregate() return - # if source in self._terminated_peers: - # logging.info(f"[CAFF] Ignoring update from terminated peer {source}") - # await self._cache_lock.release_async() - # return - + # staleness filter for new peer update if self._local_round - round_received <= self._staleness_threshold: - self._update_cache[source] = Update(model, weight, source, round_received) + self._update_cache[source] = Update(model, weight, source, round_received, time_received) logging.info(f"[CAFF] Cached new update from {source} (round {round_received})") else: logging.info(f"[CAFF] Discarded stale update from {source} (round {round_received})") @@ -165,179 +194,104 @@ async def storage_update(self, updt_received_event: UpdateReceivedEvent): ) await self._cache_lock.release_async() - - # Trigger aggregate only if local training is complete if self._local_training_done: await self._maybe_aggregate() async def get_round_updates(self): + """ + Return all cached updates for aggregation. + + Returns: + dict[str, tuple]: Mapping of node -> (model, weight) pairs. + """ await self._cache_lock.acquire_async() updates = {peer: (u.model, u.weight) for peer, u in self._update_cache.items()} - if self._local_update: - updates[self._addr] = (self._local_update.model, self._local_update.weight) - logging.info(f"[CAFF] DEBUG local update = {self._local_update}-----------------") await self._cache_lock.release_async() + logging.info(f"[CAFF] Aggregating with {len(self._update_cache)} updates (incl. local)") + + ### TEST FOR CAFF - Send model updates only to last contributors + # Snapshot the peers who actually contributed this round + self._last_contributors = list(set(self._update_cache.keys()).difference({self._addr})) + ### TEST FOR CAFF - Send model updates only to last contributors + return updates async def notify_federation_update(self, updt_nei_event: UpdateNeighborEvent): + """ + Handle peers joining or leaving mid-round. + + When a peer leaves, remove it from expected set and recalc K_node. + """ source, remove = await updt_nei_event.get_event_data() + if remove: - self._terminated_peers.add(source) - logging.info(f"[CAFF] Peer {source} marked as terminated") + # peer left / finished + if source in self._sources_expected: + self._sources_expected.discard(source) + logging.info(f"[CAFF] Peer {source} removed → now expecting {self._sources_expected}") + # recompute threshold + self._recalc_K_node() + logging.info(f"[CAFF] Recomputed K_node={self._K_node}") + # maybe we can now aggregate earlier + await self._maybe_aggregate() + else: + # peer joined mid-round + self._sources_expected.add(source) + logging.info(f"[CAFF] Peer {source} joined → now expecting {self._sources_expected}") + self._recalc_K_node() + logging.info(f"[CAFF] Recomputed K_node={self._K_node}") async def get_round_missing_nodes(self): - return self._sources_expected.difference(self._update_cache.keys()).difference(self._terminated_peers) + """ + List peers whose updates have not yet arrived. + + Returns: + set[str]: IDs of missing peers. + """ + return self._sources_expected.difference(self._update_cache.keys()) async def notify_if_all_updates_received(self): - logging.info("[CAFF] DEBUG NOTIFY ALL UPDATES RECEIVED---------------") + """ + Trigger aggregation check immediately when all conditions may be met. + """ + logging.info("[CAFF] Timer started and waiting for updates to reach threshold K_node") await self._maybe_aggregate() async def stop_notifying_updates(self): + """ + Cancel any pending fallback notification timers. + """ if self._fallback_task: self._fallback_task.cancel() self._fallback_task = None - # async def mark_self_terminated(self): - # source_id = self._aggregator.get_id() - # await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(source_id)) - - async def handle_node_termination(self, event: NodeTerminatedEvent): - source = await event.get_event_data() - if source in self._terminated_peers: - return - self._terminated_peers.add(source) - logging.info(f"[CAFF] Peer {source} terminated, remaining expected: {self._sources_expected}") - - # re-evaluate K_node feasibility immediately - await self._check_k_node_satisfaction() - - # if source not in self._terminated_peers: - # self._terminated_peers.add(source) - # #self._sources_expected.discard(source) - # logging.info(f"[CAFF] Node {source} terminated") - # logging.info(f"[CAFF] Updated terminated peers: {self._terminated_peers}") - # logging.info(f"[CAFF] Remaining active peers: {self._sources_expected - {self._addr} - self._terminated_peers}") - # if not self._local_training_done: - # await self. (force_check=True) - # else: - # logging.info("[CAFF] Ignoring termination impact – local training already done.") - async def _maybe_aggregate(self, force_check=False): - logging.info( - f"[{self._addr}] ENTERED _maybe_aggregate (round={self._local_round}) | local_done={self._local_training_done} | cache={list(self._update_cache.keys())}" - ) + """ + Check if K_node threshold and local training are satisfied; if so, notify aggregator. + """ await self._cache_lock.acquire_async() + # if aggregation already started leave function if self._notification_sent and not force_check: - logging.info("CAFF DEBUGG got out of notification send") await self._cache_lock.release_async() return - # if self._addr in self._update_cache: - # logging.info(f"----------------[CAFF] LOCAL UPDATE IS IN CACHE -------------------------") - # self._local_training_done = True - peer_count = len(self._update_cache) has_local = self._local_training_done - logging.info( - f"[CAFF] Round {self._local_round} aggregation check: {peer_count} / {self._K_node} (local ready: {has_local})" - ) - - if self._addr in self._update_cache: - logging.info(f"[CAFF] [MAYBE_AGGREGATE] Local update is already cached for round {self._local_round}") - else: - logging.warning( - f"[CAFF] [MAYBE_AGGREGATE] Local update is NOT cached at aggregation time (round {self._local_round})" - ) - + # wait for local model if not has_local: logging.debug("[CAFF] Waiting for local training to finish before aggregation.") await self._cache_lock.release_async() return - # ✅ NEW: For round 0, enforce synchronous wait for all expected sources - if self._local_round == 0: - expected = self._sources_expected | {self._addr} - logging.info(f"Expected: {expected}") - received = set(self._update_cache.keys()) - logging.info(f"Received: {received}") - missing = expected - received - logging.info(f"Missing: {missing}") - if missing: - logging.info(f"[CAFF][ROUND 0 SYNC] Waiting for all updates in round 0. Still missing: {missing}") - await self._cache_lock.release_async() - return - else: - logging.info("[CAFF][ROUND 0 SYNC] All updates received for round 0 — proceeding with aggregation") - + # if enough peers ready, start aggregation if has_local and peer_count >= self._K_node: - logging.info(f"[CAFF] Aggregating with {peer_count} updates (incl. local)") + logging.info(f"[CAFF] K_node threshold reached and local training finished: Starting aggregation process..") self._notification_sent = True await self._cache_lock.release_async() - await asyncio.sleep(0.5) + logging.info("[CAFF] 🔄 Notifying aggregator to release aggregation") await self.agg.notify_all_updates_received() return - active_peers_remaining = len(self._sources_expected - {self._addr} - self._terminated_peers) - required_external = self._K_node - 1 - - if active_peers_remaining < required_external: - logging.warning("[CAFF] Not enough active peers to reach K_node. Triggering self-termination.") - self._should_stop_training = True - logging.warning( - f"[CAFF] Triggered self-termination due to unsatisfiable K_node condition (active: {active_peers_remaining}) | (required: {required_external})" - ) - await self.terminate_self() - if hasattr(self._aggregator, "engine"): - self._aggregator.engine.force_stop_training() - return - await self._cache_lock.release_async() - - # async def _start_fallback_timer(self): - # await asyncio.sleep(self._fallback_timeout) - # await self._cache_lock.acquire_async() - # if self._notification_sent: - # await self._cache_lock.release_async() - # return - # if self._update_cache and self._local_training_done: - # logging.warning("[CAFF] Fallback triggered — aggregating with partial cache.") - # self._notification_sent = True - # await self._cache_lock.release_async() - # await self.agg.notify_all_updates_received() - # else: - # await self._cache_lock.release_async() - - async def should_continue_training(self) -> bool: - logging.info( - f"[CAFF] should_continue_training = {not self._should_stop_training} (should_stop_training: {self._should_stop_training})" - ) - return not self._should_stop_training - - async def terminate_self(self): - # 3) Idempotency: only run this block once - if self._notification_sent: - return - - # mark that we've broadcast our termination - self._notification_sent = True - self._should_stop_training = True - - logging.warning(f"[CAFF] Node {self._addr} terminating itself.") - - # cancel the fallback timer if it’s still pending - # if self._fallback_task: - # self._fallback_task.cancel() - # self._fallback_task = None - - # publish our own termination event - await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(self._addr)) - - # send the control‐terminated message exactly once - terminate_msg = self._aggregator.engine.cm.create_message("control", "terminated") - await self._aggregator.engine.cm.send_message_to_neighbors(terminate_msg) - - # also flip the engine’s flag so the loop exits ASAP - if hasattr(self._aggregator, "engine"): - self._aggregator.engine.force_stop_training() diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index 67c24a6e..4d78a1d7 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -112,11 +112,11 @@ def factory_update_handler(updt_handler, aggregator, addr) -> UpdateHandler: from nebula.core.aggregation.updatehandlers.dflupdatehandler import DFLUpdateHandler from nebula.core.aggregation.updatehandlers.sdflupdatehandler import SDFLUpdateHandler - # added by Eduard to choose between communication strategy once DFL is selected + # Choose between communication strategy once DFL is selected if updt_handler == "DFL": mechanism = aggregator.config.participant.get("communication_args", {}).get( "mechanism", "standard" - ) # strategy can be either standard or CAFF + ) # Strategy for DFL can be either standard or CAFF if mechanism == "CAFF": return CAFFUpdateHandler(aggregator, addr) else: diff --git a/nebula/core/engine.py b/nebula/core/engine.py index ff781e6a..d11eaa0a 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -17,7 +17,6 @@ from nebula.core.nebulaevents import ( AggregationEvent, ExperimentFinishEvent, - NodeTerminatedEvent, RoundEndEvent, RoundStartEvent, UpdateNeighborEvent, @@ -41,7 +40,6 @@ import sys from nebula.config.config import Config -from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler from nebula.core.training.lightning import Lightning @@ -116,16 +114,6 @@ def __init__( self._trainer = trainer(model, datamodule, config=self.config) self._aggregator = create_aggregator(config=self.config, engine=self) - self._termination_sent = False # <-- needed for CAFF mechanism - self._using_caff = self.config.participant.get("communication_args", {}).get("mechanism", "").lower() == "caff" - if self._using_caff: - logging.info(f"[{self.addr}] Communication mechanism: CAFF") - else: - logging.info( - f"[{self.addr}] Communication mechanism: {self.config.participant.get('communication_args', {}).get('mechanism', 'standard')}" - ) - self._caff_force_terminate = False - self._secure_neighbors = [] self._is_malicious = self.config.participant["adversarial_args"]["attack_params"]["attacks"] != "No Attack" @@ -255,18 +243,6 @@ def set_round(self, new_round): self.round = new_round self.trainer.set_current_round(new_round) - def _get_caff_handler(self) -> CAFFUpdateHandler | None: - if self.config.participant["scenario_args"]["federation"].lower() == "dfl" and self._using_caff: - from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler - - handler = getattr(self.aggregator, "_update_storage", None) - if isinstance(handler, CAFFUpdateHandler): - return handler - return None - - def force_stop_training(self): - self._caff_force_terminate = True - """ ############################## # MODEL CALLBACKS # ############################## @@ -941,68 +917,6 @@ async def _shutdown_protocol(self): await self.shutdown() return - # If it's CAFF, log what should_continue_training() returns - caff_handler = self._get_caff_handler() - if caff_handler: - should_continue = await caff_handler.should_continue_training() - logging.info(f"[{self.addr}] CAFF should_continue_training() result: {should_continue}") - else: - should_continue = False - logging.info(f"[{self.addr}] Not using CAFF or not a DFL scenario — proceeding to terminate.") - - if ( - not caff_handler or not await caff_handler.should_continue_training() - ): # <--- comment out if training to stop as soon as first node reaches finish line - if self.config.participant["scenario_args"]["deployment"] == "docker": - try: - docker_id = socket.gethostname() - logging.info(f"📦 Killing docker container with ID {docker_id}") - self.client.containers.get(docker_id).kill() - except Exception as e: - logging.exception(f"📦 Error stopping Docker container with ID {docker_id}: {e}") - - async def _extended_learning_cycle(self): - """ - This method is called in each round of the learning cycle. It is used to extend the learning cycle with additional - functionalities. The method is called in the _learning_cycle method. - """ - pass - - async def _continue_training(self): - if self.round is None: - logging.info("LEAVING LEARNING CYCLE") - return False - - # If using CAFF: check if early termination has been triggered - if self._using_caff: - caff_handler = self._get_caff_handler() - if caff_handler: - should_continue = await caff_handler.should_continue_training() - if not should_continue: - logging.info(f"[{self.addr}] CAFF handler requested early stop. Terminating training.") - # Send termination alert only once - if not self._termination_sent: - await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(self.addr)) - terminate_msg = self.cm.create_message("control", "terminated") - await self.cm.send_message_to_neighbors(terminate_msg) - self._termination_sent = True - return False - - # Regular case: keep training if rounds left - if self.round < self.total_rounds: - return True - - # If max round is reached: send CAFF termination alert only once - if self._using_caff: - if not self._termination_sent: - logging.info("[CAFF] Max rounds reached. Announcing termination to peers.") - await EventManager.get_instance().publish_node_event(NodeTerminatedEvent(self.addr)) - terminate_msg = self.cm.create_message("control", "terminated") - await self.cm.send_message_to_neighbors(terminate_msg) - self._termination_sent = True - - # Never train beyond max rounds — not even for CAFF - return False async def shutdown(self): logging.info("🚦 Engine shutdown initiated") diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index bc207dcf..286be6b5 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -144,20 +144,6 @@ async def is_concurrent(self): return False -class NodeTerminatedEvent(NodeEvent): - def __init__(self, source_id: str): - self._source_id = source_id - - def __str__(self): - return f"NodeTerminatedEvent from {self._source_id}" - - async def get_event_data(self) -> str: - return self._source_id - - async def is_concurrent(self) -> bool: - return True - - class AggregationEvent(NodeEvent): def __init__(self, updates: dict, expected_nodes: set, missing_nodes: set): """Event triggered when model aggregation is ready. diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 717ea5f9..d9cbff3e 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -7,6 +7,10 @@ from nebula.core.eventmanager import EventManager from typing import TYPE_CHECKING, Any +### TEST +from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler +### TEST + from nebula.addons.functions import print_msg_box if TYPE_CHECKING: @@ -308,7 +312,9 @@ async def _propagate(self, mpe: ModelPropagationEvent): bool: True if propagation occurred (payload sent), False if halted early. """ eligible_neighbors, strategy_id = await mpe.get_event_data() - + + + self.reset_status_history() if strategy_id not in self.strategies: logging.info(f"Strategy {strategy_id} not found.") @@ -320,6 +326,20 @@ async def _propagate(self, mpe: ModelPropagationEvent): strategy = self.strategies[strategy_id] logging.info(f"Starting model propagation with strategy: {strategy_id}") + + + ### TEST FOR CAFF - Send model updates only to last contributors + #current_round = await self.get_round() + #handler = self.aggregator.us + #logging.info(f"HANDLER for CAFF: {handler}, current_round: {current_round}") + + #if isinstance(handler, CAFFUpdateHandler) and current_round > 0: + # logging.info("exchanging eligible neighbor list to last contributed peers") + # eligible_neighbors = handler.last_contributors.copy() + ### TEST FOR CAFF - Send model updates only to last contributors + + + # current_connections = await self.cm.get_addrs_current_connections(only_direct=True) # eligible_neighbors = [ # neighbor_addr for neighbor_addr in current_connections if await strategy.is_node_eligible(neighbor_addr) diff --git a/nebula/frontend/templates/deployment.html b/nebula/frontend/templates/deployment.html index dbd95a3e..9dbfa36a 100755 --- a/nebula/frontend/templates/deployment.html +++ b/nebula/frontend/templates/deployment.html @@ -290,7 +290,7 @@
Number of rounds
Communication Mechanism
- +