-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.py
147 lines (119 loc) · 6.36 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import numpy as np
import copy
from scipy.linalg import svd
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Server():
def __init__(self,
clients,
eposide,
actor_sparsity,
critic_sparsity,
update,
aggregate) -> None:
self.clients = clients
self.eposide = eposide
self.actor_sparisty = actor_sparsity
self.critic_sparisty = critic_sparsity
self.update = update
self.aggregate = aggregate
def fed_model_average(self):
actors = []
critics = []
for client in self.clients:
actor_model, critic_model = client.get_fed_model_params()
actors.append(actor_model)
critics.append(critic_model)
avg_actor_model_params = actors[0].state_dict()
avg_crtic_model_params = critics[0].state_dict()
for param_name in avg_actor_model_params:
for i in range(1, len(actors)):
avg_actor_model_params[param_name] += actors[i].state_dict()[param_name]
avg_actor_model_params[param_name] /= len(actors)
for param_name in avg_crtic_model_params:
for i in range(1, len(critics)):
avg_crtic_model_params[param_name] += critics[i].state_dict()[param_name]
avg_crtic_model_params[param_name] /= len(critics)
return avg_actor_model_params, avg_crtic_model_params
def sparse_fed_model_average(self):
actors_params = []
critics_params = []
actors_masks = []
critics_masks = []
for client in self.clients:
actor_model, critic_model = client.get_model_params()
actors_params.append(actor_model)
critics_params.append(critic_model)
actor_masks, critic_masks = client.get_model_masks()
actors_masks.append(actor_masks)
critics_masks.append(critic_masks)
actor_count_mask = copy.deepcopy(actors_masks[0])
critic_count_mask = copy.deepcopy(critics_masks[0])
avg_actor_model_params = copy.deepcopy(actors_params[0])
avg_crtic_model_params = copy.deepcopy(critics_params[0])
for k in actor_count_mask.keys():
actor_count_mask[k] = actor_count_mask[k] - actor_count_mask[k]
for clnt in range(len(self.clients)):
actor_count_mask[k] += actors_masks[clnt][k]
for k in actor_count_mask.keys():
actor_count_mask_cpu = actor_count_mask[k].cpu()
actor_count_mask[k] = np.divide(1, actor_count_mask_cpu, out = np.zeros_like(actor_count_mask_cpu), where = actor_count_mask_cpu != 0)
for k in avg_actor_model_params.keys():
avg_actor_model_params[k] = avg_actor_model_params[k] - avg_actor_model_params[k]
for clnt in range(len(self.clients)):
avg_actor_model_params[k] += torch.from_numpy(actor_count_mask[k]) * actors_params[clnt][k]
for k in critic_count_mask.keys():
critic_count_mask[k] = critic_count_mask[k] - critic_count_mask[k]
for clnt in range(len(self.clients)):
critic_count_mask[k] += critics_masks[clnt][k]
for k in critic_count_mask.keys():
critic_count_mask_cpu = critic_count_mask[k].cpu()
critic_count_mask[k] = np.divide(1, critic_count_mask_cpu, out = np.zeros_like(critic_count_mask_cpu), where = critic_count_mask_cpu != 0)
for k in avg_crtic_model_params.keys():
avg_crtic_model_params[k] = avg_crtic_model_params[k] - avg_crtic_model_params[k]
for clnt in range(len(self.clients)):
avg_crtic_model_params[k] += torch.from_numpy(critic_count_mask[k]) * critics_params[clnt][k]
for id, client in enumerate(self.clients):
local_critic = copy.deepcopy(avg_crtic_model_params)
for name in critics_masks[id]:
critic_mask_gpu = critics_masks[id][name].cpu()
local_critic[name] = local_critic[name] * critic_mask_gpu
local_actor = copy.deepcopy(avg_actor_model_params)
for name in actors_masks[id]:
actor_mask_gpu = actors_masks[id][name].cpu()
local_actor[name] = local_actor[name] * actor_mask_gpu
client.set_model_params(local_actor, local_critic)
def local_train(self):
for i in range(self.eposide):
for id, client in enumerate(self.clients):
index = i % len(client.train_data)
reward = client.fed_train(index)
print(f"home: {id}, eposide: {i}, reward: {reward}")
def fed_train(self):
for eposide in range(self.eposide):
for id, client in enumerate(self.clients):
index = eposide % len(client.train_data)
reward = client.fed_train(index)
print(f"home: {id}, eposide: {eposide}, reward: {reward}")
if eposide % 10 == 0:
for id, client in enumerate(self.clients):
reward_t, reward_e = client.fed_test()
if eposide % self.aggregate == 0 and eposide > 0:
actor_global_model_params, critic_global_model_params = self.fed_model_average()
# Model distribute
for client in self.clients:
client.set_fed_model_params(actor_global_model_params, critic_global_model_params)
def sparse_fed_train(self):
actor_mask, critic_mask = self.clients[0].initialize(self.actor_sparisty, self.critic_sparisty)
for client in self.clients:
client.set_model_masks(actor_mask, critic_mask)
for episode in range(self.eposide):
if episode % self.aggregate == 0 and episode > 0:
self.sparse_fed_model_average()
for id, client in enumerate(self.clients):
index = episode % len(client.train_data)
reward = client.train(index)
print(f"home: {id}, eposide: {episode}, reward: {reward}")
if episode % self.update == 0 and episode > 0:
for client in self.clients:
client.dynamic_update(episode)