-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoverfit_prevention.py
33 lines (25 loc) · 1.08 KB
/
overfit_prevention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
class Thresholdout:
def __init__(self, labels, noise_rate):
self.labels = labels
self.noise_rate = noise_rate / np.sqrt(len(labels))
self.threshold = 2 * noise_rate / np.sqrt(len(labels))
self.noisy_t = self.threshold + self.sample_laplace_noise(self.noise_rate * 2)
@staticmethod
def sample_laplace_noise(scale):
"""Sample a Laplace noise with specified scale."""
return np.random.laplace(loc=0, scale=scale)
def score(self, train_acc, val_acc):
"""
Process a single query function and update the threshold accordingly.
Returns:
float: The result of the query
"""
eta = self.sample_laplace_noise(4 * self.noise_rate)
if abs(val_acc - train_acc) > self.noisy_t + eta:
xi = self.sample_laplace_noise(self.noise_rate)
gamma = self.sample_laplace_noise(2 * self.noise_rate)
self.noisy_t = self.threshold + gamma
return 1 - (val_acc + xi)
else:
return 1 - train_acc