Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions machine_learning/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,14 @@ def train(self, x, y):
if y.ndim != 1:
raise ValueError("Data set labels must be one-dimensional")

if len(x) < 2 * self.min_leaf_size:
self.prediction = np.mean(y)
return
mean_y = np.mean(y)

if self.depth == 1:
self.prediction = np.mean(y)
if len(x) < 2 * self.min_leaf_size or self.depth == 1:
self.prediction = mean_y
return

best_split = 0
min_error = self.mean_squared_error(x, np.mean(y)) * 2
min_error = self.mean_squared_error(x, mean_y) * 2

"""
loop over all possible splits for the decision tree. find the best split.
Expand All @@ -105,17 +103,21 @@ def train(self, x, y):
the predictor
"""
for i in range(len(x)):
if len(x[:i]) < self.min_leaf_size: # noqa: SIM114
continue
elif len(x[i:]) < self.min_leaf_size:
if len(x[:i]) < self.min_leaf_size or len(x[i:]) < self.min_leaf_size:
continue
else:
error_left = self.mean_squared_error(x[:i], np.mean(y[:i]))
error_right = self.mean_squared_error(x[i:], np.mean(y[i:]))
error = error_left + error_right
if error < min_error:
best_split = i
min_error = error

left_y = y[:i]
right_y = y[i:]
mean_left = np.mean(left_y)
mean_right = np.mean(right_y)

error_left = self.mean_squared_error(left_y, mean_left)
error_right = self.mean_squared_error(right_y, mean_right)
error = error_left + error_right

if error < min_error:
best_split = i
min_error = error

if best_split != 0:
left_x = x[:best_split]
Expand Down Expand Up @@ -184,7 +186,7 @@ def main():
x = np.arange(-1.0, 1.0, 0.005)
y = np.sin(x)

tree = DecisionTree(depth=10, min_leaf_size=10)
tree = DecisionTree(depth=6, min_leaf_size=10)
tree.train(x, y)

rng = np.random.default_rng()
Expand All @@ -201,4 +203,4 @@ def main():
main()
import doctest

doctest.testmod(name="mean_squarred_error", verbose=True)
doctest.testmod(name="mean_squared_error", verbose=True)