diff --git a/machine_learning/decision_tree.py b/machine_learning/decision_tree.py index 72970431c3fc..c0d967962965 100644 --- a/machine_learning/decision_tree.py +++ b/machine_learning/decision_tree.py @@ -87,16 +87,14 @@ def train(self, x, y): if y.ndim != 1: raise ValueError("Data set labels must be one-dimensional") - if len(x) < 2 * self.min_leaf_size: - self.prediction = np.mean(y) - return + mean_y = np.mean(y) - if self.depth == 1: - self.prediction = np.mean(y) + if len(x) < 2 * self.min_leaf_size or self.depth == 1: + self.prediction = mean_y return best_split = 0 - min_error = self.mean_squared_error(x, np.mean(y)) * 2 + min_error = self.mean_squared_error(x, mean_y) * 2 """ loop over all possible splits for the decision tree. find the best split. @@ -105,17 +103,21 @@ def train(self, x, y): the predictor """ for i in range(len(x)): - if len(x[:i]) < self.min_leaf_size: # noqa: SIM114 - continue - elif len(x[i:]) < self.min_leaf_size: + if len(x[:i]) < self.min_leaf_size or len(x[i:]) < self.min_leaf_size: continue - else: - error_left = self.mean_squared_error(x[:i], np.mean(y[:i])) - error_right = self.mean_squared_error(x[i:], np.mean(y[i:])) - error = error_left + error_right - if error < min_error: - best_split = i - min_error = error + + left_y = y[:i] + right_y = y[i:] + mean_left = np.mean(left_y) + mean_right = np.mean(right_y) + + error_left = self.mean_squared_error(left_y, mean_left) + error_right = self.mean_squared_error(right_y, mean_right) + error = error_left + error_right + + if error < min_error: + best_split = i + min_error = error if best_split != 0: left_x = x[:best_split] @@ -184,7 +186,7 @@ def main(): x = np.arange(-1.0, 1.0, 0.005) y = np.sin(x) - tree = DecisionTree(depth=10, min_leaf_size=10) + tree = DecisionTree(depth=6, min_leaf_size=10) tree.train(x, y) rng = np.random.default_rng() @@ -201,4 +203,4 @@ def main(): main() import doctest - doctest.testmod(name="mean_squarred_error", verbose=True) + doctest.testmod(name="mean_squared_error", verbose=True)