Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kd tree-next_branch is not none 问题 #54

Open
baowenqian2001 opened this issue Feb 25, 2024 · 0 comments
Open

kd tree-next_branch is not none 问题 #54

baowenqian2001 opened this issue Feb 25, 2024 · 0 comments

Comments

@baowenqian2001
Copy link

原来的代码如果搜索到无左或右子树就回退,但事实上无叶节点的分支区域会存在离根节点更近的点
修改后的search代码
` def _search(self, point, tree=None, k=1, k_neighbors_sets=None, depth=0):
"""算法3.3 搜索

    Args:
        point (_type_): _description_
        tree (_type_, optional): _description_. Defaults to None.
        k (int, optional): _description_. Defaults to 1.
        k_neighbors_sets (_type_, optional): _description_. Defaults to None.
        depth (int, optional): _description_. Defaults to 0.

    Returns:
        _type_: _description_
    """
    n = point.shape[1] # 看输入格式np.array([[3, 4.5]]) # shape:(1, 2)
    if k_neighbors_sets is None:
        k_neighbors_sets = []
    if tree is None:
        return k_neighbors_sets
    
    # (1)找到包含目标点x的叶节点
    if tree.left_child is None and tree.right_child is None:
        # 更新当前k近邻集
        return self._update_k_neighbor_sets(k_neighbors_sets, k ,tree, point)
    
    # 递归地向下访问kd树
    if point[0][depth % n] < tree.value[depth % n]:
        direct = 'left'
        next_branch = tree.left_child
    else:
        direct = 'right'
        next_branch = tree.right_child
        
    if next_branch is not None:
        # (3)(b)检查另一子节点对应的区域是否相交
        # 递归
        k_neighbors_sets = self._search(point, tree=next_branch, k=k, depth=depth + 1,
                                         k_neighbors_sets=k_neighbors_sets)
        # 计算目标点与切分点形成的分割超平面的距离
        temp_dist = abs(tree.value[depth % n] - point[0][depth % n])
        
        # 判断超球体是否与超平面相交
        if not(k_neighbors_sets[0][0] < temp_dist and len(k_neighbors_sets) == k): # 换到另一侧
            # 如果相交,递归地进行近邻搜索
            # 判断当前结点,并更新当前k近邻点集
            k_neighbors_sets = self._update_k_neighbor_sets(k_neighbors_sets, k, tree, point) # tree 返回父节点
            if direct == 'left':
                return self._search(point, tree=tree.right_child, k=k, depth = depth + 1, k_neighbors_sets=k_neighbors_sets)
            else:
                return self._search(point, tree=tree.left_child, k=k, depth = depth + 1, k_neighbors_sets=k_neighbors_sets)
    else:
        temp_dist = abs(tree.value[depth % n] - point[0][depth % n])
        
        # 判断超球体是否与超平面相交
        if not(len(k_neighbors_sets) == k): # 换到另一侧
            # 如果相交,递归地进行近邻搜索
            # 判断当前结点,并更新当前k近邻点集
            k_neighbors_sets = self._update_k_neighbor_sets(k_neighbors_sets, k, tree, point) # tree 返回父节点
            if direct == 'left':
                return self._search(point, tree=tree.right_child, k=k, depth = depth + 1, k_neighbors_sets=k_neighbors_sets)
            else:
                return self._search(point, tree=tree.left_child, k=k, depth = depth + 1, k_neighbors_sets=k_neighbors_sets)
        # return self._update_k_neighbor_sets(k_neighbors_sets, k,tree, point)
    
    return k_neighbors_sets`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant