-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
executable file
·35 lines (30 loc) · 1.21 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from sklearn.pipeline import Pipeline
class Pipeline(Pipeline):
"""docstring for Pipeline"""
def __init__(self, class_list, save_path=None):
self.class_list = class_list
self.steps = self.load_steps(class_list)
super(Pipeline, self).__init__(self.steps)
self.set_save_path(save_path)
def load_steps(self, class_list):
steps = []
for dict_ in class_list:
if "class" not in dict_:
raise RuntimeError("Missing class key in config of Pipeline/"
"class_list")
if "name" in dict_:
name = dict_["name"]
else:
name = dict_["class"].__name__
if "params" in dict_:
params = dict_["params"]
steps.append((name, dict_["class"](**params)))
else:
steps.append((name, dict_["class"]()))
return steps
def set_save_path(self, save_path):
self.save_path = save_path
for dict_ in self.class_list:
if hasattr(dict_["class"], "set_save_path"):
param = {dict_["class"].__name__+"__save_path": save_path}
self.set_params(**param)