diff --git a/python/fedml/core/security/attack/byzantine_attack.py b/python/fedml/core/security/attack/byzantine_attack.py index c4f6b63257..7bf0605587 100644 --- a/python/fedml/core/security/attack/byzantine_attack.py +++ b/python/fedml/core/security/attack/byzantine_attack.py @@ -1,11 +1,12 @@ from collections import OrderedDict - +import random import fedml import numpy as np import torch from .attack_base import BaseAttackMethod from ..common.utils import is_weight_param, sample_some_clients from typing import List, Tuple, Any +import logging """ attack @ server, added by Shanshan, 07/04/2022 @@ -17,9 +18,19 @@ def __init__(self, args): self.byzantine_client_num = args.byzantine_client_num self.attack_mode = args.attack_mode # random: randomly generate a weight; zero: set the weight to 0 self.device = fedml.device.get_device(args) + self.attack_training_rounds = None # default: attack happens in every round + if hasattr(args, "attack_round_num") and isinstance(args.attack_round_num, + int) and args.attack_round_num < args.comm_round: + random.seed(args.random_seed) + self.attack_training_rounds = random.sample(range(args.comm_round), args.attack_round_num) + logging.info(f"attack rounds: {self.attack_training_rounds}") + self.current_training_round = -1 def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], - extra_auxiliary_info: Any = None): + extra_auxiliary_info: Any = None): + self.current_training_round += 1 + if self.attack_training_rounds is not None and self.current_training_round not in self.attack_training_rounds: + return raw_client_grad_list if len(raw_client_grad_list) < self.byzantine_client_num: self.byzantine_client_num = len(raw_client_grad_list) byzantine_idxs = sample_some_clients(len(raw_client_grad_list), self.byzantine_client_num) @@ -29,7 +40,8 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], elif self.attack_mode == "random": byzantine_local_w = self._attack_random_mode(raw_client_grad_list, byzantine_idxs) elif self.attack_mode == "flip": - byzantine_local_w = self._attack_flip_mode(raw_client_grad_list, byzantine_idxs, extra_auxiliary_info) # extra_auxiliary_info: global model + byzantine_local_w = self._attack_flip_mode(raw_client_grad_list, byzantine_idxs, + extra_auxiliary_info) # extra_auxiliary_info: global model else: raise NotImplementedError("Method not implemented!") return byzantine_local_w @@ -43,7 +55,8 @@ def _attack_zero_mode(self, model_list, byzantine_idxs): local_sample_number, local_model_params = model_list[i] for k in local_model_params.keys(): if is_weight_param(k): - local_model_params[k] = torch.from_numpy(np.zeros(local_model_params[k].size())).float().to(self.device) + local_model_params[k] = torch.from_numpy(np.zeros(local_model_params[k].size())).float().to( + self.device) new_model_list.append((local_sample_number, local_model_params)) return new_model_list @@ -57,11 +70,11 @@ def _attack_random_mode(self, model_list, byzantine_idxs): local_sample_number, local_model_params = model_list[i] for k in local_model_params.keys(): if is_weight_param(k): - local_model_params[k] = torch.from_numpy(2*np.random.random_sample(size=local_model_params[k].size())-1).float().to(self.device) + local_model_params[k] = torch.from_numpy( + 2 * np.random.random_sample(size=local_model_params[k].size()) - 1).float().to(self.device) new_model_list.append((local_sample_number, local_model_params)) return new_model_list - def _attack_flip_mode(self, model_list, byzantine_idxs, global_model): new_model_list = [] for i in range(0, len(model_list)): @@ -71,6 +84,8 @@ def _attack_flip_mode(self, model_list, byzantine_idxs, global_model): local_sample_number, local_model_params = model_list[i] for k in local_model_params.keys(): if is_weight_param(k): - local_model_params[k] = global_model[k].float().to(self.device) + (global_model[k].float().to(self.device) - local_model_params[k].float().to(self.device)) + local_model_params[k] = global_model[k].float().to(self.device) + ( + global_model[k].float().to(self.device) - local_model_params[k].float().to( + self.device)) new_model_list.append((local_sample_number, local_model_params)) return new_model_list diff --git a/python/fedml/core/security/attack/model_replacement_backdoor_attack.py b/python/fedml/core/security/attack/model_replacement_backdoor_attack.py index be4495d188..40d5e0bf5d 100644 --- a/python/fedml/core/security/attack/model_replacement_backdoor_attack.py +++ b/python/fedml/core/security/attack/model_replacement_backdoor_attack.py @@ -38,7 +38,7 @@ def __init__(self, args): self.scale_factor_S = args.scale_factor_S else: self.scale_factor_S = None - self.training_round = 1 + self.training_round = -1 self.device = fedml.device.get_device(args) def attack_model( @@ -47,6 +47,7 @@ def attack_model( extra_auxiliary_info: Any = None, ): participant_num = len(raw_client_grad_list) + self.training_round += 1 if self.attack_training_rounds is not None and self.training_round not in self.attack_training_rounds: return raw_client_grad_list if self.malicious_client_id is None: @@ -67,7 +68,6 @@ def attack_model( if is_weight_param(k): original_client_model[k] = torch.tensor(gamma * (original_client_model[k] - global_model[k]) + global_model[k]).float().to(self.device) raw_client_grad_list.insert(malicious_idx, (num, original_client_model)) - self.training_round = self.training_round + 1 return raw_client_grad_list def compute_gamma(self, global_model, original_client_model): diff --git a/python/fedml/core/security/defense/robust_learning_rate_defense.py b/python/fedml/core/security/defense/robust_learning_rate_defense.py index ccbb1c24df..1a3dcf545d 100644 --- a/python/fedml/core/security/defense/robust_learning_rate_defense.py +++ b/python/fedml/core/security/defense/robust_learning_rate_defense.py @@ -28,7 +28,13 @@ def __init__(self, config): self.robust_threshold = config.robust_threshold # e.g., robust threshold = 4 self.server_learning_rate = 1 - def run( + def _compute_robust_learning_rates(self, client_update_sign): + client_lr = torch.abs(sum(client_update_sign)) + client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate + client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate + return client_lr + + def defend_on_aggregation( self, raw_client_grad_list: List[Tuple[float, OrderedDict]], base_aggregation_func: Callable = None, @@ -51,9 +57,3 @@ def run( client_lr = self._compute_robust_learning_rates(client_update_sign) avg_params[k] = client_lr * avg_params[k] return avg_params - - def _compute_robust_learning_rates(self, client_update_sign): - client_lr = torch.abs(sum(client_update_sign)) - client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate - client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate - return client_lr diff --git a/python/fedml/core/security/fedml_defender.py b/python/fedml/core/security/fedml_defender.py index 6c2acc2072..7b86c24eef 100644 --- a/python/fedml/core/security/fedml_defender.py +++ b/python/fedml/core/security/fedml_defender.py @@ -129,7 +129,7 @@ def defend( ) def is_defense_on_aggregation(self): - return self.is_defense_enabled() and self.defense_type in [DEFENSE_SLSGD, DEFENSE_RFA, DEFENSE_WISE_MEDIAN, DEFENSE_GEO_MEDIAN] + return self.is_defense_enabled() and self.defense_type in [DEFENSE_SLSGD, DEFENSE_RFA, DEFENSE_WISE_MEDIAN, DEFENSE_GEO_MEDIAN, DEFENSE_ROBUST_LEARNING_RATE] def is_defense_before_aggregation(self): return self.is_defense_enabled() and self.defense_type in [