-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrocket.py
65 lines (46 loc) · 2.14 KB
/
rocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from aeon.classification.convolution_based import RocketClassifier
import numpy as np
from constants import BATCH_SIZE
class Rocket:
def __init__(self, num_kernels : int = 500):
self.classifier = RocketClassifier(num_kernels = num_kernels)
print("RocketClassifier built")
def fit(self, X_train, y_train):
self.classifier.fit(X_train, y_train)
def predict(self, X_test, batch_size=BATCH_SIZE):
# Calculate the number of batches
n_batches = int(np.ceil(X_test.shape[0] / batch_size))
# Create an empty list to store the batch-wise predictions
y_preds = []
# Loop over each batch and append predictions to y_preds
for i in range(n_batches):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
y_pred_batch = self.classifier.predict(X_test[start_idx:end_idx])
y_preds.append(y_pred_batch)
# Concatenate all batch predictions to get the final y_pred array
y_pred = np.concatenate(y_preds, axis=0)
return y_pred
def predict_proba(self, X_test, batch_size=BATCH_SIZE):
# Calculate the number of batches
n_batches = int(np.ceil(X_test.shape[0] / batch_size))
# Create an empty list to store the batch-wise predictions
y_preds = []
# Loop over each batch and append predictions to y_preds
for i in range(n_batches):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
y_pred_batch = self.classifier.predict_proba(X_test[start_idx:end_idx])
y_preds.append(y_pred_batch)
# Concatenate all batch predictions to get the final y_pred array
y_pred = np.concatenate(y_preds, axis=0)
return y_pred[:, 1]
def predict_sample(self, sample):
return self.classifier.predict_proba(sample)[:, 1]
def dump(self, path : str):
self.classifier.save(path)
def load(self, path : str):
self.classifier = RocketClassifier.load_from_path(path + ".zip")
# set seed for reproducibility
def set_seeds(self, seed : int = 42):
np.random.seed(seed)