Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sklearn/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,10 +889,11 @@ def export_text(decision_tree, feature_names=None, max_depth=10,
else:
value_fmt = "{}{} value: {}\n"

if feature_names:
feature_names_ = [feature_names[i] for i in tree_.feature]
if feature_names is not None:
feature_names_ = list(feature_names)
else:
feature_names_ = ["feature_{}".format(i) for i in tree_.feature]
feature_names_ = ["feature_{}".format(i)
for i in range(tree_.n_features)]

export_text.report = ""

Expand Down Expand Up @@ -928,7 +929,8 @@ def print_tree_recurse(node, depth):
info_fmt_right = info_fmt

if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_names_[node]
feature_index = tree_.feature[node]
name = feature_names_[feature_index]
threshold = tree_.threshold[node]
threshold = "{1:.{0}f}".format(decimals, threshold)
export_text.report += right_child_fmt.format(indent,
Expand Down
20 changes: 20 additions & 0 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,26 @@ def test_export_text():
assert export_text(reg, decimals=1, show_weights=True) == expected_report


def test_export_text_single_feature_with_names():
X_single = [[-2], [-1], [-1], [1], [1], [2]]
y_single = [-1, -1, -1, 1, 1, 1]
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_single, y_single)

report = export_text(clf, feature_names=['single'])
assert_in('single <=', report)


def test_export_text_single_feature_auto_names():
X_single = [[-2], [-1], [-1], [1], [1], [2]]
y_single = [-1, -1, -1, 1, 1, 1]
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_single, y_single)

report = export_text(clf)
assert_in('feature_0', report)


def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = entropy
Expand Down