-
Notifications
You must be signed in to change notification settings - Fork 0
/
federated_concurrent.py
68 lines (54 loc) · 2.16 KB
/
federated_concurrent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from federated_model import FederatedModel
from utils import load_dataset, split_dataset, split_dataset_grouped, remove_outliers, make_missing_values, knn_impute_dataset
from local_model import LocalModel
import matplotlib.pyplot as plt
import numpy as np
def run_model(model):
model.fit()
return model
if __name__ == '__main__':
number_of_models = 2
data, feature_names = load_dataset("winequality-red.csv")
# data, feature_names = load_dataset("winequality-white.csv")
data = remove_outliers(data, neighbors=3)
print(data.shape)
data = data[:200, :]
# datasets = split_dataset(data, number_of_models)
datasets = split_dataset_grouped(data, number_of_models)
# datasets = split_dataset(data, number_of_models, overlap=True, block_size=200)
print("Datasets shapes: ")
shapes = ""
for d in datasets:
shapes += str(d.shape)
print(shapes)
index = 0
datasets_weights = list()
for i in range(number_of_models):
datasets[i], d_weight = make_missing_values(datasets[i])
datasets[i] = knn_impute_dataset(datasets[i])
datasets_weights.append(d_weight)
i += 1
print("Datasets weights before normalization: ", datasets_weights)
datasets_weights = np.array(datasets_weights) / sum(datasets_weights)
print("Datasets weights after normalization: ", datasets_weights)
local_models = list()
training_times = list()
for i in range(number_of_models):
lm = LocalModel(datasets[i], datasets_weights[i], feature_names, epochs=100, validity_method="Spearman")
local_models.append(lm)
fm = FederatedModel(local_models, feature_names, rounds_count=5)
rules = fm.create_rules()
fm.set_federated_rules_to_local_models(rules)
fm.fit()
print()
print("Number of federated rules: ", fm.number_of_fed_rules)
print()
print("Dataset OWA weights")
print(datasets_weights)
print()
print("Average of MSE on local and aggregated fis: ")
print(np.mean(fm.final_results, axis=1))
print()
print("Average standard deviation on local and aggregated fis: ")
print(np.std(fm.final_results, axis=1))
exit()