-
Notifications
You must be signed in to change notification settings - Fork 1
/
training2.py
137 lines (120 loc) · 4.27 KB
/
training2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import pickle
import numpy as np
import convnet as co
import io
import tenseal as ts
#torch.manual_seed(73)
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
def train(model, train_loader, criterion, optimizer, n_epochs=10):
# model in training mode
model.train()
for epoch in range(1, n_epochs + 1):
train_loss = 0.0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# calculate average losses
train_loss = train_loss / len(train_loader)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
# model in evaluation mode
model.eval()
return model
print("Training 2")
# Verschlüsselte Werte aus Datei auslesen
with (open("model.pt", "rb")) as f:
pickload = pickle.load(f)
buffer = io.BytesIO()
torch.save(pickload, buffer)
# Setze die Position des Puffers auf den Anfang
buffer.seek(0)
# Lade das Modell aus dem Puffer
loaded_model = torch.load(buffer, map_location=torch.device('cpu'))
#model = torch.load(pickload)
print(loaded_model)
model = co.ConvNet()
model.load_state_dict(loaded_model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# model muss jedes Mal geupdated werden, transport?
model = train(model, train_loader, criterion, optimizer, 10)
#torch.save(model.state_dict(), "model.pt")
weights = []
param_ten = []
for param_tensor in model.state_dict():
print(param_tensor)
param_ten.append(param_tensor)
weights.append(model.state_dict()[param_tensor])
print(weights)
# Initialisieren Sie TenSeal Context
bits_scale = 26
# Create TenSEAL context
context = ts.context(
ts.SCHEME_TYPE.CKKS,
poly_modulus_degree=8192,
coeff_mod_bit_sizes=[31, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, 31]
)
# set the scale
context.global_scale = pow(2, bits_scale)
# galois keys are required to do ciphertext rotations
context.generate_galois_keys()
# Erstellen Sie ein Tensor-Objekt für jeden Gewichtstensor und verschlüsseln Sie es
global encrypted_weights
encrypted_weights = []
weights_lists = []
n = 0
for weight in weights:
new = []
print(param_ten)
new.append(param_ten[n])
n = n + 1
print(weight)
print(weight.numpy().shape)
if len(weight.numpy().shape) == 1:
encrypted_weight = ts.ckks_vector(context, j)
encrypted_weights.append(encrypted_weight)
new.append(encrypted_weight)
else:
for w in weight.numpy():
print("Hier")
print(len(w.shape))
if len(w.shape) == 1:
encrypted_weight = ts.ckks_vector(context, w)
new.append(encrypted_weight)
encrypted_weights.append(encrypted_weight)
else:
for i in w:
print("Dort")
print(i.shape)
if len(i.shape) == 1:
encrypted_weight = ts.ckks_vector(context, i)
new.append(encrypted_weight)
encrypted_weights.append(encrypted_weight)
else:
for j in i:
print("Dort2")
print(j.shape)
encrypted_weight = ts.ckks_vector(context, j)
print(encrypted_weight)
new.append(encrypted_weight)
encrypted_weights.append(encrypted_weight)
print("NEW")
print(new)
#Liste an Gewichten und darauf Tensoren
weights_lists.append(new)
print(weights_lists)
print(encrypted_weights)
#Hier verschlüsseln?
#global updated
#updated = pickle.dumps(model.state_dict())
#print(updated)