-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfed_avg_ennio_remote.py
126 lines (99 loc) · 4.15 KB
/
fed_avg_ennio_remote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 31 08:17:23 2021
@author: Usert990
"""
import numpy as np
import sys
import time
import memory_profiler
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
import torch
from torch.utils.data import Dataset, DataLoader
#from torchvision import datasets, transforms
#from syft.frameworks.torch.federated import utils
import run_websocket_client_iot as rwc
import logging
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)
#eps = 0.01
#lambd = 0.0001
#lambi = 0.01
hook = sy.TorchHook(torch)
kwargs_websocket_alice = {"host": "192.168.0.23", "hook": hook}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket_alice)
kwargs_websocket_bob = {"host": "192.168.0.24", "hook": hook}
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket_bob)
kwargs_websocket_charlie = {"host": "192.168.0.25", "hook": hook}
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket_charlie)
kwargs_websocket_jane = {"host": "192.168.0.26", "hook": hook}
jane = WebsocketClientWorker(id="jane", port=8780, **kwargs_websocket_jane)
workers = [alice, bob, charlie, jane]
print(workers)
#run this box only if the the next box gives pipeline error
#Get data set
class IoTDataset(Dataset):
# Initialize your data, download, etc.
def __init__(self):
benign = np.loadtxt("benign_traffic.csv", delimiter = ",", dtype=np.float32)
gafgyt = np.loadtxt("gafgyt_traffic.csv", delimiter = ",", dtype=np.float32)
alldata = np.concatenate((benign, gafgyt))
j = len(benign[0])
data = alldata[:, 1:j]
benlabel = alldata[:, 0]
bendata = (data - data.min()) / (data.max() - data.min())
self.len = alldata.shape[0]
self.x_data = torch.from_numpy(bendata)
self.y_data = torch.from_numpy(benlabel)
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
full_dataset = IoTDataset()
train_size = int(len(full_dataset)* 0.8)
test_size = len(full_dataset) - train_size
# split the dataset
trainset, testset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
trainset = trainset.dataset
testset = testset.dataset
federated_train_loader = sy.FederatedDataLoader(
trainset.federate(tuple(workers)),
batch_size= args.batch_size,
shuffle=True,
iter_per_worker=True
)
test_loader = DataLoader(
dataset=testset, batch_size=args.batch_size, shuffle=True)
model = rwc.nmodel().to(device)
print(model)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
handler.setFormatter(formatter)
logger.handlers = [handler]
for epoch in range(1, args.epochs + 1):
print("Starting epoch {}/{}".format(epoch, args.epochs))
starttbase = time.time()
startmbase = memory_profiler.memory_usage()
#starttefi = time.time()
#startmefi = memory_profiler.memory_usage()
model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
endtbase =time.time()
endmbase = memory_profiler.memory_usage()
traintime_base = endtbase - starttbase
train_memory_base = endmbase[0] - startmbase[0]
print("Training time base: {:2f} sec".format(traintime_base))
print("Training memory base: {:2f} mb".format(train_memory_base))
rwc.test(model, device, test_loader)
#endtefi = time.time()
#endmefi = memory_profiler.memory_usage()
#traintime_efi = starttefi - endtefi
#train_memory_efi = endmefi[0] - startmefi[0]
#print("Training time optimize: {:2f} sec".format(traintime_efi))
#print("Training memory optimize: {:2f} mb".format(train_memory_efi))
#rwc.test(model, device, test_loader)