-
Notifications
You must be signed in to change notification settings - Fork 1
/
CVProgressBar.py
58 lines (44 loc) · 2.34 KB
/
CVProgressBar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sklearn.model_selection
# sets up from cmd: "conda install progressbar2"
import progressbar
from tqdm import tqdm
class GridSearchCVProgressBar(sklearn.model_selection.GridSearchCV):
"""Monkey patch Parallel to have a progress bar during grid search"""
def _get_param_iterator(self):
"""Return ParameterGrid instance for the given param_grid"""
iterator = super(GridSearchCVProgressBar, self)._get_param_iterator()
iterator = list(iterator)
n_candidates = len(iterator)
cv = sklearn.model_selection._split.check_cv(self.cv, None)
n_splits = getattr(cv, 'n_splits', 3)
max_value = n_candidates * n_splits
class ParallelProgressBar(sklearn.model_selection._search.Parallel):
def __call__(self, iterable):
# bar = progressbar.ProgressBar(max_value=max_value, title='GridSearchCV')
# iterable = bar(iterable)
iterable = tqdm(iterable, total=max_value)
iterable.set_description("GridSearchCV")
return super(ParallelProgressBar, self).__call__(iterable)
# Monkey patch
sklearn.model_selection._search.Parallel = ParallelProgressBar
return iterator
class RandomizedSearchCVProgressBar(sklearn.model_selection.RandomizedSearchCV):
"""Monkey patch Parallel to have a progress bar during grid search"""
def _get_param_iterator(self):
"""Return ParameterGrid instance for the given param_grid"""
iterator = super(RandomizedSearchCVProgressBar, self)._get_param_iterator()
iterator = list(iterator)
n_candidates = len(iterator)
cv = sklearn.model_selection._split.check_cv(self.cv, None)
n_splits = getattr(cv, 'n_splits', 3)
max_value = n_candidates * n_splits
class ParallelProgressBar(sklearn.model_selection._search.Parallel):
def __call__(self, iterable):
# bar = progressbar.ProgressBar(max_value=max_value, title='RandomizedSearchCV')
# iterable = bar(iterable)
iterable = tqdm(iterable, total=max_value)
iterable.set_description("RandomizedSearchCV")
return super(ParallelProgressBar, self).__call__(iterable)
# Monkey patch
sklearn.model_selection._search.Parallel = ParallelProgressBar
return iterator