Skip to content

Commit 72f8008

Browse files
authored
Merge pull request #1 from neuroneural/sklearn_update
Sklearn update
2 parents 872cfa7 + 9efb1af commit 72f8008

File tree

7 files changed

+16
-9
lines changed

7 files changed

+16
-9
lines changed

polyssifier/poly_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ def build_classifiers(exclude, scale, feature_selection, nCols):
125125
if 'Decision Tree' not in exclude:
126126
classifiers['Decision Tree'] = {
127127
'clf': DecisionTreeClassifier(max_depth=None,
128-
max_features='auto'),
128+
max_features='sqrt'),
129129
'parameters': {}}
130130

131131
if 'Random Forest' not in exclude:
132132
classifiers['Random Forest'] = {
133133
'clf': RandomForestClassifier(max_depth=None,
134134
n_estimators=10,
135-
max_features='auto'),
135+
max_features='sqrt'),
136136
'parameters': {'n_estimators': list(range(5, 20))}}
137137

138138
if 'Logistic Regression' not in exclude:

polyssifier/polyssifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def poly(data, label, n_folds=10, scale=True, exclude=[],
5555
_le.fit(label)
5656
label = _le.transform(label)
5757
n_class = len(np.unique(label))
58-
logger.info(f'Detected {n_class} classes in label')
58+
logger.info('Detected ' + str(n_class) + ' classes in label')
5959

6060
if save and not os.path.exists('poly_{}/models'.format(project_name)):
6161
os.makedirs('poly_{}/models'.format(project_name))
@@ -84,6 +84,7 @@ def poly(data, label, n_folds=10, scale=True, exclude=[],
8484
kf = list(skf.split(np.zeros(data.shape[0]), label))
8585

8686
# Parallel processing of tasks
87+
8788
manager = Manager()
8889
args = manager.list()
8990
args.append({}) # Store inputs

polyssifier/report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def plot_scores(scores, scoring='auc', file_name='temp', min_val=None):
143143
break
144144
ax1.text(rect.get_x() - rect.get_width() / 2., ymin + (1 - ymin) * .01,
145145
data.index[n], ha='center', va='bottom',
146-
rotation='90', color='black', fontsize=15)
146+
rotation=90, color='black', fontsize=15)
147147
plt.tight_layout()
148148
plt.savefig(file_name + '.pdf')
149149
plt.savefig(file_name + '.svg', transparent=False)

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
'Programming Language :: Python :: 3.4',
6464
'Programming Language :: Python :: 3.5',
6565
'Programming Language :: Python :: 3.6',
66+
'Programming Language :: Python :: 3.7',
67+
'Programming Language :: Python :: 3.8',
68+
'Programming Language :: Python :: 3.9',
69+
'Programming Language :: Python :: 3.10',
70+
'Programming Language :: Python :: 3.11',
71+
'Programming Language :: Python :: 3.12',
6672
],
6773

6874
# What does your project relate to?
@@ -78,6 +84,6 @@
7884
# your project is installed. For an analysis of "install_requires" vs pip's
7985
# requirements files see:
8086
# https://packaging.python.org/en/latest/requirements.html
81-
install_requires=['pandas', 'sklearn', 'numpy', 'matplotlib'],
87+
install_requires=['pandas','scikit-learn', 'numpy', 'matplotlib','joblib'],
8288

8389
zip_safe=False) # Override annoying default behavior of easy_install.

tests/test_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_run():
3636
report = poly(data, label, n_folds=2, verbose=1,
3737
feature_selection=False,
3838
save=False, project_name='test2')
39-
for key, score in report.scores.mean().iteritems():
39+
for key, score in report.scores.mean().items():
4040
assert score < 5, '{} score is too low'.format(key)
4141

4242

@@ -45,7 +45,7 @@ def test_multiclass():
4545
report = poly(data, label, n_folds=2, verbose=1,
4646
feature_selection=False,
4747
save=False, project_name='test3')
48-
for key, score in report.scores.mean().iteritems():
48+
for key, score in report.scores.mean().items():
4949
assert score < 5, '{} score is too low'.format(key)
5050

5151

tests/test_multiclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_run():
2626
report = poly(data, label, n_folds=2, verbose=1,
2727
feature_selection=False,
2828
save=False, project_name='test2')
29-
for key, score in report.scores.mean().iteritems():
29+
for key, score in report.scores.mean().items():
3030
assert score < 5, '{} score is too low'.format(key)
3131

3232

tests/test_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_feature_selection_regression():
2525
assert (report_with_features.scores.mean()[:, 'train'] > 0.2).all(),\
2626
'train score below chance'
2727

28-
for key, ypred in report_with_features.predictions.iteritems():
28+
for key, ypred in report_with_features.predictions.items():
2929
mse = np.linalg.norm(ypred - diabetes_target) / len(diabetes_target)
3030
assert mse < 5, '{} Prediction error is too high'.format(key)
3131

0 commit comments

Comments
 (0)