-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscaler.py
31 lines (23 loc) · 810 Bytes
/
scaler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
class StandardScaler():
def __init__(self, mean=None, std=None):
self.mean = mean
self.std = std
def save(self, filename):
torch.save({'mean': self.mean, 'std': self.std}, filename)
def load(self, filename):
data = torch.load(filename)
if 'mean' not in data or 'std' not in data:
raise ValueError('No mean or std in the file.')
self.mean = data['mean']
self.std = data['std']
def fit(self, X):
self.mean = torch.mean(X, dim=0)
self.std = torch.std(X, dim=0)
def transform(self, X):
return (X - self.mean) / self.std
def fit_transform(self, X):
self.fit(X)
return self.transform(X)
def inverse_transform(self, X):
return X * self.std + self.mean