-
Notifications
You must be signed in to change notification settings - Fork 1
/
sub-mc.py
176 lines (124 loc) · 5.89 KB
/
sub-mc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import paho.mqtt.client as mqtt
import paho.mqtt.subscribe as subscribe
import time, queue, sys
import numpy as np
import pandas as pd
import json
from json import JSONEncoder
from federated_utils import *
from Numpy_to_JSON_utils import *
global qGSModel, qGSPerfm
qGSModel = queue.Queue()
qGSPerfm = queue.Queue()
global global_model_result, prev_global_model, current_global_model
global_model_result =[]
prev_global_model = list()
l_rate = 0.05 #Learning rate
def on_connect(client, userdata, flags, rc):
if rc ==0:
print("Global Server connected to broker successfylly ")
else:
print(f"Failed with code {rc}")
for i in topic_list:
val = client.subscribe(i)
print(val)
def on_message(client, userdata, message):
if (message.topic == "LocalModel"):
print("Message received from Local Model")
qGSModel.put(message)
if(message.topic == 'ModelPerformance'):
print('Performance metric received from Local Models')
qGSPerfm.put(message)
print('---------------------------------------------------------------')
# mqttBroker = "mqtt.eclipseprojects.io"
mqttBroker = "127.0.0.1"
client = mqtt.Client(client_id ="GlobalServer", clean_session=True)
client.on_connect = on_connect
client.connect(mqttBroker,1883)
topic_list =[('LocalModel',0),('ModelPerformance',0)]
client.loop_start()
client.on_message = on_message
#**********************************************************
time.sleep(5) #Wait for connection setup to complete
#**********************************************************
print('---------------------------------------------------------------')
i = 0
while True:
print('---------STARTED-------------')
print('Global Server')
# print('Round: ',i)
time.sleep(5)
#====================================================
# Global Model Performance Printing
#====================================================
if (i>0): #after first round of model exchange global models performance is calculated
print('Now Collecting Local Model erformacen Metrics....')
local_model_performace = list()
while not qGSPerfm.empty():
message = qGSPerfm.get()
if message is None:
continue
msg_model_performance = message.payload.decode('utf-8')
decodedModelPerfromance = list(json.loads(msg_model_performance).values())
local_model_performace.append(decodedModelPerfromance)
global_model_performance = np.array(local_model_performace)
global_performance = np.mean(global_model_performance, axis=0)
len_local_perfm = len(local_model_performace)
print('Total Model Performance received:',len_local_perfm)
if (len_local_perfm != 0):
global_model_result.append([global_performance[1],global_performance[2]])
print('----------------------------------------------------')
print('Global Model Accuracy:',global_performance[1])
print('Global Model F1-score:',global_performance[2])
print('----------------------------------------------------')
else:
break #No more data from local model
#**********************************************************
time.sleep(50) #to receive model weights
#**********************************************************
#=========================================================
# Local Model Receiving Part
#=========================================================
all_local_model_weights = list()
while not qGSModel.empty():
message = qGSModel.get()
if message is None:
continue
msg = message.payload.decode('utf-8')
decodedweights = json2NumpyWeights(msg)
# Convert object to a list
local_model_weights = list(decodedweights)
scaled_weights = scale_model_weights(local_model_weights, 0.1)
all_local_model_weights.append(scaled_weights)
print('Total Local Model Received:',len(all_local_model_weights))
#======================================================
i +=1 #Next round increment
#===================================================================
# Publish the Global Model after Federated Averaging
#===================================================================
if i >0:
#to get the average over all the local model, we simply take the sum of the scaled weights
averaged_weights = list()
averaged_weights = sum_scaled_weights(all_local_model_weights)
global_weights = EagerTensor2Numpy(averaged_weights)
encodedGlobalModelWeights = json.dumps(global_weights,cls=Numpy2JSONEncoder)
client.publish("GlobalModel", payload = encodedGlobalModelWeights) #str(Global_weights), qos=0, retain=False)
print("Broadcasted Global Model to Topic:--> GlobalModel")
#**********************************************************
time.sleep(30) #pause it so that the publisher gets the Global model
#**********************************************************
#====================================================================
print('---------------HERE------------------')
#===================================================================================
# If No more data from Publisher exit and server closed connection to the broker
#===================================================================================
if(i >0 and len(all_local_model_weights)==0): #loop break no message from producer
break
#Global Model Result Save
folderPath = '/home/gp/Desktop/PhD-codes/Fed-ReMECS-mqtt/Federated_Results/'
fname_fm = folderPath +'_Global_Model' +'_'+'_results.csv'
column_names = ['Acc', 'F1']
global_model_result = pd.DataFrame(global_model_result,columns = column_names)
global_model_result.to_csv(fname_fm)
print("All done, Global Server Closed.")
client.loop_stop()