diff --git a/nebula/controller/scenarios.py b/nebula/controller/scenarios.py index acf1e98e..51b262e9 100644 --- a/nebula/controller/scenarios.py +++ b/nebula/controller/scenarios.py @@ -6,15 +6,12 @@ import math import os import shutil -import subprocess -import sys import time from datetime import datetime from urllib.parse import quote -from aiohttp import FormData import docker -import tensorboard_reducer as tbr +from aiohttp import FormData from nebula.addons.topologymanager import TopologyManager from nebula.config.config import Config @@ -108,6 +105,9 @@ def __init__( sar_neighbor_policy, sar_training, sar_training_policy, + aggregator_args, # <-- NEW + communication_args, # <-- NEW + caff_args, # <-- NEW physical_ips=None, ): """ @@ -199,30 +199,30 @@ def __init__( self.mobile_participants_percent = mobile_participants_percent self.additional_participants = additional_participants self.with_trustworthiness = with_trustworthiness - self.robustness_pillar = robustness_pillar, - self.resilience_to_attacks = resilience_to_attacks, - self.algorithm_robustness = algorithm_robustness, - self.client_reliability = client_reliability, - self.privacy_pillar = privacy_pillar, - self.technique = technique, - self.uncertainty = uncertainty, - self.indistinguishability = indistinguishability, - self.fairness_pillar = fairness_pillar, - self.selection_fairness = selection_fairness, - self.performance_fairness = performance_fairness, - self.class_distribution = class_distribution, - self.explainability_pillar = explainability_pillar, - self.interpretability = interpretability, - self.post_hoc_methods = post_hoc_methods, - self.accountability_pillar = accountability_pillar, - self.factsheet_completeness = factsheet_completeness, - self.architectural_soundness_pillar = architectural_soundness_pillar, - self.client_management = client_management, - self.optimization = optimization, - self.sustainability_pillar = sustainability_pillar, - self.energy_source = energy_source, - self.hardware_efficiency = hardware_efficiency, - self.federation_complexity = federation_complexity, + self.robustness_pillar = (robustness_pillar,) + self.resilience_to_attacks = (resilience_to_attacks,) + self.algorithm_robustness = (algorithm_robustness,) + self.client_reliability = (client_reliability,) + self.privacy_pillar = (privacy_pillar,) + self.technique = (technique,) + self.uncertainty = (uncertainty,) + self.indistinguishability = (indistinguishability,) + self.fairness_pillar = (fairness_pillar,) + self.selection_fairness = (selection_fairness,) + self.performance_fairness = (performance_fairness,) + self.class_distribution = (class_distribution,) + self.explainability_pillar = (explainability_pillar,) + self.interpretability = (interpretability,) + self.post_hoc_methods = (post_hoc_methods,) + self.accountability_pillar = (accountability_pillar,) + self.factsheet_completeness = (factsheet_completeness,) + self.architectural_soundness_pillar = (architectural_soundness_pillar,) + self.client_management = (client_management,) + self.optimization = (optimization,) + self.sustainability_pillar = (sustainability_pillar,) + self.energy_source = (energy_source,) + self.hardware_efficiency = (hardware_efficiency,) + self.federation_complexity = (federation_complexity,) self.schema_additional_participants = schema_additional_participants self.random_topology_probability = random_topology_probability self.with_sa = with_sa @@ -234,6 +234,9 @@ def __init__( self.sar_training = sar_training self.sar_training_policy = sar_training_policy self.physical_ips = physical_ips + self.aggregator_args = aggregator_args # <-- NEW + self.communication_args = communication_args # <-- NEW + self.caff_args = caff_args # <-- NEW def attack_node_assign( self, @@ -694,6 +697,11 @@ def __init__(self, scenario, user=None): participant_config["device_args"]["gpu_id"] = self.scenario.gpu_id participant_config["device_args"]["logging"] = self.scenario.logginglevel participant_config["aggregator_args"]["algorithm"] = self.scenario.agg_algorithm + participant_config["aggregator_args"]["aggregation_timeout"] = int( + self.scenario.aggregator_args["aggregation_timeout"] + ) # <-- New Addition + participant_config["communication_args"] = self.scenario.communication_args # <-- New Addition + participant_config["caff_args"] = self.scenario.caff_args # <-- New Addition # To be sure that benign nodes have no attack parameters if node_config["role"] == "malicious": participant_config["adversarial_args"]["fake_behavior"] = node_config["fake_behavior"] diff --git a/nebula/core/aggregation/updatehandlers/caffupdatehandler.py b/nebula/core/aggregation/updatehandlers/caffupdatehandler.py new file mode 100644 index 00000000..532b5c3e --- /dev/null +++ b/nebula/core/aggregation/updatehandlers/caffupdatehandler.py @@ -0,0 +1,297 @@ +import asyncio +import time +import logging +from typing import TYPE_CHECKING, Any + +from nebula.core.aggregation.updatehandlers.updatehandler import UpdateHandler +from nebula.core.eventmanager import EventManager +from nebula.core.nebulaevents import UpdateNeighborEvent, UpdateReceivedEvent +from nebula.core.utils.locker import Locker + +if TYPE_CHECKING: + from nebula.core.aggregation.aggregator import Aggregator + + +class Update: + """ + Holds a single model update along with metadata. + + Attributes: + model (object): The serialized model or weights. + weight (float): The importance weight of the update. + source (str): Identifier of the sending node. + round (int): Training round number of this update. + time_received (float): Timestamp when update arrived. + """ + def __init__(self, model, weight, source, round, time_received): + self.model = model + self.weight = weight + self.source = source + self.round = round + self.time_received = time_received + + +class CAFFUpdateHandler(UpdateHandler): + """ + CAFF: Cache-based Aggregation with Fairness and Filterin Update Handler. + + Manages peer updates asynchronously, applies a staleness filter, + and triggers aggregation once a dynamic threshold K_node is met. + """ + def __init__(self, aggregator: "Aggregator", addr: str): + """ + Initialize handler state and async locks. + + Args: + aggregator (Aggregator): The federation aggregator instance. + addr (str): This node's unique address. + """ + self._addr = addr + self._aggregator = aggregator + + self._update_cache: dict[str, Update] = {} + self._sources_expected: set[str] = set() + + self._cache_lock = Locker(name="caff_cache_lock", async_lock=True) + + self._fallback_task = None + self._notification_sent = False + + self._aggregation_fraction = 1.0 + self._staleness_threshold = 1 + + self._local_round = 0 + self._K_node = 1 + + self._local_training_done = False + + ### TEST FOR CAFF - Send model updates only to last contributors + self._last_contributors: list[str] = [] + ### TEST FOR CAFF - Send model updates only to last contributors + + + @property + def us(self): + """Returns the internal update cache mapping source -> Update.""" + return self._update_cache + + @property + def agg(self): + """Returns the linked aggregator.""" + return self._aggregator + + @property + def last_contributors(self) -> list[str]: + """ + The list of peer-addresses who actually contributed in the previous round. + """ + return self._last_contributors + + def _recalc_K_node(self): + """ + Recalculate K_node based on current expected peers and fraction. + + Ensures at least one external peer and includes local update. + """ + # when only local node is still active and all others finished training + if len(self._sources_expected) == 1: + self._K_node = 1 + else: + # count only “external” peers + external = len(self._sources_expected - {self._addr}) + # round() to nearest int, Account for local update also stored with + 1 + self._K_node = max(1 + 1, round(self._aggregation_fraction * external) + 1) + + async def init(self, _role_name): + """ + Subscribe to federation events and load config args. + + Args: + _role_name (str): Unused placeholder for role. + """ + cfg = self._aggregator.config + + self._aggregation_fraction = cfg.participant["caff_args"]["aggregation_fraction"] + self._staleness_threshold = cfg.participant["caff_args"]["staleness_threshold"] + + await EventManager.get_instance().subscribe_node_event(UpdateNeighborEvent, self.notify_federation_update) + await EventManager.get_instance().subscribe_node_event(UpdateReceivedEvent, self.storage_update) + + async def round_expected_updates(self, federation_nodes: set): + """ + Reset round state when a new round starts. + + Args: + federation_nodes (set[str]): IDs of peers expected this round. + """ + self._sources_expected = federation_nodes.copy() + self._update_cache.clear() + self._notification_sent = False + self._local_round = await self._aggregator.engine.get_round() + + # compute initial K_node + self._recalc_K_node() + + self._local_training_done = False + + async def storage_update(self, updt_received_event: UpdateReceivedEvent): + """ + Handle an incoming model update event. + + Args: + updt_received_event (UpdateReceivedEvent): Carries model, weight, source, and round. + """ + time_received = time.time() + model, weight, source, round_received, _ = await updt_received_event.get_event_data() + await self._cache_lock.acquire_async() + + # reject any update not in expected set + if source not in self._sources_expected: + logging.info( + f"[CAFF] Discard update | source: {source} not in expected updates for round {self._local_round}" + ) + await self._cache_lock.release_async() + return + + # store local update + if self._addr == source: + update = Update(model, weight, source, round_received, time_received) + self._update_cache[source] = update + self._local_training_done = True + logging.info(f"[CAFF] Received own update from {source} (local training done)") + logging.info( + f"[CAFF] Round {self._local_round} aggregation check: {len(self._update_cache)} / {self._K_node} (local ready: {self._local_training_done})" + ) + await self._cache_lock.release_async() + await self._maybe_aggregate(force_check=True) + return + + # replace or discard update from peer already in cache + if source in self._update_cache: + cached_update = self._update_cache[source] + if round_received > cached_update.round: + self._update_cache[source] = Update(model, weight, source, round_received, time_received) + logging.info(f"[CAFF] Replaced cached update from {source} with newer round {round_received}") + else: + logging.info(f"[CAFF] Discarded stale update from {source} (round {round_received})") + await self._cache_lock.release_async() + if self._local_training_done: + await self._maybe_aggregate() + return + + # staleness filter for new peer update + if self._local_round - round_received <= self._staleness_threshold: + self._update_cache[source] = Update(model, weight, source, round_received, time_received) + logging.info(f"[CAFF] Cached new update from {source} (round {round_received})") + else: + logging.info(f"[CAFF] Discarded stale update from {source} (round {round_received})") + + peer_count = len(self._update_cache) + has_local = self._local_training_done + + logging.info( + f"[CAFF] Round {self._local_round} aggregation check: {peer_count} / {self._K_node} (local ready: {has_local})" + ) + + await self._cache_lock.release_async() + if self._local_training_done: + await self._maybe_aggregate() + + async def get_round_updates(self): + """ + Return all cached updates for aggregation. + + Returns: + dict[str, tuple]: Mapping of node -> (model, weight) pairs. + """ + await self._cache_lock.acquire_async() + updates = {peer: (u.model, u.weight) for peer, u in self._update_cache.items()} + await self._cache_lock.release_async() + logging.info(f"[CAFF] Aggregating with {len(self._update_cache)} updates (incl. local)") + + ### TEST FOR CAFF - Send model updates only to last contributors + # Snapshot the peers who actually contributed this round + self._last_contributors = list(set(self._update_cache.keys()).difference({self._addr})) + ### TEST FOR CAFF - Send model updates only to last contributors + + return updates + + async def notify_federation_update(self, updt_nei_event: UpdateNeighborEvent): + """ + Handle peers joining or leaving mid-round. + + When a peer leaves, remove it from expected set and recalc K_node. + """ + source, remove = await updt_nei_event.get_event_data() + + if remove: + # peer left / finished + if source in self._sources_expected: + self._sources_expected.discard(source) + logging.info(f"[CAFF] Peer {source} removed → now expecting {self._sources_expected}") + # recompute threshold + self._recalc_K_node() + logging.info(f"[CAFF] Recomputed K_node={self._K_node}") + # maybe we can now aggregate earlier + await self._maybe_aggregate() + else: + # peer joined mid-round + self._sources_expected.add(source) + logging.info(f"[CAFF] Peer {source} joined → now expecting {self._sources_expected}") + self._recalc_K_node() + logging.info(f"[CAFF] Recomputed K_node={self._K_node}") + + async def get_round_missing_nodes(self): + """ + List peers whose updates have not yet arrived. + + Returns: + set[str]: IDs of missing peers. + """ + return self._sources_expected.difference(self._update_cache.keys()) + + async def notify_if_all_updates_received(self): + """ + Trigger aggregation check immediately when all conditions may be met. + """ + logging.info("[CAFF] Timer started and waiting for updates to reach threshold K_node") + await self._maybe_aggregate() + + async def stop_notifying_updates(self): + """ + Cancel any pending fallback notification timers. + """ + if self._fallback_task: + self._fallback_task.cancel() + self._fallback_task = None + + async def _maybe_aggregate(self, force_check=False): + """ + Check if K_node threshold and local training are satisfied; if so, notify aggregator. + """ + await self._cache_lock.acquire_async() + + # if aggregation already started leave function + if self._notification_sent and not force_check: + await self._cache_lock.release_async() + return + + peer_count = len(self._update_cache) + has_local = self._local_training_done + + # wait for local model + if not has_local: + logging.debug("[CAFF] Waiting for local training to finish before aggregation.") + await self._cache_lock.release_async() + return + + # if enough peers ready, start aggregation + if has_local and peer_count >= self._K_node: + logging.info(f"[CAFF] K_node threshold reached and local training finished: Starting aggregation process..") + self._notification_sent = True + await self._cache_lock.release_async() + logging.info("[CAFF] 🔄 Notifying aggregator to release aggregation") + await self.agg.notify_all_updates_received() + return + + await self._cache_lock.release_async() diff --git a/nebula/core/aggregation/updatehandlers/updatehandler.py b/nebula/core/aggregation/updatehandlers/updatehandler.py index f3484923..4d78a1d7 100644 --- a/nebula/core/aggregation/updatehandlers/updatehandler.py +++ b/nebula/core/aggregation/updatehandlers/updatehandler.py @@ -107,15 +107,32 @@ async def stop_notifying_updates(self): def factory_update_handler(updt_handler, aggregator, addr) -> UpdateHandler: + from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler from nebula.core.aggregation.updatehandlers.cflupdatehandler import CFLUpdateHandler from nebula.core.aggregation.updatehandlers.dflupdatehandler import DFLUpdateHandler from nebula.core.aggregation.updatehandlers.sdflupdatehandler import SDFLUpdateHandler - UPDATE_HANDLERS = {"DFL": DFLUpdateHandler, "CFL": CFLUpdateHandler, "SDFL": SDFLUpdateHandler} - - update_handler = UPDATE_HANDLERS.get(updt_handler) - - if update_handler: - return update_handler(aggregator, addr) + # Choose between communication strategy once DFL is selected + if updt_handler == "DFL": + mechanism = aggregator.config.participant.get("communication_args", {}).get( + "mechanism", "standard" + ) # Strategy for DFL can be either standard or CAFF + if mechanism == "CAFF": + return CAFFUpdateHandler(aggregator, addr) + else: + return DFLUpdateHandler(aggregator, addr) + elif updt_handler == "CFL": + return CFLUpdateHandler(aggregator, addr) + elif updt_handler == "SDFL": + return SDFLUpdateHandler(aggregator, addr) else: raise UpdateHandlerException(f"Update Handler {updt_handler} not found") + + # UPDATE_HANDLERS = {"DFL": DFLUpdateHandler, "CFL": CFLUpdateHandler, "SDFL": SDFLUpdateHandler, "CAFF": CAFFUpdateHandler} + + # update_handler = UPDATE_HANDLERS.get(updt_handler) + + # if update_handler: + # return update_handler(aggregator, addr) + # else: + # raise UpdateHandlerException(f"Update Handler {updt_handler} not found") diff --git a/nebula/core/engine.py b/nebula/core/engine.py index 151ae6b2..d11eaa0a 100644 --- a/nebula/core/engine.py +++ b/nebula/core/engine.py @@ -94,7 +94,7 @@ def __init__( self.ip = config.participant["network_args"]["ip"] self.port = config.participant["network_args"]["port"] self.addr = config.participant["network_args"]["addr"] - + self.name = config.participant["device_args"]["name"] self.client = docker.from_env() @@ -187,7 +187,7 @@ def aggregator(self): def trainer(self): """Trainer""" return self._trainer - + @property def rb(self): """Role Behavior""" @@ -317,7 +317,7 @@ async def _control_alive_callback(self, source, message): async def _control_leadership_transfer_callback(self, source, message): logging.info(f"🔧 handle_control_message | Trigger | Received leadership transfer message from {source}") - + if await self._round_in_process_lock.locked_async(): logging.info("Learning cycle is executing, role behavior will be modified next round") await self.rb.set_next_role(Role.AGGREGATOR, source_to_notificate=source) @@ -354,7 +354,7 @@ async def _control_leadership_transfer_ack_callback(self, source, message): except TimeoutError: logging.info("Learning cycle is locked, role behavior will be modified next round") await self.rb.set_next_role(Role.TRAINER) - + async def _connection_connect_callback(self, source, message): logging.info(f"🔗 handle_connection_message | Trigger | Received connection message from {source}") @@ -710,10 +710,10 @@ async def _start_learning(self): await self.get_federation_ready_lock().acquire_async() if self.config.participant["device_args"]["start"]: logging.info("Propagate initial model updates.") - + mpe = ModelPropagationEvent(await self.cm.get_addrs_current_connections(only_direct=True, myself=False), "initialization") await EventManager.get_instance().publish_node_event(mpe) - + await self.get_federation_ready_lock().release_async() self.trainer.set_epochs(epochs) @@ -764,7 +764,7 @@ async def learning_cycle_finished(self): return False else: return current_round >= self.total_rounds - + async def resolve_missing_updates(self): """ Delegates the resolution strategy for missing updates to the current role behavior. @@ -778,7 +778,7 @@ async def resolve_missing_updates(self): """ logging.info(f"Using Role behavior: {self.rb.get_role_name()} conflict resolve strategy") return await self.rb.resolve_missing_updates() - + async def update_self_role(self): """ Checks whether a role update is required and performs the transition if necessary. @@ -806,7 +806,7 @@ async def update_self_role(self): logging.info(f"Sending role modification ACK to transferer: {source_to_notificate}") message = self.cm.create_message("control", "leadership_transfer_ack") asyncio.create_task(self.cm.send_message(source_to_notificate, message)) - + async def _learning_cycle(self): """ Main asynchronous loop for executing the Federated Learning process across multiple rounds. @@ -837,9 +837,9 @@ async def _learning_cycle(self): indent=2, title="Round information", ) - + await self.update_self_role() - + logging.info(f"Federation nodes: {self.federation_nodes}") await self.update_federation_nodes( await self.cm.get_addrs_current_connections(only_direct=True, myself=True) @@ -851,10 +851,10 @@ async def _learning_cycle(self): logging.info(f"Expected nodes: {expected_nodes}") direct_connections = await self.cm.get_addrs_current_connections(only_direct=True) undirected_connections = await self.cm.get_addrs_current_connections(only_undirected=True) - + logging.info(f"Direct connections: {direct_connections} | Undirected connections: {undirected_connections}") logging.info(f"[Role {self.rb.get_role_name()}] Starting learning cycle...") - + await self.aggregator.update_federation_nodes(expected_nodes) async with self._role_behavior_performance_lock: await self.rb.extended_learning_cycle() @@ -878,17 +878,18 @@ async def _learning_cycle(self): ) # Set current round in config (send to the controller) await self.get_round_lock().release_async() + logging.info(f"[{self.addr}] Training finished at round {self.round}, entering idle mode.") # End of the learning cycle self.trainer.on_learning_cycle_end() await self.trainer.test() - + # Shutdown protocol await self._shutdown_protocol() - + async def _shutdown_protocol(self): logging.info("Starting graceful shutdown process...") - + # 1.- Publish Experiment Finish Event to the last update on modules logging.info("Publishing Experiment Finish Event...") efe = ExperimentFinishEvent() diff --git a/nebula/core/nebulaevents.py b/nebula/core/nebulaevents.py index f2ec0883..286be6b5 100644 --- a/nebula/core/nebulaevents.py +++ b/nebula/core/nebulaevents.py @@ -5,7 +5,7 @@ class AddonEvent(ABC): """ Abstract base class for all addon-related events in the system. """ - + @abstractmethod async def get_event_data(self): """ @@ -21,7 +21,7 @@ class NodeEvent(ABC): """ Abstract base class for all node-related events in the system. """ - + @abstractmethod async def get_event_data(self): """ @@ -52,7 +52,7 @@ class MessageEvent: source (str): Address or identifier of the message sender. message (Any): The actual message payload. """ - + def __init__(self, message_type, source, message): """ Initializes a MessageEvent instance. @@ -264,7 +264,7 @@ async def get_event_data(self) -> tuple[str, bool]: async def is_concurrent(self) -> bool: return True - + class ModelPropagationEvent(NodeEvent): def __init__(self, eligible_neighbors, strategy): """Event triggered when model propagation is ready. @@ -275,7 +275,7 @@ def __init__(self, eligible_neighbors, strategy): """ self.eligible_neighbors = eligible_neighbors self._strategy = strategy - + def __str__(self): return f"Model propagation event, strategy: {self._strategy}" @@ -291,8 +291,8 @@ async def get_event_data(self) -> tuple[set, str]: return (self.eligible_neighbors, self._strategy) async def is_concurrent(self) -> bool: - return False - + return False + class UpdateReceivedEvent(NodeEvent): @@ -362,7 +362,7 @@ async def get_event_data(self) -> tuple[str, tuple[float, float]]: async def is_concurrent(self) -> bool: return True - + class DuplicatedMessageEvent(NodeEvent): """ Event triggered when a message is received that has already been processed. @@ -370,7 +370,7 @@ class DuplicatedMessageEvent(NodeEvent): Attributes: source (str): The address of the node that sent the duplicated message. """ - + def __init__(self, source: str, message_type: str): self.source = source @@ -396,7 +396,7 @@ class GPSEvent(AddonEvent): Attributes: distances (dict): A dictionary mapping node addresses to their respective distances. """ - + def __init__(self, distances: dict): """ Initializes a GPSEvent. @@ -427,7 +427,7 @@ class ChangeLocationEvent(AddonEvent): latitude (float): New latitude of the node. longitude (float): New longitude of the node. """ - + def __init__(self, latitude, longitude): """ Initializes a ChangeLocationEvent. @@ -450,7 +450,8 @@ async def get_event_data(self): tuple: A tuple containing latitude and longitude. """ return (self.latitude, self.longitude) - + + class TestMetricsEvent(AddonEvent): def __init__(self, loss, accuracy): self._loss = loss diff --git a/nebula/core/network/actions.py b/nebula/core/network/actions.py index 77e1997c..109bbab5 100644 --- a/nebula/core/network/actions.py +++ b/nebula/core/network/actions.py @@ -47,6 +47,7 @@ class ControlAction(Enum): WEAK_LINK = nebula_pb2.ControlMessage.Action.WEAK_LINK LEADERSHIP_TRANSFER = nebula_pb2.ControlMessage.Action.LEADERSHIP_TRANSFER LEADERSHIP_TRANSFER_ACK = nebula_pb2.ControlMessage.Action.LEADERSHIP_TRANSFER_ACK + TERMINATED = nebula_pb2.ControlMessage.Action.Value("TERMINATED") class DiscoverAction(Enum): diff --git a/nebula/core/network/propagator.py b/nebula/core/network/propagator.py index 717ea5f9..d9cbff3e 100755 --- a/nebula/core/network/propagator.py +++ b/nebula/core/network/propagator.py @@ -7,6 +7,10 @@ from nebula.core.eventmanager import EventManager from typing import TYPE_CHECKING, Any +### TEST +from nebula.core.aggregation.updatehandlers.caffupdatehandler import CAFFUpdateHandler +### TEST + from nebula.addons.functions import print_msg_box if TYPE_CHECKING: @@ -308,7 +312,9 @@ async def _propagate(self, mpe: ModelPropagationEvent): bool: True if propagation occurred (payload sent), False if halted early. """ eligible_neighbors, strategy_id = await mpe.get_event_data() - + + + self.reset_status_history() if strategy_id not in self.strategies: logging.info(f"Strategy {strategy_id} not found.") @@ -320,6 +326,20 @@ async def _propagate(self, mpe: ModelPropagationEvent): strategy = self.strategies[strategy_id] logging.info(f"Starting model propagation with strategy: {strategy_id}") + + + ### TEST FOR CAFF - Send model updates only to last contributors + #current_round = await self.get_round() + #handler = self.aggregator.us + #logging.info(f"HANDLER for CAFF: {handler}, current_round: {current_round}") + + #if isinstance(handler, CAFFUpdateHandler) and current_round > 0: + # logging.info("exchanging eligible neighbor list to last contributed peers") + # eligible_neighbors = handler.last_contributors.copy() + ### TEST FOR CAFF - Send model updates only to last contributors + + + # current_connections = await self.cm.get_addrs_current_connections(only_direct=True) # eligible_neighbors = [ # neighbor_addr for neighbor_addr in current_connections if await strategy.is_node_eligible(neighbor_addr) diff --git a/nebula/core/pb/nebula.proto b/nebula/core/pb/nebula.proto index 3360196e..23fe0005 100755 --- a/nebula/core/pb/nebula.proto +++ b/nebula/core/pb/nebula.proto @@ -49,6 +49,7 @@ message ControlMessage { WEAK_LINK = 4; LEADERSHIP_TRANSFER = 5; LEADERSHIP_TRANSFER_ACK = 6; + TERMINATED = 7; } Action action = 1; string log = 2; diff --git a/nebula/core/pb/nebula_pb2.py b/nebula/core/pb/nebula_pb2.py index 448675b3..ecf3c3b3 100644 --- a/nebula/core/pb/nebula_pb2.py +++ b/nebula/core/pb/nebula_pb2.py @@ -1,62 +1,61 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: nebula.proto -# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" + from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0cnebula.proto\x12\x06nebula"\xae\x04\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x37\n\x12reputation_message\x18\x08 \x01(\x0b\x32\x19.nebula.ReputationMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\t \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\n \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\x0b \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02"\xe1\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t"\x92\x01\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\x12\x17\n\x13LEADERSHIP_TRANSFER\x10\x05\x12\x1b\n\x17LEADERSHIP_TRANSFER_ACK\x10\x06\x12\x0e\n\nTERMINATED\x10\x07"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03"\x95\x01\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action"R\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01"\x89\x01\n\x11ReputationMessage\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\r\n\x05round\x18\x03 \x01(\x05\x12\x30\n\x06\x61\x63tion\x18\x04 \x01(\x0e\x32 .nebula.ReputationMessage.Action"\x13\n\x06\x41\x63tion\x12\t\n\x05SHARE\x10\x00"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3' +) - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cnebula.proto\x12\x06nebula\"\xae\x04\n\x07Wrapper\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x35\n\x11\x64iscovery_message\x18\x02 \x01(\x0b\x32\x18.nebula.DiscoveryMessageH\x00\x12\x31\n\x0f\x63ontrol_message\x18\x03 \x01(\x0b\x32\x16.nebula.ControlMessageH\x00\x12\x37\n\x12\x66\x65\x64\x65ration_message\x18\x04 \x01(\x0b\x32\x19.nebula.FederationMessageH\x00\x12-\n\rmodel_message\x18\x05 \x01(\x0b\x32\x14.nebula.ModelMessageH\x00\x12\x37\n\x12\x63onnection_message\x18\x06 \x01(\x0b\x32\x19.nebula.ConnectionMessageH\x00\x12\x33\n\x10response_message\x18\x07 \x01(\x0b\x32\x17.nebula.ResponseMessageH\x00\x12\x37\n\x12reputation_message\x18\x08 \x01(\x0b\x32\x19.nebula.ReputationMessageH\x00\x12\x33\n\x10\x64iscover_message\x18\t \x01(\x0b\x32\x17.nebula.DiscoverMessageH\x00\x12-\n\roffer_message\x18\n \x01(\x0b\x32\x14.nebula.OfferMessageH\x00\x12+\n\x0clink_message\x18\x0b \x01(\x0b\x32\x13.nebula.LinkMessageH\x00\x42\t\n\x07message\"\x9e\x01\n\x10\x44iscoveryMessage\x12/\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1f.nebula.DiscoveryMessage.Action\x12\x10\n\x08latitude\x18\x02 \x01(\x02\x12\x11\n\tlongitude\x18\x03 \x01(\x02\"4\n\x06\x41\x63tion\x12\x0c\n\x08\x44ISCOVER\x10\x00\x12\x0c\n\x08REGISTER\x10\x01\x12\x0e\n\nDEREGISTER\x10\x02\"\xd1\x01\n\x0e\x43ontrolMessage\x12-\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1d.nebula.ControlMessage.Action\x12\x0b\n\x03log\x18\x02 \x01(\t\"\x82\x01\n\x06\x41\x63tion\x12\t\n\x05\x41LIVE\x10\x00\x12\x0c\n\x08OVERHEAD\x10\x01\x12\x0c\n\x08MOBILITY\x10\x02\x12\x0c\n\x08RECOVERY\x10\x03\x12\r\n\tWEAK_LINK\x10\x04\x12\x17\n\x13LEADERSHIP_TRANSFER\x10\x05\x12\x1b\n\x17LEADERSHIP_TRANSFER_ACK\x10\x06\"\xcd\x01\n\x11\x46\x65\x64\x65rationMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.FederationMessage.Action\x12\x11\n\targuments\x18\x02 \x03(\t\x12\r\n\x05round\x18\x03 \x01(\x05\"d\n\x06\x41\x63tion\x12\x14\n\x10\x46\x45\x44\x45RATION_START\x10\x00\x12\x0e\n\nREPUTATION\x10\x01\x12\x1e\n\x1a\x46\x45\x44\x45RATION_MODELS_INCLUDED\x10\x02\x12\x14\n\x10\x46\x45\x44\x45RATION_READY\x10\x03\"A\n\x0cModelMessage\x12\x12\n\nparameters\x18\x01 \x01(\x0c\x12\x0e\n\x06weight\x18\x02 \x01(\x03\x12\r\n\x05round\x18\x03 \x01(\x05\"\x8f\x01\n\x11\x43onnectionMessage\x12\x30\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32 .nebula.ConnectionMessage.Action\"H\n\x06\x41\x63tion\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\x0e\n\nDISCONNECT\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"\x95\x01\n\x0f\x44iscoverMessage\x12.\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1e.nebula.DiscoverMessage.Action\"R\n\x06\x41\x63tion\x12\x11\n\rDISCOVER_JOIN\x10\x00\x12\x12\n\x0e\x44ISCOVER_NODES\x10\x01\x12\x10\n\x0cLATE_CONNECT\x10\x02\x12\x0f\n\x0bRESTRUCTURE\x10\x03\"\xce\x01\n\x0cOfferMessage\x12+\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1b.nebula.OfferMessage.Action\x12\x13\n\x0bn_neighbors\x18\x02 \x01(\x02\x12\x0c\n\x04loss\x18\x03 \x01(\x02\x12\x12\n\nparameters\x18\x04 \x01(\x0c\x12\x0e\n\x06rounds\x18\x05 \x01(\x05\x12\r\n\x05round\x18\x06 \x01(\x05\x12\x0e\n\x06\x65pochs\x18\x07 \x01(\x05\"+\n\x06\x41\x63tion\x12\x0f\n\x0bOFFER_MODEL\x10\x00\x12\x10\n\x0cOFFER_METRIC\x10\x01\"w\n\x0bLinkMessage\x12*\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\x1a.nebula.LinkMessage.Action\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x01(\t\"-\n\x06\x41\x63tion\x12\x0e\n\nCONNECT_TO\x10\x00\x12\x13\n\x0f\x44ISCONNECT_FROM\x10\x01\"\x89\x01\n\x11ReputationMessage\x12\x0f\n\x07node_id\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\r\n\x05round\x18\x03 \x01(\x05\x12\x30\n\x06\x61\x63tion\x18\x04 \x01(\x0e\x32 .nebula.ReputationMessage.Action\"\x13\n\x06\x41\x63tion\x12\t\n\x05SHARE\x10\x00\"#\n\x0fResponseMessage\x12\x10\n\x08response\x18\x01 \x01(\tb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nebula_pb2', _globals) +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "nebula_pb2", globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_WRAPPER']._serialized_start=25 - _globals['_WRAPPER']._serialized_end=583 - _globals['_DISCOVERYMESSAGE']._serialized_start=586 - _globals['_DISCOVERYMESSAGE']._serialized_end=744 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_start=692 - _globals['_DISCOVERYMESSAGE_ACTION']._serialized_end=744 - _globals['_CONTROLMESSAGE']._serialized_start=747 - _globals['_CONTROLMESSAGE']._serialized_end=956 - _globals['_CONTROLMESSAGE_ACTION']._serialized_start=826 - _globals['_CONTROLMESSAGE_ACTION']._serialized_end=956 - _globals['_FEDERATIONMESSAGE']._serialized_start=959 - _globals['_FEDERATIONMESSAGE']._serialized_end=1164 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_start=1064 - _globals['_FEDERATIONMESSAGE_ACTION']._serialized_end=1164 - _globals['_MODELMESSAGE']._serialized_start=1166 - _globals['_MODELMESSAGE']._serialized_end=1231 - _globals['_CONNECTIONMESSAGE']._serialized_start=1234 - _globals['_CONNECTIONMESSAGE']._serialized_end=1377 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_start=1305 - _globals['_CONNECTIONMESSAGE_ACTION']._serialized_end=1377 - _globals['_DISCOVERMESSAGE']._serialized_start=1380 - _globals['_DISCOVERMESSAGE']._serialized_end=1529 - _globals['_DISCOVERMESSAGE_ACTION']._serialized_start=1447 - _globals['_DISCOVERMESSAGE_ACTION']._serialized_end=1529 - _globals['_OFFERMESSAGE']._serialized_start=1532 - _globals['_OFFERMESSAGE']._serialized_end=1738 - _globals['_OFFERMESSAGE_ACTION']._serialized_start=1695 - _globals['_OFFERMESSAGE_ACTION']._serialized_end=1738 - _globals['_LINKMESSAGE']._serialized_start=1740 - _globals['_LINKMESSAGE']._serialized_end=1859 - _globals['_LINKMESSAGE_ACTION']._serialized_start=1814 - _globals['_LINKMESSAGE_ACTION']._serialized_end=1859 - _globals['_REPUTATIONMESSAGE']._serialized_start=1862 - _globals['_REPUTATIONMESSAGE']._serialized_end=1999 - _globals['_REPUTATIONMESSAGE_ACTION']._serialized_start=1980 - _globals['_REPUTATIONMESSAGE_ACTION']._serialized_end=1999 - _globals['_RESPONSEMESSAGE']._serialized_start=2001 - _globals['_RESPONSEMESSAGE']._serialized_end=2036 + DESCRIPTOR._options = None + _WRAPPER._serialized_start = 25 + _WRAPPER._serialized_end = 583 + _DISCOVERYMESSAGE._serialized_start = 586 + _DISCOVERYMESSAGE._serialized_end = 744 + _DISCOVERYMESSAGE_ACTION._serialized_start = 692 + _DISCOVERYMESSAGE_ACTION._serialized_end = 744 + _CONTROLMESSAGE._serialized_start = 747 + _CONTROLMESSAGE._serialized_end = 972 + _CONTROLMESSAGE_ACTION._serialized_start = 826 + _CONTROLMESSAGE_ACTION._serialized_end = 972 + _FEDERATIONMESSAGE._serialized_start = 975 + _FEDERATIONMESSAGE._serialized_end = 1180 + _FEDERATIONMESSAGE_ACTION._serialized_start = 1080 + _FEDERATIONMESSAGE_ACTION._serialized_end = 1180 + _MODELMESSAGE._serialized_start = 1182 + _MODELMESSAGE._serialized_end = 1247 + _CONNECTIONMESSAGE._serialized_start = 1250 + _CONNECTIONMESSAGE._serialized_end = 1393 + _CONNECTIONMESSAGE_ACTION._serialized_start = 1321 + _CONNECTIONMESSAGE_ACTION._serialized_end = 1393 + _DISCOVERMESSAGE._serialized_start = 1396 + _DISCOVERMESSAGE._serialized_end = 1545 + _DISCOVERMESSAGE_ACTION._serialized_start = 1463 + _DISCOVERMESSAGE_ACTION._serialized_end = 1545 + _OFFERMESSAGE._serialized_start = 1548 + _OFFERMESSAGE._serialized_end = 1754 + _OFFERMESSAGE_ACTION._serialized_start = 1711 + _OFFERMESSAGE_ACTION._serialized_end = 1754 + _LINKMESSAGE._serialized_start = 1756 + _LINKMESSAGE._serialized_end = 1875 + _LINKMESSAGE_ACTION._serialized_start = 1830 + _LINKMESSAGE_ACTION._serialized_end = 1875 + _REPUTATIONMESSAGE._serialized_start = 1878 + _REPUTATIONMESSAGE._serialized_end = 2015 + _REPUTATIONMESSAGE_ACTION._serialized_start = 1996 + _REPUTATIONMESSAGE_ACTION._serialized_end = 2015 + _RESPONSEMESSAGE._serialized_start = 2017 + _RESPONSEMESSAGE._serialized_end = 2052 # @@protoc_insertion_point(module_scope) diff --git a/nebula/frontend/config/participant.json.example b/nebula/frontend/config/participant.json.example index ca1d1cfd..876de32e 100755 --- a/nebula/frontend/config/participant.json.example +++ b/nebula/frontend/config/participant.json.example @@ -149,5 +149,12 @@ "misc_args": { "grace_time_connection": 10, "grace_time_start_federation": 10 + }, + "communication_args": { + "mechanism": "standard" + }, + "caff_args": { + "aggregation_fraction": 0.5, + "staleness_threshold": 2 } } diff --git a/nebula/frontend/static/js/deployment/help-content.js b/nebula/frontend/static/js/deployment/help-content.js index 673cae88..34ac2de3 100644 --- a/nebula/frontend/static/js/deployment/help-content.js +++ b/nebula/frontend/static/js/deployment/help-content.js @@ -13,7 +13,10 @@ const HelpContent = (function() { 'partitionMethodsHelpIcon': partitionMethods, 'parameterSettingHelpIcon': parameterSetting, 'modelHelpIcon': model, - 'maliciousHelpIcon': malicious + 'maliciousHelpIcon': malicious, + 'aggregationFractionHelpIcon': caff.aggregationFraction, + 'stalenessThresholdHelpIcon': caff.stalenessThreshold, + 'fallbackTimeoutHelpIcon': caff.fallbackTimeout, }; Object.entries(tooltipElements).forEach(([id, content]) => { @@ -160,6 +163,43 @@ const HelpContent = (function() { ` }; + const caff = { + aggregationFraction: ` +
+ The fraction of peers (N) whose updates you wait for before aggregating.
+ E.g. alpha=0.5 means “once 50% of other nodes have sent an update and local model is ready, start aggregation.”
+
+ K = max(1, round(alpha * N))
+
+ A peer’s update is only accepted if:
+
+ local_round - peer_update_round ≤ Smax
.
+
+ Otherwise (i.e. if it’s older), it’s discarded as too stale.
+
+ A timer that starts immediately after your local update is sent.
+
+ If aggregation hasn’t triggered within this many seconds,
+
+ it forces aggregation with whatever updates are in the cache.
+