-
Notifications
You must be signed in to change notification settings - Fork 328
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #191 from XinghaoWu/master
Add FedCAC algorithm
- Loading branch information
Showing
4 changed files
with
251 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import numpy as np | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
import copy | ||
from flcore.clients.clientbase import Client | ||
|
||
class clientCAC(Client): | ||
def __init__(self, args, id, train_samples, test_samples, **kwargs): | ||
super().__init__(args, id, train_samples, test_samples, **kwargs) | ||
self.args = args | ||
self.critical_parameter = None # record the critical parameter positions in FedCAC | ||
self.customized_model = copy.deepcopy(self.model) # customized global model | ||
self.critical_parameter, self.global_mask, self.local_mask = None, None, None | ||
|
||
def train(self): | ||
trainloader = self.load_train_data() | ||
|
||
start_time = time.time() | ||
|
||
# record the model before local updating, used for critical parameter selection | ||
initial_model = copy.deepcopy(self.model) | ||
|
||
# self.model.to(self.device) | ||
self.model.train() | ||
|
||
max_local_epochs = self.local_epochs | ||
if self.train_slow: | ||
max_local_epochs = np.random.randint(1, max_local_epochs // 2) | ||
|
||
for epoch in range(max_local_epochs): | ||
for i, (x, y) in enumerate(trainloader): | ||
if type(x) == type([]): | ||
x[0] = x[0].to(self.device) | ||
else: | ||
x = x.to(self.device) | ||
y = y.to(self.device) | ||
if self.train_slow: | ||
time.sleep(0.1 * np.abs(np.random.rand())) | ||
output = self.model(x) | ||
loss = self.loss(output, y) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
if self.learning_rate_decay: | ||
self.learning_rate_scheduler.step() | ||
|
||
# self.model.to('cpu') | ||
|
||
# select the critical parameters | ||
self.critical_parameter, self.global_mask, self.local_mask = self.evaluate_critical_parameter( | ||
prevModel=initial_model, model=self.model, tau=self.args.tau | ||
) | ||
|
||
self.train_time_cost['num_rounds'] += 1 | ||
self.train_time_cost['total_cost'] += time.time() - start_time | ||
|
||
def evaluate_critical_parameter(self, prevModel: nn.Module, model: nn.Module, tau: float): | ||
r""" | ||
Overview: | ||
Implement critical parameter selection. | ||
""" | ||
global_mask = [] # mark non-critical parameter | ||
local_mask = [] # mark critical parameter | ||
critical_parameter = [] | ||
|
||
# self.model.to(self.device) | ||
# prevModel.to(self.device) | ||
|
||
# select critical parameters in each layer | ||
for (name1, prevparam), (name2, param) in zip(prevModel.named_parameters(), model.named_parameters()): | ||
g = (param.data - prevparam.data) | ||
v = param.data | ||
c = torch.abs(g * v) | ||
|
||
metric = c.view(-1) | ||
num_params = metric.size(0) | ||
nz = int(tau * num_params) | ||
top_values, _ = torch.topk(metric, nz) | ||
thresh = top_values[-1] if len(top_values) > 0 else np.inf | ||
# if threshold equals 0, select minimal nonzero element as threshold | ||
if thresh <= 1e-10: | ||
new_metric = metric[metric > 1e-20] | ||
if len(new_metric) == 0: # this means all items in metric are zero | ||
print(f'Abnormal!!! metric:{metric}') | ||
else: | ||
thresh = new_metric.sort()[0][0] | ||
|
||
# Get the local mask and global mask | ||
mask = (c >= thresh).int().to('cpu') | ||
global_mask.append((c < thresh).int().to('cpu')) | ||
local_mask.append(mask) | ||
critical_parameter.append(mask.view(-1)) | ||
model.zero_grad() | ||
critical_parameter = torch.cat(critical_parameter) | ||
|
||
# self.model.to('cpu') | ||
# prevModel.to('cpu') | ||
|
||
return critical_parameter, global_mask, local_mask | ||
|
||
def set_parameters(self, model): | ||
if self.local_mask != None: | ||
# self.model.to(self.device) | ||
# model.to(self.device) | ||
# self.customized_model.to(self.device) | ||
|
||
index = 0 | ||
for (name1, param1), (name2, param2), (name3, param3) in zip( | ||
self.model.named_parameters(), model.named_parameters(), | ||
self.customized_model.named_parameters()): | ||
param1.data = self.local_mask[index].to(self.device).float() * param3.data + \ | ||
self.global_mask[index].to(self.args.device).float() * param2.data | ||
index += 1 | ||
|
||
# self.model.to('cpu') | ||
# model.to('cpu') | ||
# self.customized_model.to('cpu') | ||
else: | ||
super().set_parameters(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import time | ||
import torch | ||
import copy | ||
|
||
from flcore.clients.clientcac import clientCAC | ||
from flcore.servers.serverbase import Server | ||
from utils.data_utils import read_client_data | ||
|
||
class FedCAC(Server): | ||
def __init__(self, args, times): | ||
super().__init__(args, times) | ||
args.beta = int(args.beta) | ||
# select slow clients | ||
self.set_slow_clients() | ||
self.set_clients(clientCAC) | ||
|
||
print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}") | ||
print("Finished creating server and clients.") | ||
|
||
# self.load_model() | ||
self.Budget = [] | ||
|
||
# To be consistent with the existing pipeline interface. Maintaining an epoch counter. | ||
self.epoch = -1 | ||
|
||
def train(self): | ||
for i in range(self.global_rounds+1): | ||
self.epoch = i | ||
s_t = time.time() | ||
self.selected_clients = self.select_clients() | ||
self.send_models() | ||
|
||
if i%self.eval_gap == 0: | ||
print(f"\n-------------Round number: {i}-------------") | ||
print("\nEvaluate personalized models") | ||
self.evaluate() | ||
|
||
for client in self.selected_clients: | ||
client.train() | ||
|
||
# threads = [Thread(target=client.train) | ||
# for client in self.selected_clients] | ||
# [t.start() for t in threads] | ||
# [t.join() for t in threads] | ||
|
||
self.receive_models() | ||
self.aggregate_parameters() | ||
|
||
self.Budget.append(time.time() - s_t) | ||
print('-'*25, 'time cost', '-'*25, self.Budget[-1]) | ||
|
||
if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): | ||
break | ||
|
||
print("\nBest accuracy.") | ||
# self.print_(max(self.rs_test_acc), max( | ||
# self.rs_train_acc), min(self.rs_train_loss)) | ||
print(max(self.rs_test_acc)) | ||
print("\nAverage time cost per round.") | ||
print(sum(self.Budget[1:])/len(self.Budget[1:])) | ||
|
||
self.save_results() | ||
|
||
if self.num_new_clients > 0: | ||
self.eval_new_clients = True | ||
self.set_new_clients(clientCAC) | ||
print(f"\n-------------Fine tuning round-------------") | ||
print("\nEvaluate new clients") | ||
self.evaluate() | ||
|
||
def get_customized_global_models(self): | ||
r""" | ||
Overview: | ||
Aggregating customized global models for clients to collaborate critical parameters. | ||
""" | ||
assert type(self.args.beta) == int and self.args.beta >= 1 | ||
overlap_buffer = [[] for i in range(self.args.num_clients)] | ||
|
||
# calculate overlap rate between client i and client j | ||
for i in range(self.args.num_clients): | ||
for j in range(self.args.num_clients): | ||
if i == j: | ||
continue | ||
overlap_rate = 1 - torch.sum( | ||
torch.abs(self.clients[i].critical_parameter.to(self.device) - self.clients[j].critical_parameter.to(self.args.device)) | ||
) / float(torch.sum(self.clients[i].critical_parameter.to(self.args.device)).cpu() * 2) | ||
overlap_buffer[i].append(overlap_rate) | ||
|
||
# calculate the global threshold | ||
overlap_buffer_tensor = torch.tensor(overlap_buffer) | ||
overlap_sum = overlap_buffer_tensor.sum() | ||
overlap_avg = overlap_sum / ((self.args.num_clients - 1) * self.args.num_clients) | ||
overlap_max = overlap_buffer_tensor.max() | ||
threshold = overlap_avg + (self.epoch + 1) / self.args.beta * (overlap_max - overlap_avg) | ||
|
||
# calculate the customized global model for each client | ||
for i in range(self.args.num_clients): | ||
w_customized_global = copy.deepcopy(self.clients[i].model.state_dict()) | ||
collaboration_clients = [i] | ||
# find clients whose critical parameter locations are similar to client i | ||
index = 0 | ||
for j in range(self.args.num_clients): | ||
if i == j: | ||
continue | ||
if overlap_buffer[i][index] >= threshold: | ||
collaboration_clients.append(j) | ||
index += 1 | ||
|
||
for key in w_customized_global.keys(): | ||
for client in collaboration_clients: | ||
if client == i: | ||
continue | ||
w_customized_global[key] += self.clients[client].model.state_dict()[key] | ||
w_customized_global[key] = torch.div(w_customized_global[key], float(len(collaboration_clients))) | ||
# send the customized global model to client i | ||
self.clients[i].customized_model.load_state_dict(w_customized_global) | ||
|
||
def send_models(self): | ||
if self.epoch != 0: | ||
self.get_customized_global_models() | ||
|
||
super().send_models() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters