Skip to content

Commit

Permalink
Merge pull request #191 from XinghaoWu/master
Browse files Browse the repository at this point in the history
Add FedCAC algorithm
  • Loading branch information
TsingZ0 authored Jun 3, 2024
2 parents 097b6ed + 8725a06 commit 5012fce
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Thanks to [@Stonesjtu](https://github.com/Stonesjtu/pytorch_memlab/blob/d590c489
- **GPFL**[GPFL: Simultaneously Learning Generic and Personalized Feature Information for Personalized Federated Learning](https://arxiv.org/pdf/2308.10279v3.pdf) *ICCV 2023*
- **FedGH**[FedGH: Heterogeneous Federated Learning with Generalized Global Header](https://dl.acm.org/doi/10.1145/3581783.3611781) *ACM MM 2023*
- **DBE**[Eliminating Domain Bias for Federated Learning in Representation Space](https://openreview.net/forum?id=nO5i1XdUS0) *NeurIPS 2023*
- **FedCAC**[Bold but Cautious: Unlocking the Potential of Personalized Federated Learning through Cautiously Aggressive Collaboration](https://arxiv.org/abs/2309.11103) *ICCV 2023*

***Knowledge-distillation-based pFL***

Expand Down
121 changes: 121 additions & 0 deletions system/flcore/clients/clientcac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
import time
import torch
import torch.nn as nn
import copy
from flcore.clients.clientbase import Client

class clientCAC(Client):
def __init__(self, args, id, train_samples, test_samples, **kwargs):
super().__init__(args, id, train_samples, test_samples, **kwargs)
self.args = args
self.critical_parameter = None # record the critical parameter positions in FedCAC
self.customized_model = copy.deepcopy(self.model) # customized global model
self.critical_parameter, self.global_mask, self.local_mask = None, None, None

def train(self):
trainloader = self.load_train_data()

start_time = time.time()

# record the model before local updating, used for critical parameter selection
initial_model = copy.deepcopy(self.model)

# self.model.to(self.device)
self.model.train()

max_local_epochs = self.local_epochs
if self.train_slow:
max_local_epochs = np.random.randint(1, max_local_epochs // 2)

for epoch in range(max_local_epochs):
for i, (x, y) in enumerate(trainloader):
if type(x) == type([]):
x[0] = x[0].to(self.device)
else:
x = x.to(self.device)
y = y.to(self.device)
if self.train_slow:
time.sleep(0.1 * np.abs(np.random.rand()))
output = self.model(x)
loss = self.loss(output, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

if self.learning_rate_decay:
self.learning_rate_scheduler.step()

# self.model.to('cpu')

# select the critical parameters
self.critical_parameter, self.global_mask, self.local_mask = self.evaluate_critical_parameter(
prevModel=initial_model, model=self.model, tau=self.args.tau
)

self.train_time_cost['num_rounds'] += 1
self.train_time_cost['total_cost'] += time.time() - start_time

def evaluate_critical_parameter(self, prevModel: nn.Module, model: nn.Module, tau: float):
r"""
Overview:
Implement critical parameter selection.
"""
global_mask = [] # mark non-critical parameter
local_mask = [] # mark critical parameter
critical_parameter = []

# self.model.to(self.device)
# prevModel.to(self.device)

# select critical parameters in each layer
for (name1, prevparam), (name2, param) in zip(prevModel.named_parameters(), model.named_parameters()):
g = (param.data - prevparam.data)
v = param.data
c = torch.abs(g * v)

metric = c.view(-1)
num_params = metric.size(0)
nz = int(tau * num_params)
top_values, _ = torch.topk(metric, nz)
thresh = top_values[-1] if len(top_values) > 0 else np.inf
# if threshold equals 0, select minimal nonzero element as threshold
if thresh <= 1e-10:
new_metric = metric[metric > 1e-20]
if len(new_metric) == 0: # this means all items in metric are zero
print(f'Abnormal!!! metric:{metric}')
else:
thresh = new_metric.sort()[0][0]

# Get the local mask and global mask
mask = (c >= thresh).int().to('cpu')
global_mask.append((c < thresh).int().to('cpu'))
local_mask.append(mask)
critical_parameter.append(mask.view(-1))
model.zero_grad()
critical_parameter = torch.cat(critical_parameter)

# self.model.to('cpu')
# prevModel.to('cpu')

return critical_parameter, global_mask, local_mask

def set_parameters(self, model):
if self.local_mask != None:
# self.model.to(self.device)
# model.to(self.device)
# self.customized_model.to(self.device)

index = 0
for (name1, param1), (name2, param2), (name3, param3) in zip(
self.model.named_parameters(), model.named_parameters(),
self.customized_model.named_parameters()):
param1.data = self.local_mask[index].to(self.device).float() * param3.data + \
self.global_mask[index].to(self.args.device).float() * param2.data
index += 1

# self.model.to('cpu')
# model.to('cpu')
# self.customized_model.to('cpu')
else:
super().set_parameters(model)
122 changes: 122 additions & 0 deletions system/flcore/servers/servercac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import time
import torch
import copy

from flcore.clients.clientcac import clientCAC
from flcore.servers.serverbase import Server
from utils.data_utils import read_client_data

class FedCAC(Server):
def __init__(self, args, times):
super().__init__(args, times)
args.beta = int(args.beta)
# select slow clients
self.set_slow_clients()
self.set_clients(clientCAC)

print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
print("Finished creating server and clients.")

# self.load_model()
self.Budget = []

# To be consistent with the existing pipeline interface. Maintaining an epoch counter.
self.epoch = -1

def train(self):
for i in range(self.global_rounds+1):
self.epoch = i
s_t = time.time()
self.selected_clients = self.select_clients()
self.send_models()

if i%self.eval_gap == 0:
print(f"\n-------------Round number: {i}-------------")
print("\nEvaluate personalized models")
self.evaluate()

for client in self.selected_clients:
client.train()

# threads = [Thread(target=client.train)
# for client in self.selected_clients]
# [t.start() for t in threads]
# [t.join() for t in threads]

self.receive_models()
self.aggregate_parameters()

self.Budget.append(time.time() - s_t)
print('-'*25, 'time cost', '-'*25, self.Budget[-1])

if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt):
break

print("\nBest accuracy.")
# self.print_(max(self.rs_test_acc), max(
# self.rs_train_acc), min(self.rs_train_loss))
print(max(self.rs_test_acc))
print("\nAverage time cost per round.")
print(sum(self.Budget[1:])/len(self.Budget[1:]))

self.save_results()

if self.num_new_clients > 0:
self.eval_new_clients = True
self.set_new_clients(clientCAC)
print(f"\n-------------Fine tuning round-------------")
print("\nEvaluate new clients")
self.evaluate()

def get_customized_global_models(self):
r"""
Overview:
Aggregating customized global models for clients to collaborate critical parameters.
"""
assert type(self.args.beta) == int and self.args.beta >= 1
overlap_buffer = [[] for i in range(self.args.num_clients)]

# calculate overlap rate between client i and client j
for i in range(self.args.num_clients):
for j in range(self.args.num_clients):
if i == j:
continue
overlap_rate = 1 - torch.sum(
torch.abs(self.clients[i].critical_parameter.to(self.device) - self.clients[j].critical_parameter.to(self.args.device))
) / float(torch.sum(self.clients[i].critical_parameter.to(self.args.device)).cpu() * 2)
overlap_buffer[i].append(overlap_rate)

# calculate the global threshold
overlap_buffer_tensor = torch.tensor(overlap_buffer)
overlap_sum = overlap_buffer_tensor.sum()
overlap_avg = overlap_sum / ((self.args.num_clients - 1) * self.args.num_clients)
overlap_max = overlap_buffer_tensor.max()
threshold = overlap_avg + (self.epoch + 1) / self.args.beta * (overlap_max - overlap_avg)

# calculate the customized global model for each client
for i in range(self.args.num_clients):
w_customized_global = copy.deepcopy(self.clients[i].model.state_dict())
collaboration_clients = [i]
# find clients whose critical parameter locations are similar to client i
index = 0
for j in range(self.args.num_clients):
if i == j:
continue
if overlap_buffer[i][index] >= threshold:
collaboration_clients.append(j)
index += 1

for key in w_customized_global.keys():
for client in collaboration_clients:
if client == i:
continue
w_customized_global[key] += self.clients[client].model.state_dict()[key]
w_customized_global[key] = torch.div(w_customized_global[key], float(len(collaboration_clients)))
# send the customized global model to client i
self.clients[i].customized_model.load_state_dict(w_customized_global)

def send_models(self):
if self.epoch != 0:
self.get_customized_global_models()

super().send_models()
9 changes: 7 additions & 2 deletions system/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from flcore.servers.serverntd import FedNTD
from flcore.servers.servergh import FedGH
from flcore.servers.serveravgDBE import FedAvgDBE
from flcore.servers.servercac import FedCAC

from flcore.trainmodel.models import *

Expand Down Expand Up @@ -355,6 +356,10 @@ def run(args):
args.model.fc = nn.Identity()
args.model = BaseHeadSplit(args.model, args.head)
server = FedAvgDBE(args, i)

elif args.algorithm == 'FedCAC':
server = FedCAC(args, i)


else:
raise NotImplementedError
Expand Down Expand Up @@ -429,7 +434,7 @@ def run(args):
help="Whether to group and select clients at each round according to time cost")
parser.add_argument('-tth', "--time_threthold", type=float, default=10000,
help="The threthold for droping slow clients")
# pFedMe / PerAvg / FedProx / FedAMP / FedPHP / GPFL
# pFedMe / PerAvg / FedProx / FedAMP / FedPHP / GPFL / FedCAC
parser.add_argument('-bt', "--beta", type=float, default=0.0)
parser.add_argument('-lam', "--lamda", type=float, default=1.0,
help="Regularization weight")
Expand All @@ -452,7 +457,7 @@ def run(args):
parser.add_argument('-al', "--alpha", type=float, default=1.0)
# Ditto / FedRep
parser.add_argument('-pls', "--plocal_epochs", type=int, default=1)
# MOON
# MOON / FedCAC
parser.add_argument('-tau', "--tau", type=float, default=1.0)
# FedBABU
parser.add_argument('-fte', "--fine_tuning_epochs", type=int, default=10)
Expand Down

0 comments on commit 5012fce

Please sign in to comment.