-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfed_utils.py
175 lines (148 loc) · 5.65 KB
/
fed_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 31 09:22:44 2021
@author: Usert990
"""
import syft as sy
import torch
from typing import List
import logging
logger = logging.getLogger(__name__)
def extract_batches_per_worker(federated_train_loader: sy.FederatedDataLoader):
"""Extracts the batches from the federated_train_loader and stores them
in a dictionary (keys = data.location).
Args:
federated_train_loader: the connection object we use to send responses.
back to the client.
"""
logging_interval = 100
batches = {}
for worker_id in federated_train_loader.workers:
worker = federated_train_loader.federated_dataset.datasets[worker_id].location
batches[worker] = []
for batch_idx, (data, target) in enumerate(federated_train_loader):
if batch_idx % logging_interval == 0:
logger.debug("Extracted %s batches from federated_train_loader", batch_idx)
batches[data.location].append((data, target))
return batches
def add_model(dst_model, src_model):
"""Add the parameters of two models.
Args:
dst_model (torch.nn.Module): the model to which the src_model will be added.
src_model (torch.nn.Module): the model to be added to dst_model.
Returns:
torch.nn.Module: the resulting model of the addition.
"""
params1 = src_model.named_parameters()
params2 = dst_model.named_parameters()
dict_params2 = dict(params2)
with torch.no_grad():
for name1, param1 in params1:
if name1 in dict_params2:
dict_params2[name1].set_(param1.data + dict_params2[name1].data)
return dst_model
def scale_model(model, scale):
"""Scale the parameters of a model.
Args:
model (torch.nn.Module): the models whose parameters will be scaled.
scale (float): the scaling factor.
Returns:
torch.nn.Module: the module with scaled parameters.
"""
params = model.named_parameters()
dict_params = dict(params)
with torch.no_grad():
for name, param in dict_params.items():
dict_params[name].set_(dict_params[name].data * scale)
return model
def federated_avg(models: List[torch.nn.Module]) -> torch.nn.Module:
"""Calculate the federated average of a list of models.
Args:
models (List[torch.nn.Module]): the models of which the federated average is calculated.
Returns:
torch.nn.Module: the module with averaged parameters.
"""
nr_models = len(models)
model_list = list(models.values())
model = model_list[0]
for i in range(1, nr_models):
model = add_model(model, model_list[i])
model = scale_model(model, 1.0 / nr_models)
return model
def accuracy(pred_softmax, target):
"""Calculate the accuray of a given prediction.
This functions assumes pred_softmax to be converted into the final prediction by taking the argmax.
Args:
pred_softmax: array type(float), providing nr_classes values per element in target.
target: array type(int), correct classes, taking values in range [0, nr_classes).
Returns:
accuracy: float, fraction of correct predictions.
"""
nr_elems = len(target)
pred = pred_softmax.argmax(dim=1)
return (pred.float() == target.view(pred.shape).float()).sum().numpy() / float(nr_elems)
def create_gaussian_mixture_toy_data(nr_samples: int): # pragma: no cover
""" Create a simple toy data for binary classification
The data is drawn from two normal distributions
target = 1: mu = 2, sigma = 1
target = 0: mu = 0, sigma = 1
The dataset is balanced with an equal number of positive and negative samples
Args:
nr_samples: number of samples to generate
Returns:
data, targets
"""
sample_dim = 2
one_half = int(nr_samples / 2)
X1 = torch.randn(one_half, sample_dim, requires_grad=True)
X2 = torch.randn(one_half, sample_dim, requires_grad=True) + 2
X = torch.cat([X1, X2], dim=0)
Y1 = torch.zeros(one_half, requires_grad=False).long()
Y2 = torch.ones(one_half, requires_grad=False).long()
Y = torch.cat([Y1, Y2], dim=0)
return X, Y
def iris_data_partial():
"""
Returns: 30 samples from the iris data set: https://archive.ics.uci.edu/ml/datasets/iris
"""
data = [
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.0, 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5.0, 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5.0, 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
]
target_to_string = {0: "Iris-setosa", 1: "Iris-versicolor", 2: "Iris-virginica"}
targets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
data += [
[7.0, 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4.0, 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1.0],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
]
targets += [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
data += [
[6.3, 3.3, 6.0, 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3.0, 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3.0, 5.8, 2.2],
[7.6, 3.0, 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
]
targets += [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
return torch.tensor(data), torch.tensor(targets)