You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I often use useful functions related to edge. We would like you to provide such functions. Here I provide sample code for the following two edge functions. Please consider providing such functions.
remove_edges: remove an edge from edge_index.
check_is_exist_edge: check whether an edge exists in edge_index.
fromtypingimportOptional, overloadimporttorchfromtorch_geometricsimportcoalesce@overloaddefremove_edges(
edge_index: torch.Tensor,
removed_edge_index: torch.Tensor,
edge_attr: None,
) ->torch.Tensor: ...
@overloaddefremove_edges(
edge_index: torch.Tensor,
removed_edge_index: torch.Tensor,
edge_attr: torch.Tensor,
) ->tuple[torch.Tensor, torch.Tensor]: ...
defremove_edges(
edge_index: torch.Tensor,
removed_edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] =None,
):
"""Remove edges from edge_index. Args: edge_index: edge index, shape=(2, n) n: number of edges edge_attr: edge attributes, shape=(n, m) m: number of edge attributes removed_edge_index: target edge index to remove, shape=(2, m) m: number of edges to remove, default=None Returns: remaining_edge_index: edge index after removing edges remaining_edge_attrs: edge attributes after removing edges """all_edge_index=torch.cat([edge_index, removed_edge_index], dim=1)
# mark removed edges as 1 and 0 otherwiseall_edge_removed_flg=torch.cat(
[
torch.zeros(edge_index.size(1)),
torch.ones(removed_edge_index.size(1)),
]
).to(all_edge_index.device)
# maxを取ることでedge_index_typeを保持するため、removed_edge_indexのedge_index_typeは0にするifedge_attrisnotNone:
all_edge_attrs_=torch.cat(
[
edge_attr,
torch.zeros(removed_edge_index.size(1), edge_attr.size(1)),
]
).to(all_edge_index.device)
all_edge_attrs=torch.vstack(
[all_edge_removed_flg, all_edge_attrs_]
).T# shape: (n + m, 2)else:
all_edge_attrs=all_edge_removed_flgall_edge_index, all_edge_attrs=coalesce(
all_edge_index, all_edge_attrs, reduce='max'
)
# remove edges indicated by 1ifedge_attrisnotNone:
mask=all_edge_attrs[:, 0] ==0remaining_edge_index=all_edge_index[:, mask]
remaining_edge_attrs=all_edge_attrs[mask, 1:]
returnremaining_edge_index, remaining_edge_attrselse:
mask=all_edge_attrs==0remaining_edge_index=all_edge_index[:, mask]
returnremaining_edge_indexdefcheck_is_exist_edge(
edge_index: torch.Tensor, edges_to_check: torch.Tensor
) ->torch.Tensor:
"""Check if `edges_to_check` is in edge_index. Args: edge_index: edge index, (2, n), n: number of edges edges_to_check: target edge index to check, (2, m), m: number of edges to check Returns: bool tensor indicating if edges_to_check is in edge_index """# shape: (1, n, 2) vs (m, 1, 2)edge_exists= (
(edge_index.T.unsqueeze(0) ==edges_to_check.T.unsqueeze(1))
.all(dim=-1)
.any(dim=-1)
)
returnedge_exists
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I often use useful functions related to edge. We would like you to provide such functions. Here I provide sample code for the following two edge functions. Please consider providing such functions.
remove_edges
: remove an edge fromedge_index
.check_is_exist_edge
: check whether an edge exists inedge_index
.Beta Was this translation helpful? Give feedback.
All reactions