diff --git a/code/decision_tree.py b/code/decision_tree.py index 8ad3337..406b2fb 100644 --- a/code/decision_tree.py +++ b/code/decision_tree.py @@ -5,22 +5,44 @@ class TreeNode: - def __init__(self, data_idx, depth, child_lst=[]): + def __init__(self, data_idx, depth, child_lst=None): + """ + Initialize a TreeNode object. + + Parameters: + - data_idx (int): Index of the data associated with the node. + - depth (int): Depth of the node in the tree. + - child_lst (list, optional): List of child nodes. Defaults to an empty list. + """ self.data_idx = data_idx self.depth = depth - self.child = child_lst + self.child = child_lst if child_lst is not None else [] self.label = None self.split_col = None self.child_cate_order = None def set_attribute(self, split_col, child_cate_order=None): + """ + Set the attributes of the node. + + Parameters: + - split_col: Column used for splitting the data. + - child_cate_order (list, optional): Order of child categories. Defaults to None. + """ self.split_col = split_col self.child_cate_order = child_cate_order def set_label(self, label): + """ + Set the label for the node. + + Parameters: + - label: The label to be assigned to the node. + """ self.label = label + class DecisionTree(metaclass=ABCMeta): def __init__(self, max_depth, min_sample_leaf, min_split_criterion=1e-4, verbose=False): self.max_depth = max_depth @@ -296,4 +318,4 @@ def get_nex_node(self, node: TreeNode, x: np.array): if tree_type == "classification": print(classification_report(y_true=y_test, y_pred=y_pred)) else: - print(mean_squared_error(y_test, y_pred)) \ No newline at end of file + print(mean_squared_error(y_test, y_pred))