diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 1fe11e596a359..ee708bb93bb26 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -889,10 +889,11 @@ def export_text(decision_tree, feature_names=None, max_depth=10, else: value_fmt = "{}{} value: {}\n" - if feature_names: - feature_names_ = [feature_names[i] for i in tree_.feature] + if feature_names is not None: + feature_names_ = list(feature_names) else: - feature_names_ = ["feature_{}".format(i) for i in tree_.feature] + feature_names_ = ["feature_{}".format(i) + for i in range(tree_.n_features)] export_text.report = "" @@ -928,7 +929,8 @@ def print_tree_recurse(node, depth): info_fmt_right = info_fmt if tree_.feature[node] != _tree.TREE_UNDEFINED: - name = feature_names_[node] + feature_index = tree_.feature[node] + name = feature_names_[feature_index] threshold = tree_.threshold[node] threshold = "{1:.{0}f}".format(decimals, threshold) export_text.report += right_child_fmt.format(indent, diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 503431f456874..f51f624a20159 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -397,6 +397,26 @@ def test_export_text(): assert export_text(reg, decimals=1, show_weights=True) == expected_report +def test_export_text_single_feature_with_names(): + X_single = [[-2], [-1], [-1], [1], [1], [2]] + y_single = [-1, -1, -1, 1, 1, 1] + clf = DecisionTreeClassifier(random_state=0) + clf.fit(X_single, y_single) + + report = export_text(clf, feature_names=['single']) + assert_in('single <=', report) + + +def test_export_text_single_feature_auto_names(): + X_single = [[-2], [-1], [-1], [1], [1], [2]] + y_single = [-1, -1, -1, 1, 1, 1] + clf = DecisionTreeClassifier(random_state=0) + clf.fit(X_single, y_single) + + report = export_text(clf) + assert_in('feature_0', report) + + def test_plot_tree_entropy(pyplot): # mostly smoke tests # Check correctness of export_graphviz for criterion = entropy