Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions python/fedml/core/security/attack/byzantine_attack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import OrderedDict

import random
import fedml
import numpy as np
import torch
from .attack_base import BaseAttackMethod
from ..common.utils import is_weight_param, sample_some_clients
from typing import List, Tuple, Any
import logging

"""
attack @ server, added by Shanshan, 07/04/2022
Expand All @@ -17,9 +18,19 @@ def __init__(self, args):
self.byzantine_client_num = args.byzantine_client_num
self.attack_mode = args.attack_mode # random: randomly generate a weight; zero: set the weight to 0
self.device = fedml.device.get_device(args)
self.attack_training_rounds = None # default: attack happens in every round
if hasattr(args, "attack_round_num") and isinstance(args.attack_round_num,
int) and args.attack_round_num < args.comm_round:
random.seed(args.random_seed)
self.attack_training_rounds = random.sample(range(args.comm_round), args.attack_round_num)
logging.info(f"attack rounds: {self.attack_training_rounds}")
self.current_training_round = -1

def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]],
extra_auxiliary_info: Any = None):
extra_auxiliary_info: Any = None):
self.current_training_round += 1
if self.attack_training_rounds is not None and self.current_training_round not in self.attack_training_rounds:
return raw_client_grad_list
if len(raw_client_grad_list) < self.byzantine_client_num:
self.byzantine_client_num = len(raw_client_grad_list)
byzantine_idxs = sample_some_clients(len(raw_client_grad_list), self.byzantine_client_num)
Expand All @@ -29,7 +40,8 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]],
elif self.attack_mode == "random":
byzantine_local_w = self._attack_random_mode(raw_client_grad_list, byzantine_idxs)
elif self.attack_mode == "flip":
byzantine_local_w = self._attack_flip_mode(raw_client_grad_list, byzantine_idxs, extra_auxiliary_info) # extra_auxiliary_info: global model
byzantine_local_w = self._attack_flip_mode(raw_client_grad_list, byzantine_idxs,
extra_auxiliary_info) # extra_auxiliary_info: global model
else:
raise NotImplementedError("Method not implemented!")
return byzantine_local_w
Expand All @@ -43,7 +55,8 @@ def _attack_zero_mode(self, model_list, byzantine_idxs):
local_sample_number, local_model_params = model_list[i]
for k in local_model_params.keys():
if is_weight_param(k):
local_model_params[k] = torch.from_numpy(np.zeros(local_model_params[k].size())).float().to(self.device)
local_model_params[k] = torch.from_numpy(np.zeros(local_model_params[k].size())).float().to(
self.device)
new_model_list.append((local_sample_number, local_model_params))
return new_model_list

Expand All @@ -57,11 +70,11 @@ def _attack_random_mode(self, model_list, byzantine_idxs):
local_sample_number, local_model_params = model_list[i]
for k in local_model_params.keys():
if is_weight_param(k):
local_model_params[k] = torch.from_numpy(2*np.random.random_sample(size=local_model_params[k].size())-1).float().to(self.device)
local_model_params[k] = torch.from_numpy(
2 * np.random.random_sample(size=local_model_params[k].size()) - 1).float().to(self.device)
new_model_list.append((local_sample_number, local_model_params))
return new_model_list


def _attack_flip_mode(self, model_list, byzantine_idxs, global_model):
new_model_list = []
for i in range(0, len(model_list)):
Expand All @@ -71,6 +84,8 @@ def _attack_flip_mode(self, model_list, byzantine_idxs, global_model):
local_sample_number, local_model_params = model_list[i]
for k in local_model_params.keys():
if is_weight_param(k):
local_model_params[k] = global_model[k].float().to(self.device) + (global_model[k].float().to(self.device) - local_model_params[k].float().to(self.device))
local_model_params[k] = global_model[k].float().to(self.device) + (
global_model[k].float().to(self.device) - local_model_params[k].float().to(
self.device))
new_model_list.append((local_sample_number, local_model_params))
return new_model_list
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, args):
self.scale_factor_S = args.scale_factor_S
else:
self.scale_factor_S = None
self.training_round = 1
self.training_round = -1
self.device = fedml.device.get_device(args)

def attack_model(
Expand All @@ -47,6 +47,7 @@ def attack_model(
extra_auxiliary_info: Any = None,
):
participant_num = len(raw_client_grad_list)
self.training_round += 1
if self.attack_training_rounds is not None and self.training_round not in self.attack_training_rounds:
return raw_client_grad_list
if self.malicious_client_id is None:
Expand All @@ -67,7 +68,6 @@ def attack_model(
if is_weight_param(k):
original_client_model[k] = torch.tensor(gamma * (original_client_model[k] - global_model[k]) + global_model[k]).float().to(self.device)
raw_client_grad_list.insert(malicious_idx, (num, original_client_model))
self.training_round = self.training_round + 1
return raw_client_grad_list

def compute_gamma(self, global_model, original_client_model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ def __init__(self, config):
self.robust_threshold = config.robust_threshold # e.g., robust threshold = 4
self.server_learning_rate = 1

def run(
def _compute_robust_learning_rates(self, client_update_sign):
client_lr = torch.abs(sum(client_update_sign))
client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate
client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate
return client_lr

def defend_on_aggregation(
self,
raw_client_grad_list: List[Tuple[float, OrderedDict]],
base_aggregation_func: Callable = None,
Expand All @@ -51,9 +57,3 @@ def run(
client_lr = self._compute_robust_learning_rates(client_update_sign)
avg_params[k] = client_lr * avg_params[k]
return avg_params

def _compute_robust_learning_rates(self, client_update_sign):
client_lr = torch.abs(sum(client_update_sign))
client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate
client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate
return client_lr
2 changes: 1 addition & 1 deletion python/fedml/core/security/fedml_defender.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def defend(
)

def is_defense_on_aggregation(self):
return self.is_defense_enabled() and self.defense_type in [DEFENSE_SLSGD, DEFENSE_RFA, DEFENSE_WISE_MEDIAN, DEFENSE_GEO_MEDIAN]
return self.is_defense_enabled() and self.defense_type in [DEFENSE_SLSGD, DEFENSE_RFA, DEFENSE_WISE_MEDIAN, DEFENSE_GEO_MEDIAN, DEFENSE_ROBUST_LEARNING_RATE]

def is_defense_before_aggregation(self):
return self.is_defense_enabled() and self.defense_type in [
Expand Down