From 12ed4d19f3abac1b87aed6b05bca3f99aa29a823 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Thu, 7 Nov 2024 21:49:39 -0500 Subject: [PATCH 01/20] Replaces multi_index_mode in QL with a more customizable and easier-to-understand predicate_row_aggregator argument --- hatchet/graphframe.py | 10 +- hatchet/query/engine.py | 64 ++- hatchet/query/object_dialect.py | 37 +- hatchet/query/query.py | 1 + hatchet/query/string_dialect.py | 780 ++------------------------------ 5 files changed, 95 insertions(+), 797 deletions(-) diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index d041597b..eb2b6438 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -472,7 +472,7 @@ def filter( update_inc_cols=True, num_procs=mp.cpu_count(), rec_limit=1000, - multi_index_mode="off", + predicate_row_aggregator=None, ): """Filter the dataframe using a user-supplied function. @@ -537,15 +537,17 @@ def filter( # If a raw Object-dialect query is provided (not already passed to ObjectQuery), # create a new ObjectQuery object. if isinstance(filter_obj, list): - query = ObjectQuery(filter_obj, multi_index_mode) + query = ObjectQuery(filter_obj) # If a raw String-dialect query is provided (not already passed to StringQuery), # create a new StringQuery object. elif isinstance(filter_obj, str): - query = parse_string_dialect(filter_obj, multi_index_mode) + query = parse_string_dialect(filter_obj) # If an old-style query is provided, extract the underlying new-style query. elif issubclass(type(filter_obj), AbstractQuery): query = filter_obj._get_new_query() - query_matches = self.query_engine.apply(query, self.graph, self.dataframe) + query_matches = self.query_engine.apply( + query, self.graph, self.dataframe, predicate_row_aggregator + ) # match_set = list(set().union(*query_matches)) # filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(match_set)] filtered_df = dataframe_copy.loc[dataframe_copy["node"].isin(query_matches)] diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index 9717e240..ec7bf661 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -25,7 +25,7 @@ def reset_cache(self): """Resets the cache in the QueryEngine.""" self.search_cache = {} - def apply(self, query, graph, dframe): + def apply(self, query, graph, dframe, predicate_row_aggregator): """Apply the query to a GraphFrame. Arguments: @@ -37,11 +37,14 @@ def apply(self, query, graph, dframe): (list): A list representing the set of nodes from paths that match the query """ if issubclass(type(query), Query): + aggregator = predicate_row_aggregator + if predicate_row_aggregator is None: + aggregator = query.default_aggregator self.reset_cache() matches = [] visited = set() for root in sorted(graph.roots, key=traversal_order): - self._apply_impl(query, dframe, root, visited, matches) + self._apply_impl(query, dframe, aggregator, root, visited, matches) assert len(visited) == len(graph) matched_node_set = list(set().union(*matches)) # return matches @@ -54,12 +57,14 @@ def apply(self, query, graph, dframe): subq_obj = ObjectQuery(subq) elif isinstance(subq, str): subq_obj = parse_string_dialect(subq) - results.append(self.apply(subq_obj, graph, dframe)) + results.append( + self.apply(subq_obj, graph, dframe, predicate_row_aggregator) + ) return query._apply_op_to_results(results, graph) else: raise TypeError("Invalid query data type ({})".format(str(type(query)))) - def _cache_node(self, node, query, dframe): + def _cache_node(self, node, query, dframe, predicate_row_aggregator): """Cache (Memoize) the parts of the query that the node matches. Arguments: @@ -78,11 +83,16 @@ def _cache_node(self, node, query, dframe): row = dframe.xs(node, level="node", drop_level=False) else: row = dframe.loc[node] - if filter_func(row): + predicate_result = filter_func(row) + if not isinstance(predicate_result, bool): + predicate_result = predicate_row_aggregator(predicate_result) + if predicate_result: matches.append(i) self.search_cache[node._hatchet_nid] = matches - def _match_0_or_more(self, query, dframe, node, wcard_idx): + def _match_0_or_more( + self, query, dframe, predicate_row_aggregator, node, wcard_idx + ): """Process a "*" predicate in the query on a subgraph. Arguments: @@ -98,7 +108,7 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): """ # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: - self._cache_node(node, query, dframe) + self._cache_node(node, query, dframe, predicate_row_aggregator) # If the node matches with the next non-wildcard query node, # end the recursion and return the node. if wcard_idx + 1 in self.search_cache[node._hatchet_nid]: @@ -113,7 +123,9 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): return [[node]] return None for child in sorted(node.children, key=traversal_order): - sub_match = self._match_0_or_more(query, dframe, child, wcard_idx) + sub_match = self._match_0_or_more( + query, dframe, predicate_row_aggregator, child, wcard_idx + ) if sub_match is not None: matches.extend(sub_match) if len(matches) == 0: @@ -128,7 +140,7 @@ def _match_0_or_more(self, query, dframe, node, wcard_idx): return [[]] return None - def _match_1(self, query, dframe, node, idx): + def _match_1(self, query, dframe, predicate_row_aggregator, node, idx): """Process a "." predicate in the query on a subgraph. Arguments: @@ -142,12 +154,12 @@ def _match_1(self, query, dframe, node, idx): Will return None if there are no matches for the "." predicate. """ if node._hatchet_nid not in self.search_cache: - self._cache_node(node, query, dframe) + self._cache_node(node, query, dframe, predicate_row_aggregator) matches = [] for child in sorted(node.children, key=traversal_order): # Cache the node if it's not already cached if child._hatchet_nid not in self.search_cache: - self._cache_node(child, query, dframe) + self._cache_node(child, query, dframe, predicate_row_aggregator) if idx in self.search_cache[child._hatchet_nid]: matches.append([child]) # To be consistent with the other matching functions, return @@ -156,7 +168,9 @@ def _match_1(self, query, dframe, node, idx): return None return matches - def _match_pattern(self, query, dframe, pattern_root, match_idx): + def _match_pattern( + self, query, dframe, predicate_row_aggregator, pattern_root, match_idx + ): """Try to match the query pattern starting at the provided root node. Arguments: @@ -186,7 +200,9 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): # Get the portion of the subgraph that matches the next # part of the query. if wcard == ".": - s = self._match_1(query, dframe, m[-1], pattern_idx) + s = self._match_1( + query, dframe, predicate_row_aggregator, m[-1], pattern_idx + ) if s is None: sub_match.append(s) else: @@ -196,7 +212,13 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): sub_match.append([]) else: for child in sorted(m[-1].children, key=traversal_order): - s = self._match_0_or_more(query, dframe, child, pattern_idx) + s = self._match_0_or_more( + query, + dframe, + predicate_row_aggregator, + child, + pattern_idx, + ) if s is None: sub_match.append(s) else: @@ -221,7 +243,9 @@ def _match_pattern(self, query, dframe, pattern_root, match_idx): pattern_idx += 1 return matches - def _apply_impl(self, query, dframe, node, visited, matches): + def _apply_impl( + self, query, dframe, predicate_row_aggregator, node, visited, matches + ): """Traverse the subgraph with the specified root, and collect all paths that match the query. Arguments: @@ -237,7 +261,7 @@ def _apply_impl(self, query, dframe, node, visited, matches): return # Cache the node if it's not already cached if node._hatchet_nid not in self.search_cache: - self._cache_node(node, query, dframe) + self._cache_node(node, query, dframe, predicate_row_aggregator) # If the node matches the starting/root node of the query, # try to get all query matches in the subgraph rooted at # this node. @@ -247,11 +271,15 @@ def _apply_impl(self, query, dframe, node, visited, matches): if sub_match is not None: matches.extend(sub_match) if 0 in self.search_cache[node._hatchet_nid]: - sub_match = self._match_pattern(query, dframe, node, 0) + sub_match = self._match_pattern( + query, dframe, predicate_row_aggregator, node, 0 + ) if sub_match is not None: matches.extend(sub_match) # Note that the node is now visited. visited.add(node._hatchet_nid) # Continue the Depth First Search. for child in sorted(node.children, key=traversal_order): - self._apply_impl(query, dframe, child, visited, matches) + self._apply_impl( + query, dframe, predicate_row_aggregator, child, visited, matches + ) diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index daf55c65..4c8114f2 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -17,17 +17,7 @@ from .query import Query -def _process_multi_index_mode(apply_result, multi_index_mode): - if multi_index_mode == "any": - return apply_result.any() - if multi_index_mode == "all": - return apply_result.all() - raise ValueError( - "Multi-Index Mode for the Object-based dialect must be either 'any' or 'all'" - ) - - -def _process_predicate(attr_filter, multi_index_mode): +def _process_predicate(attr_filter): """Converts high-level API attribute filter to a lambda""" compops = ("<", ">", "==", ">=", "<=", "<>", "!=") # , @@ -126,12 +116,6 @@ def filter_single_series(df_row, key, single_value): return matches def filter_dframe(df_row): - if multi_index_mode == "off": - raise MultiIndexModeMismatch( - "The ObjectQuery's 'multi_index_mode' argument \ - cannot be set to 'off' when using multi-indexed data" - ) - def filter_single_dframe(node, df_row, key, single_value): if key == "depth": if isinstance(single_value, str) and single_value.lower().startswith( @@ -164,21 +148,18 @@ def filter_single_dframe(node, df_row, key, single_value): raise InvalidQueryFilter( "Value for attribute {} must be a string.".format(key) ) - apply_ret = df_row[key].apply( + return df_row[key].apply( lambda x: re.match(single_value + r"\Z", x) is not None ) - return _process_multi_index_mode(apply_ret, multi_index_mode) if is_numeric_dtype(df_row[key]): if isinstance(single_value, str) and single_value.lower().startswith( compops ): - apply_ret = df_row[key].apply( + return df_row[key].apply( lambda x: eval("{} {}".format(x, single_value)) ) - return _process_multi_index_mode(apply_ret, multi_index_mode) if isinstance(single_value, Real): - apply_ret = df_row[key].apply(lambda x: x == single_value).any() - return _process_multi_index_mode(apply_ret, multi_index_mode) + return df_row[key].apply(lambda x: x == single_value).any() raise InvalidQueryFilter( "Attribute {} has a numeric type. Valid filters for this attribute are a string starting with a comparison operator or a real number.".format( key @@ -218,7 +199,7 @@ def filter_choice(df_row): class ObjectQuery(Query): """Class for representing and parsing queries using the Object-based dialect.""" - def __init__(self, query, multi_index_mode="off"): + def __init__(self, query): """Builds a new ObjectQuery from an instance of the Object-based dialect syntax. Arguments: @@ -229,18 +210,15 @@ def __init__(self, query, multi_index_mode="off"): else: super().__init__() assert isinstance(query, list) - assert multi_index_mode in ["off", "all", "any"] for qnode in query: if isinstance(qnode, dict): - self._add_node(predicate=_process_predicate(qnode, multi_index_mode)) + self._add_node(predicate=_process_predicate(qnode)) elif isinstance(qnode, str) or isinstance(qnode, int): self._add_node(quantifer=qnode) elif isinstance(qnode, tuple): assert isinstance(qnode[1], dict) if isinstance(qnode[0], str) or isinstance(qnode[0], int): - self._add_node( - qnode[0], _process_predicate(qnode[1], multi_index_mode) - ) + self._add_node(qnode[0], _process_predicate(qnode[1])) else: raise InvalidQueryPath( "The first value of a tuple entry in a path must be either a string or integer." @@ -249,3 +227,4 @@ def __init__(self, query, multi_index_mode="off"): raise InvalidQueryPath( "A query path must be a list containing String, Integer, Dict, or Tuple elements" ) + self.default_aggregator = "all" diff --git a/hatchet/query/query.py b/hatchet/query/query.py index 39f17743..28330311 100644 --- a/hatchet/query/query.py +++ b/hatchet/query/query.py @@ -12,6 +12,7 @@ class Query(object): def __init__(self): """Create new Query""" self.query_pattern = [] + self.default_aggregator = "off" def match(self, quantifier=".", predicate=lambda row: True): """Start a query with a root node described by the arguments. diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 791128fe..b51f35a0 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -97,7 +97,7 @@ def filter_check_types(type_check, df_row, filt_lambda): class StringQuery(Query): """Class for representing and parsing queries using the String-based dialect.""" - def __init__(self, cypher_query, multi_index_mode="off"): + def __init__(self, cypher_query): """Builds a new StringQuery object representing a query in the String-based dialect. Arguments: @@ -107,8 +107,6 @@ def __init__(self, cypher_query, multi_index_mode="off"): super(StringQuery, self).__init__() else: super().__init__() - assert multi_index_mode in ["off", "all", "any"] - self.multi_index_mode = multi_index_mode model = None try: model = cypher_query_mm.model_from_str(cypher_query) @@ -127,6 +125,7 @@ def __init__(self, cypher_query, multi_index_mode="off"): self.lambda_filters = [None for _ in self.wcards] self._build_lambdas() self._build_query() + self.default_aggregator = "all" def _build_query(self): """Builds the entire query using 'match' and 'rel' using @@ -258,13 +257,6 @@ def _parse_not_cond(self, obj): converted_subcond[2] = "not {}".format(converted_subcond[2]) return converted_subcond - def _run_method_based_on_multi_idx_mode(self, method_name, obj): - real_method_name = method_name - if self.multi_index_mode != "off": - real_method_name = method_name + "_multi_idx" - method = eval("StringQuery.{}".format(real_method_name)) - return method(self, obj) - def _parse_single_cond(self, obj): """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): @@ -272,13 +264,13 @@ def _parse_single_cond(self, obj): if self._is_num_cond(obj): return self._parse_num(obj) if cname(obj) == "NoneCond": - return self._run_method_based_on_multi_idx_mode("_parse_none", obj) + return self._parse_none(obj) if cname(obj) == "NotNoneCond": - return self._run_method_based_on_multi_idx_mode("_parse_not_none", obj) + return self._parse_not_none(obj) if cname(obj) == "LeafCond": - return self._run_method_based_on_multi_idx_mode("_parse_leaf", obj) + return self._parse_leaf(obj) if cname(obj) == "NotLeafCond": - return self._run_method_based_on_multi_idx_mode("_parse_not_leaf", obj) + return self._parse_not_leaf(obj) raise RuntimeError("Bad Single Condition") def _parse_none(self, obj): @@ -308,39 +300,6 @@ def _parse_none(self, obj): None, ] - def _add_aggregation_call_to_multi_idx_predicate(self, predicate): - if self.multi_index_mode == "any": - return predicate + ".any()" - return predicate + ".all()" - - def _parse_none_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth is None", - None, - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid is None", - None, - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem is None)".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) - ), - None, - ] - def _parse_not_none(self, obj): """Parses 'property IS NOT NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -368,34 +327,6 @@ def _parse_not_none(self, obj): None, ] - def _parse_not_none_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth is not None", - None, - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid is not None", - None, - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem is not None)".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) - ), - None, - ] - def _parse_leaf(self, obj): """Parses 'node IS LEAF'.""" return [ @@ -405,14 +336,6 @@ def _parse_leaf(self, obj): None, ] - def _parse_leaf_multi_idx(self, obj): - return [ - None, - obj.name, - "len(df_row.index.get_level_values('node')[0].children) == 0", - None, - ] - def _parse_not_leaf(self, obj): """Parses 'node IS NOT LEAF'.""" return [ @@ -422,14 +345,6 @@ def _parse_not_leaf(self, obj): None, ] - def _parse_not_leaf_multi_idx(self, obj): - return [ - None, - obj.name, - "len(df_row.index.get_level_values('node')[0].children) > 0", - None, - ] - def _is_str_cond(self, obj): """Determines whether a predicate is for string data.""" if cname(obj) in [ @@ -463,17 +378,15 @@ def _parse_str(self, obj): to the correct function. """ if cname(obj) == "StringEq": - return self._run_method_based_on_multi_idx_mode("_parse_str_eq", obj) + return self._parse_str_eq(obj) if cname(obj) == "StringStartsWith": - return self._run_method_based_on_multi_idx_mode( - "_parse_str_starts_with", obj - ) + return self._parse_str_starts_with(obj) if cname(obj) == "StringEndsWith": - return self._run_method_based_on_multi_idx_mode("_parse_str_ends_with", obj) + return self._parse_str_ends_with(obj) if cname(obj) == "StringContains": - return self._run_method_based_on_multi_idx_mode("_parse_str_contains", obj) + return self._parse_str_contains(obj) if cname(obj) == "StringMatch": - return self._run_method_based_on_multi_idx_mode("_parse_str_match", obj) + return self._parse_str_match(obj) raise RuntimeError("Bad String Op Class") def _parse_str_eq(self, obj): @@ -496,27 +409,6 @@ def _parse_str_eq(self, obj): ), ] - def _parse_str_eq_multi_idx(self, obj): - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: elem == "{}")'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_string_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_str_starts_with(self, obj): """Processes string 'startswith' predicates.""" return [ @@ -537,27 +429,6 @@ def _parse_str_starts_with(self, obj): ), ] - def _parse_str_starts_with_multi_idx(self, obj): - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: elem.startswith("{}"))'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_string_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_str_ends_with(self, obj): """Processes string 'endswith' predicates.""" return [ @@ -578,27 +449,6 @@ def _parse_str_ends_with(self, obj): ), ] - def _parse_str_ends_with_multi_idx(self, obj): - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: elem.endswith("{}"))'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_string_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_str_contains(self, obj): """Processes string 'contains' predicates.""" return [ @@ -619,27 +469,6 @@ def _parse_str_contains(self, obj): ), ] - def _parse_str_contains_multi_idx(self, obj): - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: "{}" in elem)'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_string_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_str_match(self, obj): """Processes string regex match predicates.""" return [ @@ -660,49 +489,28 @@ def _parse_str_match(self, obj): ), ] - def _parse_str_match_multi_idx(self, obj): - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - 'df_row[{}].apply(lambda elem: re.match("{}", elem) is not None)'.format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_string_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num(self, obj): """Function that redirects processing of numeric predicates to the correct function. """ if cname(obj) == "NumEq": - return self._run_method_based_on_multi_idx_mode("_parse_num_eq", obj) + return self._parse_num_eq(obj) if cname(obj) == "NumLt": - return self._run_method_based_on_multi_idx_mode("_parse_num_lt", obj) + return self._parse_num_lt(obj) if cname(obj) == "NumGt": - return self._run_method_based_on_multi_idx_mode("_parse_num_gt", obj) + return self._parse_num_gt(obj) if cname(obj) == "NumLte": - return self._run_method_based_on_multi_idx_mode("_parse_num_lte", obj) + return self._parse_num_lte(obj) if cname(obj) == "NumGte": - return self._run_method_based_on_multi_idx_mode("_parse_num_gte", obj) + return self._parse_num_gte(obj) if cname(obj) == "NumNan": - return self._run_method_based_on_multi_idx_mode("_parse_num_nan", obj) + return self._parse_num_nan(obj) if cname(obj) == "NumNotNan": - return self._run_method_based_on_multi_idx_mode("_parse_num_not_nan", obj) + return self._parse_num_not_nan(obj) if cname(obj) == "NumInf": - return self._run_method_based_on_multi_idx_mode("_parse_num_inf", obj) + return self._parse_num_inf(obj) if cname(obj) == "NumNotInf": - return self._run_method_based_on_multi_idx_mode("_parse_num_not_inf", obj) + return self._parse_num_not_inf(obj) raise RuntimeError("Bad Number Op Class") def _parse_num_eq(self, obj): @@ -722,9 +530,7 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -747,9 +553,7 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -782,86 +586,6 @@ def _parse_num_eq(self, obj): ), ] - def _parse_num_eq_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - if obj.val == -1: - return [ - None, - obj.name, - "len(df_row.index.get_level_values('node')[0].children) == 0", - None, - ] - elif obj.val < 0: - warnings.warn( - """ - The 'depth' property of a Node is strictly non-negative. - This condition will always be false. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "False", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth == {}".format(obj.val), - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - if obj.val < 0: - warnings.warn( - """ - The 'node_id' property of a Node is strictly non-negative. - This condition will always be false. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "False", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid == {}".format( - obj.val - ), - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem == {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_lt(self, obj): """Processes numeric less-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -872,9 +596,7 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -897,9 +619,7 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -932,79 +652,6 @@ def _parse_num_lt(self, obj): ), ] - def _parse_num_lt_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - if obj.val < 0: - warnings.warn( - """ - The 'depth' property of a Node is strictly non-negative. - This condition will always be false. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "False", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth < {}".format(obj.val), - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - if obj.val < 0: - warnings.warn( - """ - The 'node_id' property of a Node is strictly non-negative. - This condition will always be false. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "False", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid < {}".format( - obj.val - ), - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem < {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_gt(self, obj): """Processes numeric greater-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1015,9 +662,7 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1040,9 +685,7 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1075,79 +718,6 @@ def _parse_num_gt(self, obj): ), ] - def _parse_num_gt_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - if obj.val < 0: - warnings.warn( - """ - The 'depth' property of a Node is strictly non-negative. - This condition will always be true. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "True", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth > {}".format(obj.val), - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - if obj.val < 0: - warnings.warn( - """ - The 'node_id' property of a Node is strictly non-negative. - This condition will always be true. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "True", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid > {}".format( - obj.val - ), - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem > {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_lte(self, obj): """Processes numeric less-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1158,9 +728,7 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1183,9 +751,7 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1218,79 +784,6 @@ def _parse_num_lte(self, obj): ), ] - def _parse_num_lte_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - if obj.val < 0: - warnings.warn( - """ - The 'depth' property of a Node is strictly non-negative. - This condition will always be false. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "False", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth <= {}".format(obj.val), - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - if obj.val < 0: - warnings.warn( - """ - The 'node_id' property of a Node is strictly non-negative. - This condition will always be false. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "False", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid <= {}".format( - obj.val - ), - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem <= {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_gte(self, obj): """Processes numeric greater-than-or-equal-to predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1301,9 +794,7 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1326,9 +817,7 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -1361,79 +850,6 @@ def _parse_num_gte(self, obj): ), ] - def _parse_num_gte_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - if obj.val < 0: - warnings.warn( - """ - The 'depth' property of a Node is strictly non-negative. - This condition will always be true. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "True", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._depth >= {}".format(obj.val), - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids == "node_id": - if obj.val < 0: - warnings.warn( - """ - The 'node_id' property of a Node is strictly non-negative. - This condition will always be true. - The statement that triggered this warning is: - {} - """.format( - obj - ), - RedundantQueryFilterWarning, - ) - return [ - None, - obj.name, - "True", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - "df_row.index.get_level_values('node')[0]._hatchet_nid >= {}".format( - obj.val - ), - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "df_row[{}].apply(lambda elem: elem >= {})".format( - ( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - obj.val, - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_nan(self, obj): """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1465,38 +881,6 @@ def _parse_num_nan(self, obj): ), ] - def _parse_num_nan_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - return [ - None, - obj.name, - "pd.isna(df_row.index.get_level_values('node')[0]._depth)", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - return [ - None, - obj.name, - "pd.isna(df_row.index.get_level_values('node')[0]._hatchet_nid)", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "pd.isna(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_not_nan(self, obj): """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1528,38 +912,6 @@ def _parse_num_not_nan(self, obj): ), ] - def _parse_num_not_nan_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - return [ - None, - obj.name, - "not pd.isna(df_row.index.get_level_values('node')[0]._depth)", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - return [ - None, - obj.name, - "not pd.isna(df_row.index.get_level_values('node')[0]._hatchet_nid)", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "not pd.isna(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_inf(self, obj): """Processes predicates that check for Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1591,38 +943,6 @@ def _parse_num_inf(self, obj): ), ] - def _parse_num_inf_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - return [ - None, - obj.name, - "np.isinf(df_row.index.get_level_values('node')[0]._depth)", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - return [ - None, - obj.name, - "np.isinf(df_row.index.get_level_values('node')[0]._hatchet_nid)", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "np.isinf(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - def _parse_num_not_inf(self, obj): """Processes predicates that check for not-Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -1654,40 +974,8 @@ def _parse_num_not_inf(self, obj): ), ] - def _parse_num_not_inf_multi_idx(self, obj): - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - return [ - None, - obj.name, - "not np.isinf(df_row.index.get_level_values('node')[0]._depth)", - "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", - ] - if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": - return [ - None, - obj.name, - "not np.isinf(df_row.index.get_level_values('node')[0]._hatchet_nid)", - "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", - ] - return [ - None, - obj.name, - self._add_aggregation_call_to_multi_idx_predicate( - "not np.isinf(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ) - ), - "is_numeric_dtype(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), - ] - -def parse_string_dialect(query_str, multi_index_mode="off"): +def parse_string_dialect(query_str): """Parse all types of String-based queries, including multi-queries that leverage the curly brace delimiters. @@ -1709,7 +997,7 @@ def parse_string_dialect(query_str, multi_index_mode="off"): if num_curly_brace_elems == 0: if sys.version_info[0] == 2: query_str = query_str.decode("utf-8") - return StringQuery(query_str, multi_index_mode) + return StringQuery(query_str) # Create an iterator over the curly brace-delimited regions curly_brace_iter = re.finditer(r"\{(.*?)\}", query_str) # Will store curly brace-delimited regions in the WHERE clause @@ -1798,14 +1086,14 @@ def parse_string_dialect(query_str, multi_index_mode="off"): query1 = "MATCH {} WHERE {}".format(match_comp, condition_list[i]) if sys.version_info[0] == 2: query1 = query1.decode("utf-8") - full_query = StringQuery(query1, multi_index_mode) + full_query = StringQuery(query1) # Get the next query as a CypherQuery where # the MATCH clause is the shared match clause and the WHERE clause is the # next curly brace-delimited region next_query = "MATCH {} WHERE {}".format(match_comp, condition_list[i + 1]) if sys.version_info[0] == 2: next_query = next_query.decode("utf-8") - next_query = StringQuery(next_query, multi_index_mode) + next_query = StringQuery(next_query) # Add the next query to the full query using the compound operator # currently being considered if op == "AND": From f65fb5b27cd517441e56d03b35297899eb29f436 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Thu, 7 Nov 2024 22:17:21 -0500 Subject: [PATCH 02/20] Fixes unit tests --- hatchet/query/string_dialect.py | 40 +++++++++++++++++++++++-------- hatchet/tests/query.py | 42 ++++++++++++++++----------------- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index b51f35a0..fd79e320 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -530,7 +530,9 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -553,7 +555,9 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -596,7 +600,9 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -619,7 +625,9 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -662,7 +670,9 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -685,7 +695,9 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -728,7 +740,9 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -751,7 +765,9 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -794,7 +810,9 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -817,7 +835,9 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index a8082dea..114fd498 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -401,9 +401,9 @@ def __init__(self): self.z = "hello" bad_field_test_dict = list(mock_graph_literal) - bad_field_test_dict[0]["children"][0]["children"][0]["metrics"][ - "list" - ] = DummyType() + bad_field_test_dict[0]["children"][0]["children"][0]["metrics"]["list"] = ( + DummyType() + ) gf = GraphFrame.from_literal(bad_field_test_dict) path = [{"name": "foo"}, {"name": "bar"}, {"list": DummyType()}] query = ObjectQuery(path) @@ -512,7 +512,7 @@ def test_apply_indices(calc_pi_hpct_db): ], ] matches = list(set().union(*matches)) - query = ObjectQuery(path, multi_index_mode="all") + query = ObjectQuery(path, predicate_row_aggregator="all") engine = QueryEngine() assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) @@ -588,7 +588,7 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db): gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db)) root = gf.graph.roots[0] - query = ObjectQuery([("*", {"depth": "<= 2"})], multi_index_mode="all") + query = ObjectQuery([("*", {"depth": "<= 2"})], predicate_row_aggregator="all") engine = QueryEngine() matches = [ [root, root.children[0], root.children[0].children[0]], @@ -601,12 +601,12 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db): matches = list(set().union(*matches)) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) - query = ObjectQuery([("*", {"depth": 0})], multi_index_mode="all") + query = ObjectQuery([("*", {"depth": 0})], predicate_row_aggregator="all") matches = [root] assert engine.apply(query, gf.graph, gf.dataframe) == matches with pytest.raises(InvalidQueryFilter): - query = ObjectQuery([{"depth": "hello"}], multi_index_mode="all") + query = ObjectQuery([{"depth": "hello"}], predicate_row_aggregator="all") engine.apply(query, gf.graph, gf.dataframe) @@ -614,7 +614,7 @@ def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db)) root = gf.graph.roots[0] - query = ObjectQuery([("*", {"node_id": "<= 2"})], multi_index_mode="all") + query = ObjectQuery([("*", {"node_id": "<= 2"})], predicate_row_aggregator="all") engine = QueryEngine() matches = [ [root, root.children[0]], @@ -626,12 +626,12 @@ def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): matches = list(set().union(*matches)) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) - query = ObjectQuery([("*", {"node_id": 0})], multi_index_mode="all") + query = ObjectQuery([("*", {"node_id": 0})], predicate_row_aggregator="all") matches = [root] assert engine.apply(query, gf.graph, gf.dataframe) == matches with pytest.raises(InvalidQueryFilter): - query = ObjectQuery([{"node_id": "hello"}], multi_index_mode="all") + query = ObjectQuery([{"node_id": "hello"}], predicate_row_aggregator="all") engine.apply(query, gf.graph, gf.dataframe) @@ -1028,9 +1028,9 @@ def __init__(self): self.z = "hello" bad_field_test_dict = list(mock_graph_literal) - bad_field_test_dict[0]["children"][0]["children"][0]["metrics"][ - "list" - ] = DummyType() + bad_field_test_dict[0]["children"][0]["children"][0]["metrics"]["list"] = ( + DummyType() + ) gf = GraphFrame.from_literal(bad_field_test_dict) path = """MATCH (p)->(q)->(r) WHERE p."name" = "foo" AND q."name" = "bar" AND p."list" = DummyType() @@ -1283,7 +1283,7 @@ def test_object_dialect_all_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() query = ObjectQuery( - [".", ("+", {"time (inc)": ">= 17983.0"})], multi_index_mode="all" + [".", ("+", {"time (inc)": ">= 17983.0"})], predicate_row_aggregator="all" ) roots = gf.graph.roots matches = [ @@ -1302,7 +1302,7 @@ def test_string_dialect_all_mode(tau_profile_dir): """MATCH (".")->("+", p) WHERE p."time (inc)" >= 17983.0 """, - multi_index_mode="all", + predicate_row_aggregator="all", ) roots = gf.graph.roots matches = [ @@ -1317,7 +1317,7 @@ def test_string_dialect_all_mode(tau_profile_dir): def test_object_dialect_any_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() - query = ObjectQuery([{"time": "< 24.0"}], multi_index_mode="any") + query = ObjectQuery([{"time": "< 24.0"}], predicate_row_aggregator="any") roots = gf.graph.roots matches = [ roots[0].children[2], @@ -1333,7 +1333,7 @@ def test_string_dialect_any_mode(tau_profile_dir): """MATCH (".", p) WHERE p."time" < 24.0 """, - multi_index_mode="any", + predicate_row_aggregator="any", ) roots = gf.graph.roots matches = [ @@ -1343,19 +1343,19 @@ def test_string_dialect_any_mode(tau_profile_dir): assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) -def test_multi_index_mode_assertion_error(tau_profile_dir): +def test_predicate_row_aggregator_assertion_error(tau_profile_dir): with pytest.raises(AssertionError): - _ = ObjectQuery([".", ("*", {"name": "test"})], multi_index_mode="foo") + _ = ObjectQuery([".", ("*", {"name": "test"})], predicate_row_aggregator="foo") with pytest.raises(AssertionError): _ = StringQuery( """ MATCH (".")->("*", p) WHERE p."name" = "test" """, - multi_index_mode="foo", + predicate_row_aggregator="foo", ) gf = GraphFrame.from_tau(tau_profile_dir) query = ObjectQuery( - [".", ("*", {"time (inc)": "> 17983.0"})], multi_index_mode="off" + [".", ("*", {"time (inc)": "> 17983.0"})], predicate_row_aggregator="off" ) engine = QueryEngine() with pytest.raises(MultiIndexModeMismatch): From 8e076602c3424357076ce089649edc63ebd0ed2c Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:01:13 -0500 Subject: [PATCH 03/20] Formatting --- hatchet/tests/query.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 114fd498..678be084 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -401,9 +401,9 @@ def __init__(self): self.z = "hello" bad_field_test_dict = list(mock_graph_literal) - bad_field_test_dict[0]["children"][0]["children"][0]["metrics"]["list"] = ( - DummyType() - ) + bad_field_test_dict[0]["children"][0]["children"][0]["metrics"][ + "list" + ] = DummyType() gf = GraphFrame.from_literal(bad_field_test_dict) path = [{"name": "foo"}, {"name": "bar"}, {"list": DummyType()}] query = ObjectQuery(path) @@ -1028,9 +1028,9 @@ def __init__(self): self.z = "hello" bad_field_test_dict = list(mock_graph_literal) - bad_field_test_dict[0]["children"][0]["children"][0]["metrics"]["list"] = ( - DummyType() - ) + bad_field_test_dict[0]["children"][0]["children"][0]["metrics"][ + "list" + ] = DummyType() gf = GraphFrame.from_literal(bad_field_test_dict) path = """MATCH (p)->(q)->(r) WHERE p."name" = "foo" AND q."name" = "bar" AND p."list" = DummyType() From 86da99c372b9319da8cc479e90ce072e6413c971 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:03:29 -0500 Subject: [PATCH 04/20] Removes MultiIndexModeMismatch --- hatchet/query/errors.py | 6 ------ hatchet/query/object_dialect.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/hatchet/query/errors.py b/hatchet/query/errors.py index 248c6a28..ebb835f0 100644 --- a/hatchet/query/errors.py +++ b/hatchet/query/errors.py @@ -18,9 +18,3 @@ class RedundantQueryFilterWarning(Warning): class BadNumberNaryQueryArgs(Exception): """Raised when a query filter does not have a valid syntax""" - - -class MultiIndexModeMismatch(Exception): - """Raised when an ObjectQuery or StringQuery object - is set to use multi-indexed data, but no multi-indexed - data is provided""" diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index 4c8114f2..f4767565 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -13,7 +13,7 @@ import re import sys -from .errors import InvalidQueryPath, InvalidQueryFilter, MultiIndexModeMismatch +from .errors import InvalidQueryPath, InvalidQueryFilter from .query import Query From 7ad292ed764b333fe62f7ed78083be3e12894a00 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:26:41 -0500 Subject: [PATCH 05/20] Fixes logic for handling string values of predicate_row_aggregator --- hatchet/query/engine.py | 36 +++++++++++++++++++++- hatchet/tests/query.py | 67 ++++++++++++++++++----------------------- 2 files changed, 64 insertions(+), 39 deletions(-) diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index ec7bf661..40ccbc16 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT from itertools import groupby +from collections.abc import Iterable import pandas as pd from .errors import InvalidQueryFilter @@ -14,6 +15,22 @@ from .string_dialect import parse_string_dialect +def _all_aggregator(pred_result): + if isinstance(pred_result, Iterable): + return all(pred_result) + elif isinstance(pred_result, pd.Series): + return pred_result.all() + return pred_result + + +def _any_aggregator(pred_result): + if isinstance(pred_result, Iterable): + return any(pred_result) + elif isinstance(pred_result, pd.Series): + return pred_result.any() + return pred_result + + class QueryEngine: """Class for applying queries to GraphFrames.""" @@ -40,6 +57,20 @@ def apply(self, query, graph, dframe, predicate_row_aggregator): aggregator = predicate_row_aggregator if predicate_row_aggregator is None: aggregator = query.default_aggregator + elif predicate_row_aggregator == "all": + aggregator = _all_aggregator + elif predicate_row_aggregator == "any": + aggregator = _any_aggregator + elif predicate_row_aggregator == "off": + if isinstance(dframe.index, pd.MultiIndex): + raise ValueError( + "'predicate_row_aggregator' cannot be 'off' when the DataFrame has a row multi-index" + ) + aggregator = None + elif not callable(predicate_row_aggregator): + raise ValueError( + "Invalid value provided for 'predicate_row_aggregator'" + ) self.reset_cache() matches = [] visited = set() @@ -84,7 +115,10 @@ def _cache_node(self, node, query, dframe, predicate_row_aggregator): else: row = dframe.loc[node] predicate_result = filter_func(row) - if not isinstance(predicate_result, bool): + if ( + not isinstance(predicate_result, bool) + and predicate_row_aggregator is not None + ): predicate_result = predicate_row_aggregator(predicate_result) if predicate_result: matches.append(i) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 678be084..324a881c 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -25,7 +25,6 @@ ExclusiveDisjunctionQuery, NegationQuery, ) -from hatchet.query.errors import MultiIndexModeMismatch def test_construct_object_dialect(): @@ -512,9 +511,9 @@ def test_apply_indices(calc_pi_hpct_db): ], ] matches = list(set().union(*matches)) - query = ObjectQuery(path, predicate_row_aggregator="all") + query = ObjectQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) gf.drop_index_levels() assert engine.apply(query, gf.graph, gf.dataframe) == matches @@ -588,7 +587,7 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db): gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db)) root = gf.graph.roots[0] - query = ObjectQuery([("*", {"depth": "<= 2"})], predicate_row_aggregator="all") + query = ObjectQuery([("*", {"depth": "<= 2"})]) engine = QueryEngine() matches = [ [root, root.children[0], root.children[0].children[0]], @@ -599,22 +598,22 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db): [root.children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) - query = ObjectQuery([("*", {"depth": 0})], predicate_row_aggregator="all") + query = ObjectQuery([("*", {"depth": 0})]) matches = [root] - assert engine.apply(query, gf.graph, gf.dataframe) == matches + assert engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") == matches with pytest.raises(InvalidQueryFilter): - query = ObjectQuery([{"depth": "hello"}], predicate_row_aggregator="all") - engine.apply(query, gf.graph, gf.dataframe) + query = ObjectQuery([{"depth": "hello"}]) + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): gf = GraphFrame.from_hpctoolkit(str(calc_pi_hpct_db)) root = gf.graph.roots[0] - query = ObjectQuery([("*", {"node_id": "<= 2"})], predicate_row_aggregator="all") + query = ObjectQuery([("*", {"node_id": "<= 2"})]) engine = QueryEngine() matches = [ [root, root.children[0]], @@ -624,15 +623,15 @@ def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): [root.children[0].children[0]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) - query = ObjectQuery([("*", {"node_id": 0})], predicate_row_aggregator="all") + query = ObjectQuery([("*", {"node_id": 0})]) matches = [root] - assert engine.apply(query, gf.graph, gf.dataframe) == matches + assert engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") == matches with pytest.raises(InvalidQueryFilter): - query = ObjectQuery([{"node_id": "hello"}], predicate_row_aggregator="all") - engine.apply(query, gf.graph, gf.dataframe) + query = ObjectQuery([{"node_id": "hello"}]) + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") def test_object_dialect_multi_condition_one_attribute(mock_graph_literal): @@ -1283,7 +1282,7 @@ def test_object_dialect_all_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() query = ObjectQuery( - [".", ("+", {"time (inc)": ">= 17983.0"})], predicate_row_aggregator="all" + [".", ("+", {"time (inc)": ">= 17983.0"})] ) roots = gf.graph.roots matches = [ @@ -1292,7 +1291,7 @@ def test_object_dialect_all_mode(tau_profile_dir): roots[0].children[6].children[1], roots[0].children[0], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) def test_string_dialect_all_mode(tau_profile_dir): @@ -1301,8 +1300,7 @@ def test_string_dialect_all_mode(tau_profile_dir): query = StringQuery( """MATCH (".")->("+", p) WHERE p."time (inc)" >= 17983.0 - """, - predicate_row_aggregator="all", + """ ) roots = gf.graph.roots matches = [ @@ -1311,19 +1309,19 @@ def test_string_dialect_all_mode(tau_profile_dir): roots[0].children[6].children[1], roots[0].children[0], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) def test_object_dialect_any_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() - query = ObjectQuery([{"time": "< 24.0"}], predicate_row_aggregator="any") + query = ObjectQuery([{"time": "< 24.0"}]) roots = gf.graph.roots matches = [ roots[0].children[2], roots[0].children[6].children[3], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any")) == sorted(matches) def test_string_dialect_any_mode(tau_profile_dir): @@ -1332,31 +1330,24 @@ def test_string_dialect_any_mode(tau_profile_dir): query = StringQuery( """MATCH (".", p) WHERE p."time" < 24.0 - """, - predicate_row_aggregator="any", + """ ) roots = gf.graph.roots matches = [ roots[0].children[2], roots[0].children[6].children[3], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any")) == sorted(matches) def test_predicate_row_aggregator_assertion_error(tau_profile_dir): - with pytest.raises(AssertionError): - _ = ObjectQuery([".", ("*", {"name": "test"})], predicate_row_aggregator="foo") - with pytest.raises(AssertionError): - _ = StringQuery( - """ MATCH (".")->("*", p) - WHERE p."name" = "test" - """, - predicate_row_aggregator="foo", - ) gf = GraphFrame.from_tau(tau_profile_dir) + engine = QueryEngine() + query = ObjectQuery([".", ("*", {"name": "test"})]) + with pytest.raises(ValueError): + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="foo") query = ObjectQuery( - [".", ("*", {"time (inc)": "> 17983.0"})], predicate_row_aggregator="off" + [".", ("*", {"time (inc)": "> 17983.0"})] ) - engine = QueryEngine() - with pytest.raises(MultiIndexModeMismatch): - engine.apply(query, gf.graph, gf.dataframe) + with pytest.raises(ValueError): + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="off") From 9159da0f0a8f44ebacd0aa4a9f9e438d1d732468 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:33:39 -0500 Subject: [PATCH 06/20] Formatting --- hatchet/tests/query.py | 46 ++++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 324a881c..764bd53d 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -513,7 +513,9 @@ def test_apply_indices(calc_pi_hpct_db): matches = list(set().union(*matches)) query = ObjectQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) gf.drop_index_levels() assert engine.apply(query, gf.graph, gf.dataframe) == matches @@ -598,11 +600,16 @@ def test_object_dialect_depth_index_levels(calc_pi_hpct_db): [root.children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) query = ObjectQuery([("*", {"depth": 0})]) matches = [root] - assert engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") == matches + assert ( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + == matches + ) with pytest.raises(InvalidQueryFilter): query = ObjectQuery([{"depth": "hello"}]) @@ -623,11 +630,16 @@ def test_object_dialect_node_id_index_levels(calc_pi_hpct_db): [root.children[0].children[0]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) query = ObjectQuery([("*", {"node_id": 0})]) matches = [root] - assert engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") == matches + assert ( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + == matches + ) with pytest.raises(InvalidQueryFilter): query = ObjectQuery([{"node_id": "hello"}]) @@ -1281,9 +1293,7 @@ def test_leaf_query(small_mock2): def test_object_dialect_all_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() - query = ObjectQuery( - [".", ("+", {"time (inc)": ">= 17983.0"})] - ) + query = ObjectQuery([".", ("+", {"time (inc)": ">= 17983.0"})]) roots = gf.graph.roots matches = [ roots[0], @@ -1291,7 +1301,9 @@ def test_object_dialect_all_mode(tau_profile_dir): roots[0].children[6].children[1], roots[0].children[0], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) def test_string_dialect_all_mode(tau_profile_dir): @@ -1309,7 +1321,9 @@ def test_string_dialect_all_mode(tau_profile_dir): roots[0].children[6].children[1], roots[0].children[0], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="all") + ) == sorted(matches) def test_object_dialect_any_mode(tau_profile_dir): @@ -1321,7 +1335,9 @@ def test_object_dialect_any_mode(tau_profile_dir): roots[0].children[2], roots[0].children[6].children[3], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any") + ) == sorted(matches) def test_string_dialect_any_mode(tau_profile_dir): @@ -1337,7 +1353,9 @@ def test_string_dialect_any_mode(tau_profile_dir): roots[0].children[2], roots[0].children[6].children[3], ] - assert sorted(engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any")) == sorted(matches) + assert sorted( + engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="any") + ) == sorted(matches) def test_predicate_row_aggregator_assertion_error(tau_profile_dir): @@ -1346,8 +1364,6 @@ def test_predicate_row_aggregator_assertion_error(tau_profile_dir): query = ObjectQuery([".", ("*", {"name": "test"})]) with pytest.raises(ValueError): engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="foo") - query = ObjectQuery( - [".", ("*", {"time (inc)": "> 17983.0"})] - ) + query = ObjectQuery([".", ("*", {"time (inc)": "> 17983.0"})]) with pytest.raises(ValueError): engine.apply(query, gf.graph, gf.dataframe, predicate_row_aggregator="off") From 3be62760ee855f4918d8dc7a17256d0706aebdb1 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:41:55 -0500 Subject: [PATCH 07/20] Fixes a condition to properly parse the default aggregators --- hatchet/query/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index 40ccbc16..31a66291 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -57,7 +57,7 @@ def apply(self, query, graph, dframe, predicate_row_aggregator): aggregator = predicate_row_aggregator if predicate_row_aggregator is None: aggregator = query.default_aggregator - elif predicate_row_aggregator == "all": + if predicate_row_aggregator == "all": aggregator = _all_aggregator elif predicate_row_aggregator == "any": aggregator = _any_aggregator From 0c5c52bbf813d4c613acad469464c4edfe76ea05 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:49:27 -0500 Subject: [PATCH 08/20] Fixes a few testing bugs --- hatchet/query/engine.py | 8 ++++---- hatchet/tests/query.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index 31a66291..a515112c 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -57,17 +57,17 @@ def apply(self, query, graph, dframe, predicate_row_aggregator): aggregator = predicate_row_aggregator if predicate_row_aggregator is None: aggregator = query.default_aggregator - if predicate_row_aggregator == "all": + if aggregator == "all": aggregator = _all_aggregator - elif predicate_row_aggregator == "any": + elif aggregator == "any": aggregator = _any_aggregator - elif predicate_row_aggregator == "off": + elif aggregator == "off": if isinstance(dframe.index, pd.MultiIndex): raise ValueError( "'predicate_row_aggregator' cannot be 'off' when the DataFrame has a row multi-index" ) aggregator = None - elif not callable(predicate_row_aggregator): + elif not callable(aggregator): raise ValueError( "Invalid value provided for 'predicate_row_aggregator'" ) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 764bd53d..eddefbc2 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -238,7 +238,7 @@ def test_node_caching(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - engine._cache_node(node, query, gf.dataframe) + engine._cache_node(node, query, gf.dataframe, None) assert 0 in engine.search_cache[node._hatchet_nid] assert 1 in engine.search_cache[node._hatchet_nid] @@ -269,12 +269,12 @@ def test_match_0_or_more_wildcard(mock_graph_literal): engine = QueryEngine() matched_paths = [] for child in sorted(node.children, key=traversal_order): - match = engine._match_0_or_more(query, gf.dataframe, child, 1) + match = engine._match_0_or_more(query, gf.dataframe, None, child, 1) if match is not None: matched_paths.extend(match) assert sorted(matched_paths, key=len) == sorted(correct_paths, key=len) - assert engine._match_0_or_more(query, gf.dataframe, none_node, 1) is None + assert engine._match_0_or_more(query, gf.dataframe, None, none_node, 1) is None def test_match_1(mock_graph_literal): @@ -287,10 +287,10 @@ def test_match_1(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - assert engine._match_1(query, gf.dataframe, gf.graph.roots[0].children[0], 2) == [ + assert engine._match_1(query, gf.dataframe, None, gf.graph.roots[0].children[0], 2) == [ [gf.graph.roots[0].children[0].children[1]] ] - assert engine._match_1(query, gf.dataframe, gf.graph.roots[0], 2) is None + assert engine._match_1(query, gf.dataframe, None, gf.graph.roots[0], 2) is None def test_match(mock_graph_literal): From cae50b590c56573680f4fe8039f8a3b92f5d1944 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 16:51:14 -0500 Subject: [PATCH 09/20] Formatting --- hatchet/tests/query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index eddefbc2..b9020e77 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -287,9 +287,9 @@ def test_match_1(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - assert engine._match_1(query, gf.dataframe, None, gf.graph.roots[0].children[0], 2) == [ - [gf.graph.roots[0].children[0].children[1]] - ] + assert engine._match_1( + query, gf.dataframe, None, gf.graph.roots[0].children[0], 2 + ) == [[gf.graph.roots[0].children[0].children[1]]] assert engine._match_1(query, gf.dataframe, None, gf.graph.roots[0], 2) is None From bad04109b1d5c7aa42723269d70f9504fa1332d3 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Mon, 11 Nov 2024 17:25:03 -0500 Subject: [PATCH 10/20] Restores special logic for multi-index in the string dialect --- hatchet/query/compat.py | 4 +- hatchet/query/string_dialect.py | 889 +++++++++++++++++++++++++++----- 2 files changed, 772 insertions(+), 121 deletions(-) diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index d62a0c5c..848b7572 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -125,7 +125,7 @@ def apply(self, gf): (list): A list of nodes representing the result of the query """ true_query = self._get_new_query() - return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe) + return COMPATABILITY_ENGINE.apply(true_query, gf.graph, gf.dataframe, "off") def _get_new_query(self): """Gets all the underlying 'new-style' queries in this object. @@ -322,7 +322,7 @@ def apply(self, gf): Returns: (list): A list representing the set of nodes from paths that match this query """ - return COMPATABILITY_ENGINE.apply(self.true_query, gf.graph, gf.dataframe) + return COMPATABILITY_ENGINE.apply(self.true_query, gf.graph, gf.dataframe, "off") def _get_new_query(self): """Get all the underlying 'new-style' query in this object. diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index fd79e320..878910bb 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -107,9 +107,9 @@ def __init__(self, cypher_query): super(StringQuery, self).__init__() else: super().__init__() - model = None + self.model = None try: - model = cypher_query_mm.model_from_str(cypher_query) + self.model = cypher_query_mm.model_from_str(cypher_query) except TextXError as e: # TODO Change to a "raise-from" expression when Python 2.7 support is dropped raise InvalidQueryPath( @@ -119,13 +119,16 @@ def __init__(self, cypher_query): ) self.wcards = [] self.wcard_pos = {} - self._parse_path(model.path_expr) + self.default_aggregator = "all" + + def parse(self, dframe): + has_multi_index = isinstance(dframe.index, pd.MultiIndex) + self._parse_path(self.model.path_expr) self.filters = [[] for _ in self.wcards] - self._parse_conditions(model.cond_expr) + self._parse_conditions(self.model.cond_expr, has_multi_index) self.lambda_filters = [None for _ in self.wcards] self._build_lambdas() self._build_query() - self.default_aggregator = "all" def _build_query(self): """Builds the entire query using 'match' and 'rel' using @@ -187,7 +190,7 @@ def _parse_path(self, path_obj): self.wcard_pos[n.name] = idx idx += 1 - def _parse_conditions(self, cond_expr): + def _parse_conditions(self, cond_expr, has_multi_index): """Top level function for parsing the WHERE statement of a String-based query. """ @@ -195,9 +198,9 @@ def _parse_conditions(self, cond_expr): for cond in conditions: converted_condition = None if self._is_unary_cond(cond): - converted_condition = self._parse_unary_cond(cond) + converted_condition = self._parse_unary_cond(cond, has_multi_index) elif self._is_binary_cond(cond): - converted_condition = self._parse_binary_cond(cond) + converted_condition = self._parse_binary_cond(cond, has_multi_index) else: raise RuntimeError("Bad Condition") self.filters[self.wcard_pos[converted_condition[1]]].append( @@ -225,52 +228,67 @@ def _is_binary_cond(self, obj): return True return False - def _parse_binary_cond(self, obj): + def _parse_binary_cond(self, obj, has_multi_index): """Top level function for parsing binary predicates.""" if cname(obj) == "AndCond": - return self._parse_and_cond(obj) + return self._parse_and_cond(obj, has_multi_index) if cname(obj) == "OrCond": - return self._parse_or_cond(obj) + return self._parse_or_cond(obj, has_multi_index) raise RuntimeError("Bad Binary Condition") - def _parse_or_cond(self, obj): + def _parse_or_cond(self, obj, has_multi_index): """Top level function for parsing predicates combined with logical OR.""" - converted_subcond = self._parse_unary_cond(obj.subcond) + converted_subcond = self._parse_unary_cond(obj.subcond, has_multi_index) converted_subcond[0] = "or" return converted_subcond - def _parse_and_cond(self, obj): + def _parse_and_cond(self, obj, has_multi_index): """Top level function for parsing predicates combined with logical AND.""" - converted_subcond = self._parse_unary_cond(obj.subcond) + converted_subcond = self._parse_unary_cond(obj.subcond, has_multi_index) converted_subcond[0] = "and" return converted_subcond - def _parse_unary_cond(self, obj): + def _parse_unary_cond(self, obj, has_multi_index): """Top level function for parsing unary predicates.""" if cname(obj) == "NotCond": - return self._parse_not_cond(obj) - return self._parse_single_cond(obj) + return self._parse_not_cond(obj, has_multi_index) + return self._parse_single_cond(obj, has_multi_index) - def _parse_not_cond(self, obj): + def _parse_not_cond(self, obj, has_multi_index): """Parse predicates containing the logical NOT operator.""" - converted_subcond = self._parse_single_cond(obj.subcond) + converted_subcond = self._parse_single_cond(obj.subcond, has_multi_index) converted_subcond[2] = "not {}".format(converted_subcond[2]) return converted_subcond - def _parse_single_cond(self, obj): + def _run_method_based_on_multi_index(self, method_name, obj, has_multi_index): + real_method_name = method_name + if has_multi_index: + real_method_name = method_name + "_multi_idx" + method = eval("StringQuery.{}".format(real_method_name)) + return method(self, obj) + + def _parse_single_cond(self, obj, has_multi_index): """Top level function for parsing individual numeric or string predicates.""" if self._is_str_cond(obj): - return self._parse_str(obj) + return self._parse_str(obj, has_multi_index) if self._is_num_cond(obj): - return self._parse_num(obj) + return self._parse_num(obj, has_multi_index) if cname(obj) == "NoneCond": - return self._parse_none(obj) + return self._run_method_based_on_multi_index( + "_parse_none", obj, has_multi_index + ) if cname(obj) == "NotNoneCond": - return self._parse_not_none(obj) + return self._run_method_based_on_multi_index( + "_parse_not_none", obj, has_multi_index + ) if cname(obj) == "LeafCond": - return self._parse_leaf(obj) + return self._run_method_based_on_multi_index( + "_parse_leaf", obj, has_multi_index + ) if cname(obj) == "NotLeafCond": - return self._parse_not_leaf(obj) + return self._run_method_based_on_multi_index( + "_parse_not_leaf", obj, has_multi_index + ) raise RuntimeError("Bad Single Condition") def _parse_none(self, obj): @@ -300,6 +318,32 @@ def _parse_none(self, obj): None, ] + def _parse_none_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._depth is None", + None, + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._hatchet_nid is None", + None, + ] + return [ + None, + obj.name, + "df_row[{}].apply(lambda elem: elem is None)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + None, + ] + def _parse_not_none(self, obj): """Parses 'property IS NOT NONE'.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -327,6 +371,32 @@ def _parse_not_none(self, obj): None, ] + def _parse_not_none_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._depth is not None", + None, + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._hatchet_nid is not None", + None, + ] + return [ + None, + obj.name, + "df_row[{}].apply(lambda elem: elem is not None)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + None, + ] + def _parse_leaf(self, obj): """Parses 'node IS LEAF'.""" return [ @@ -336,6 +406,14 @@ def _parse_leaf(self, obj): None, ] + def _parse_leaf_multi_idx(self, obj): + return [ + None, + obj.name, + "len(df_row.index.get_level_values('node')[0].children) == 0", + None, + ] + def _parse_not_leaf(self, obj): """Parses 'node IS NOT LEAF'.""" return [ @@ -345,6 +423,14 @@ def _parse_not_leaf(self, obj): None, ] + def _parse_not_leaf_multi_idx(self, obj): + return [ + None, + obj.name, + "len(df_row.index.get_level_values('node')[0].children) > 0", + None, + ] + def _is_str_cond(self, obj): """Determines whether a predicate is for string data.""" if cname(obj) in [ @@ -373,20 +459,30 @@ def _is_num_cond(self, obj): return True return False - def _parse_str(self, obj): + def _parse_str(self, obj, has_multi_index): """Function that redirects processing of string predicates to the correct function. """ if cname(obj) == "StringEq": - return self._parse_str_eq(obj) + return self._run_method_based_on_multi_index( + "_parse_str_eq", obj, has_multi_index + ) if cname(obj) == "StringStartsWith": - return self._parse_str_starts_with(obj) + return self._run_method_based_on_multi_index( + "_parse_str_starts_with", obj, has_multi_index + ) if cname(obj) == "StringEndsWith": - return self._parse_str_ends_with(obj) + return self._run_method_based_on_multi_index( + "_parse_str_ends_with", obj, has_multi_index + ) if cname(obj) == "StringContains": - return self._parse_str_contains(obj) + return self._run_method_based_on_multi_index( + "_parse_str_contains", obj, has_multi_index + ) if cname(obj) == "StringMatch": - return self._parse_str_match(obj) + return self._run_method_based_on_multi_index( + "_parse_str_match", obj, has_multi_index + ) raise RuntimeError("Bad String Op Class") def _parse_str_eq(self, obj): @@ -409,6 +505,25 @@ def _parse_str_eq(self, obj): ), ] + def _parse_str_eq_multi_idx(self, obj): + return [ + None, + obj.name, + 'df_row[{}].apply(lambda elem: elem == "{}")'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_str_starts_with(self, obj): """Processes string 'startswith' predicates.""" return [ @@ -429,6 +544,25 @@ def _parse_str_starts_with(self, obj): ), ] + def _parse_str_starts_with_multi_idx(self, obj): + return [ + None, + obj.name, + 'df_row[{}].apply(lambda elem: elem.startswith("{}"))'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_str_ends_with(self, obj): """Processes string 'endswith' predicates.""" return [ @@ -449,6 +583,25 @@ def _parse_str_ends_with(self, obj): ), ] + def _parse_str_ends_with_multi_idx(self, obj): + return [ + None, + obj.name, + 'df_row[{}].apply(lambda elem: elem.endswith("{}"))'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_str_contains(self, obj): """Processes string 'contains' predicates.""" return [ @@ -469,6 +622,25 @@ def _parse_str_contains(self, obj): ), ] + def _parse_str_contains_multi_idx(self, obj): + return [ + None, + obj.name, + 'df_row[{}].apply(lambda elem: "{}" in elem)'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_str_match(self, obj): """Processes string regex match predicates.""" return [ @@ -489,28 +661,65 @@ def _parse_str_match(self, obj): ), ] - def _parse_num(self, obj): + def _parse_str_match_multi_idx(self, obj): + return [ + None, + obj.name, + 'df_row[{}].apply(lambda elem: re.match("{}", elem) is not None)'.format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + + def _parse_num(self, obj, has_multi_index): """Function that redirects processing of numeric predicates to the correct function. """ if cname(obj) == "NumEq": - return self._parse_num_eq(obj) + return self._run_method_based_on_multi_index( + "_parse_num_eq", obj, has_multi_index + ) if cname(obj) == "NumLt": - return self._parse_num_lt(obj) + return self._run_method_based_on_multi_index( + "_parse_num_lt", obj, has_multi_index + ) if cname(obj) == "NumGt": - return self._parse_num_gt(obj) + return self._run_method_based_on_multi_index( + "_parse_num_gt", obj, has_multi_index + ) if cname(obj) == "NumLte": - return self._parse_num_lte(obj) + return self._run_method_based_on_multi_index( + "_parse_num_lte", obj, has_multi_index + ) if cname(obj) == "NumGte": - return self._parse_num_gte(obj) + return self._run_method_based_on_multi_index( + "_parse_num_gte", obj, has_multi_index + ) if cname(obj) == "NumNan": - return self._parse_num_nan(obj) + return self._run_method_based_on_multi_index( + "_parse_num_nan", obj, has_multi_index + ) if cname(obj) == "NumNotNan": - return self._parse_num_not_nan(obj) + return self._run_method_based_on_multi_index( + "_parse_num_not_nan", obj, has_multi_index + ) if cname(obj) == "NumInf": - return self._parse_num_inf(obj) + return self._run_method_based_on_multi_index( + "_parse_num_inf", obj, has_multi_index + ) if cname(obj) == "NumNotInf": - return self._parse_num_not_inf(obj) + return self._run_method_based_on_multi_index( + "_parse_num_not_inf", obj, has_multi_index + ) raise RuntimeError("Bad Number Op Class") def _parse_num_eq(self, obj): @@ -530,9 +739,7 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -555,9 +762,7 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -590,32 +795,36 @@ def _parse_num_eq(self, obj): ), ] - def _parse_num_lt(self, obj): - """Processes numeric less-than predicates.""" + def _parse_num_eq_multi_idx(self, obj): if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": - if obj.val < 0: + if obj.val == -1: + return [ + None, + obj.name, + "len(df_row.index.get_level_values('node')[0].children) == 0", + None, + ] + elif obj.val < 0: warnings.warn( """ The 'depth' property of a Node is strictly non-negative. This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ None, obj.name, "False", - "isinstance(df_row.name._depth, Real)", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] return [ None, obj.name, - "df_row.name._depth < {}".format(obj.val), - "isinstance(df_row.name._depth, Real)", + "df_row.index.get_level_values('node')[0]._depth == {}".format(obj.val), + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: @@ -625,27 +834,27 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ None, obj.name, "False", - "isinstance(df_row.name._hatchet_nid, Real)", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row.name._hatchet_nid < {}".format(obj.val), - "isinstance(df_row.name._hatchet_nid, Real)", + "df_row.index.get_level_values('node')[0]._hatchet_nid == {}".format( + obj.val + ), + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row[{}] < {}".format( + "df_row[{}].apply(lambda elem: elem == {})".format( ( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 @@ -653,38 +862,36 @@ def _parse_num_lt(self, obj): ), obj.val, ), - "isinstance(df_row[{}], Real)".format( + "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 else "'{}'".format(obj.prop.ids[0]) ), ] - def _parse_num_gt(self, obj): - """Processes numeric greater-than predicates.""" + def _parse_num_lt(self, obj): + """Processes numeric less-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ The 'depth' property of a Node is strictly non-negative. - This condition will always be true. + This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ None, obj.name, - "True", + "False", "isinstance(df_row.name._depth, Real)", ] return [ None, obj.name, - "df_row.name._depth > {}".format(obj.val), + "df_row.name._depth < {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": @@ -692,30 +899,28 @@ def _parse_num_gt(self, obj): warnings.warn( """ The 'node_id' property of a Node is strictly non-negative. - This condition will always be true. + This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ None, obj.name, - "True", + "False", "isinstance(df_row.name._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row.name._hatchet_nid > {}".format(obj.val), + "df_row.name._hatchet_nid < {}".format(obj.val), "isinstance(df_row.name._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row[{}] > {}".format( + "df_row[{}] < {}".format( ( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 @@ -730,8 +935,7 @@ def _parse_num_gt(self, obj): ), ] - def _parse_num_lte(self, obj): - """Processes numeric less-than-or-equal-to predicates.""" + def _parse_num_lt_multi_idx(self, obj): if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -740,22 +944,20 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ None, obj.name, "False", - "isinstance(df_row.name._depth, Real)", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] return [ None, obj.name, - "df_row.name._depth <= {}".format(obj.val), - "isinstance(df_row.name._depth, Real)", + "df_row.index.get_level_values('node')[0]._depth < {}".format(obj.val), + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: @@ -765,27 +967,27 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ None, obj.name, "False", - "isinstance(df_row.name._hatchet_nid, Real)", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row.name._hatchet_nid <= {}".format(obj.val), - "isinstance(df_row.name._hatchet_nid, Real)", + "df_row.index.get_level_values('node')[0]._hatchet_nid < {}".format( + obj.val + ), + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row[{}] <= {}".format( + "df_row[{}].apply(lambda elem: elem < {})".format( ( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 @@ -793,15 +995,15 @@ def _parse_num_lte(self, obj): ), obj.val, ), - "isinstance(df_row[{}], Real)".format( + "is_numeric_dtype(df_row[{}])".format( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 else "'{}'".format(obj.prop.ids[0]) ), ] - def _parse_num_gte(self, obj): - """Processes numeric greater-than-or-equal-to predicates.""" + def _parse_num_gt(self, obj): + """Processes numeric greater-than predicates.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( @@ -810,9 +1012,7 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -824,7 +1024,7 @@ def _parse_num_gte(self, obj): return [ None, obj.name, - "df_row.name._depth >= {}".format(obj.val), + "df_row.name._depth > {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": @@ -835,9 +1035,7 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format( - obj - ), + """.format(obj), RedundantQueryFilterWarning, ) return [ @@ -849,13 +1047,13 @@ def _parse_num_gte(self, obj): return [ None, obj.name, - "df_row.name._hatchet_nid >= {}".format(obj.val), + "df_row.name._hatchet_nid > {}".format(obj.val), "isinstance(df_row.name._hatchet_nid, Real)", ] return [ None, obj.name, - "df_row[{}] >= {}".format( + "df_row[{}] > {}".format( ( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 @@ -870,30 +1068,265 @@ def _parse_num_gte(self, obj): ), ] - def _parse_num_nan(self, obj): - """Processes predicates that check for NaN.""" + def _parse_num_gt_multi_idx(self, obj): if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + if obj.val < 0: + warnings.warn( + """ + The 'depth' property of a Node is strictly non-negative. + This condition will always be true. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "True", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] return [ None, obj.name, - "pd.isna(df_row.name._depth)", - "isinstance(df_row.name._depth, Real)", + "df_row.index.get_level_values('node')[0]._depth > {}".format(obj.val), + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + if obj.val < 0: + warnings.warn( + """ + The 'node_id' property of a Node is strictly non-negative. + This condition will always be true. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "True", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] return [ None, obj.name, - "pd.isna(df_row.name._hatchet_nid)", - "isinstance(df_row.name._hatchet_nid, Real)", + "df_row.index.get_level_values('node')[0]._hatchet_nid > {}".format( + obj.val + ), + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", ] return [ None, obj.name, - "pd.isna(df_row[{}])".format( - str(tuple(obj.prop.ids)) - if len(obj.prop.ids) > 1 - else "'{}'".format(obj.prop.ids[0]) - ), + "df_row[{}].apply(lambda elem: elem > {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + + def _parse_num_lte(self, obj): + """Processes numeric less-than-or-equal-to predicates.""" + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + if obj.val < 0: + warnings.warn( + """ + The 'depth' property of a Node is strictly non-negative. + This condition will always be false. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "False", + "isinstance(df_row.name._depth, Real)", + ] + return [ + None, + obj.name, + "df_row.name._depth <= {}".format(obj.val), + "isinstance(df_row.name._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + if obj.val < 0: + warnings.warn( + """ + The 'node_id' property of a Node is strictly non-negative. + This condition will always be false. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "False", + "isinstance(df_row.name._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row.name._hatchet_nid <= {}".format(obj.val), + "isinstance(df_row.name._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row[{}] <= {}".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + + def _parse_num_lte_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + if obj.val < 0: + warnings.warn( + """ + The 'depth' property of a Node is strictly non-negative. + This condition will always be false. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "False", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._depth <= {}".format(obj.val), + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + if obj.val < 0: + warnings.warn( + """ + The 'node_id' property of a Node is strictly non-negative. + This condition will always be false. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "False", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._hatchet_nid <= {}".format( + obj.val + ), + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row[{}].apply(lambda elem: elem <= {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + + def _parse_num_gte(self, obj): + """Processes numeric greater-than-or-equal-to predicates.""" + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + if obj.val < 0: + warnings.warn( + """ + The 'depth' property of a Node is strictly non-negative. + This condition will always be true. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "True", + "isinstance(df_row.name._depth, Real)", + ] + return [ + None, + obj.name, + "df_row.name._depth >= {}".format(obj.val), + "isinstance(df_row.name._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + if obj.val < 0: + warnings.warn( + """ + The 'node_id' property of a Node is strictly non-negative. + This condition will always be true. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "True", + "isinstance(df_row.name._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row.name._hatchet_nid >= {}".format(obj.val), + "isinstance(df_row.name._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row[{}] >= {}".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), "isinstance(df_row[{}], Real)".format( str(tuple(obj.prop.ids)) if len(obj.prop.ids) > 1 @@ -901,6 +1334,134 @@ def _parse_num_nan(self, obj): ), ] + def _parse_num_gte_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + if obj.val < 0: + warnings.warn( + """ + The 'depth' property of a Node is strictly non-negative. + This condition will always be true. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "True", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._depth >= {}".format(obj.val), + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids == "node_id": + if obj.val < 0: + warnings.warn( + """ + The 'node_id' property of a Node is strictly non-negative. + This condition will always be true. + The statement that triggered this warning is: + {} + """.format(obj), + RedundantQueryFilterWarning, + ) + return [ + None, + obj.name, + "True", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row.index.get_level_values('node')[0]._hatchet_nid >= {}".format( + obj.val + ), + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "df_row[{}].apply(lambda elem: elem >= {})".format( + ( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + obj.val, + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + + def _parse_num_nan(self, obj): + """Processes predicates that check for NaN.""" + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "pd.isna(df_row.name._depth)", + "isinstance(df_row.name._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "pd.isna(df_row.name._hatchet_nid)", + "isinstance(df_row.name._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + + def _parse_num_nan_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "pd.isna(df_row.index.get_level_values('node')[0]._depth)", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "pd.isna(df_row.index.get_level_values('node')[0]._hatchet_nid)", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_num_not_nan(self, obj): """Processes predicates that check for NaN.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -932,6 +1493,36 @@ def _parse_num_not_nan(self, obj): ), ] + def _parse_num_not_nan_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "not pd.isna(df_row.index.get_level_values('node')[0]._depth)", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "not pd.isna(df_row.index.get_level_values('node')[0]._hatchet_nid)", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "not pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_num_inf(self, obj): """Processes predicates that check for Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -963,6 +1554,36 @@ def _parse_num_inf(self, obj): ), ] + def _parse_num_inf_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "np.isinf(df_row.index.get_level_values('node')[0]._depth)", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "np.isinf(df_row.index.get_level_values('node')[0]._hatchet_nid)", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def _parse_num_not_inf(self, obj): """Processes predicates that check for not-Infinity.""" if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": @@ -994,6 +1615,36 @@ def _parse_num_not_inf(self, obj): ), ] + def _parse_num_not_inf_multi_idx(self, obj): + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": + return [ + None, + obj.name, + "not np.isinf(df_row.index.get_level_values('node')[0]._depth)", + "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", + ] + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": + return [ + None, + obj.name, + "not np.isinf(df_row.index.get_level_values('node')[0]._hatchet_nid)", + "isinstance(df_row.index.get_level_values('node')[0]._hatchet_nid, Real)", + ] + return [ + None, + obj.name, + "not np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + ] + def parse_string_dialect(query_str): """Parse all types of String-based queries, including multi-queries that leverage From db68dbc0fc5fb711d4388f00f8a3ed974ac226f9 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 11:57:25 -0400 Subject: [PATCH 11/20] Updates GraphFrame.filter's docstring --- hatchet/graphframe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index eb2b6438..d399e4cc 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -484,6 +484,10 @@ def filter( update_inc_cols (boolean, optional): if True, update inclusive columns when performing squash. rec_limit: set Python recursion limit, increase if running into recursion depth errors) (default: 1000). + predicate_row_aggregator (str or Callable, optional): function to use in Query Language + to merge multiple predicate results for each node into a single boolean. When providing + a string value, the following are accepted: "all" (equivalent to Python 'all'), "any" + (equivalent to Python 'any'), "off" (no aggregation) """ sys.setrecursionlimit(rec_limit) From 42219a13c06e0b8e95bdaea2f7270089bbb06edc Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:05:29 -0400 Subject: [PATCH 12/20] Fixes formatting --- hatchet/query/compat.py | 4 +- hatchet/query/string_dialect.py | 76 ++++++++++++++++++++++++--------- 2 files changed, 60 insertions(+), 20 deletions(-) diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index 848b7572..3820060e 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -322,7 +322,9 @@ def apply(self, gf): Returns: (list): A list representing the set of nodes from paths that match this query """ - return COMPATABILITY_ENGINE.apply(self.true_query, gf.graph, gf.dataframe, "off") + return COMPATABILITY_ENGINE.apply( + self.true_query, gf.graph, gf.dataframe, "off" + ) def _get_new_query(self): """Get all the underlying 'new-style' query in this object. diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 878910bb..32437c41 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -739,7 +739,9 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -762,7 +764,9 @@ def _parse_num_eq(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -811,7 +815,9 @@ def _parse_num_eq_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -834,7 +840,9 @@ def _parse_num_eq_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -879,7 +887,9 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -902,7 +912,9 @@ def _parse_num_lt(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -944,7 +956,9 @@ def _parse_num_lt_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -967,7 +981,9 @@ def _parse_num_lt_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1012,7 +1028,9 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1035,7 +1053,9 @@ def _parse_num_gt(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1077,7 +1097,9 @@ def _parse_num_gt_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1100,7 +1122,9 @@ def _parse_num_gt_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1145,7 +1169,9 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1168,7 +1194,9 @@ def _parse_num_lte(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1210,7 +1238,9 @@ def _parse_num_lte_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1233,7 +1263,9 @@ def _parse_num_lte_multi_idx(self, obj): This condition will always be false. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1278,7 +1310,9 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1301,7 +1335,9 @@ def _parse_num_gte(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ @@ -1343,7 +1379,9 @@ def _parse_num_gte_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ From 216afea8897ad6cc574e990f3b44551b2ee3e009 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:10:16 -0400 Subject: [PATCH 13/20] Restores multi_index_mode as a deprecated parameter to GraphFrame.filter so that existing code does not break --- hatchet/graphframe.py | 11 +++++++++++ hatchet/query/string_dialect.py | 4 +++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index d399e4cc..d4402d56 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -8,6 +8,7 @@ import sys import traceback from collections import defaultdict +import warnings import multiprocess as mp import numpy as np @@ -473,6 +474,7 @@ def filter( num_procs=mp.cpu_count(), rec_limit=1000, predicate_row_aggregator=None, + multi_index_mode=None, ): """Filter the dataframe using a user-supplied function. @@ -488,7 +490,16 @@ def filter( to merge multiple predicate results for each node into a single boolean. When providing a string value, the following are accepted: "all" (equivalent to Python 'all'), "any" (equivalent to Python 'any'), "off" (no aggregation) + multi_index_mode: deprecated alias for "predicate_row_aggregator" """ + if multi_index_mode is not None: + warnings.warn( + "'multi_index_mode' parameter is deprecated. Use 'predicate_row_aggregator' instead", + DeprecationWarning + ) + if predicate_row_aggregator is None: + predicate_row_aggregator = multi_index_mode + sys.setrecursionlimit(rec_limit) dataframe_copy = self.dataframe.copy() diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index 32437c41..6990d768 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -1404,7 +1404,9 @@ def _parse_num_gte_multi_idx(self, obj): This condition will always be true. The statement that triggered this warning is: {} - """.format(obj), + """.format( + obj + ), RedundantQueryFilterWarning, ) return [ From c851fde4b41b8b7ff9495cb2cff611a00d6dc7ca Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:11:39 -0400 Subject: [PATCH 14/20] Yet more formatting --- hatchet/graphframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet/graphframe.py b/hatchet/graphframe.py index d4402d56..d403ee8d 100644 --- a/hatchet/graphframe.py +++ b/hatchet/graphframe.py @@ -495,7 +495,7 @@ def filter( if multi_index_mode is not None: warnings.warn( "'multi_index_mode' parameter is deprecated. Use 'predicate_row_aggregator' instead", - DeprecationWarning + DeprecationWarning, ) if predicate_row_aggregator is None: predicate_row_aggregator = multi_index_mode From 994f317504fad9f64b7bed60622b163217231394 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:23:30 -0400 Subject: [PATCH 15/20] Fixes unit tests to properly pass predicate_row_aggregator --- hatchet/tests/query.py | 102 ++++++++++++++++++++--------------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index b9020e77..d50bb2d0 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -315,7 +315,7 @@ def test_match(mock_graph_literal): ] query0 = ObjectQuery(path0) engine = QueryEngine() - assert engine._match_pattern(query0, gf.dataframe, root, 0) == match0 + assert engine._match_pattern(query0, gf.dataframe, query0.default_aggregator, root, 0) == match0 engine.reset_cache() @@ -327,7 +327,7 @@ def test_match(mock_graph_literal): {"time (inc)": 7.5, "time": 7.5}, ] query1 = ObjectQuery(path1) - assert engine._match_pattern(query1, gf.dataframe, root, 0) is None + assert engine._match_pattern(query1, gf.dataframe, query0.default_aggregator, root, 0) is None def test_apply(mock_graph_literal): @@ -349,7 +349,7 @@ def test_apply(mock_graph_literal): query = ObjectQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"time (inc)": ">= 30.0"}, ".", {"name": "bar"}, "*"] match = [ @@ -360,12 +360,12 @@ def test_apply(mock_graph_literal): root.children[1].children[0].children[0].children[0].children[1], ] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"name": "foo"}, {"name": "bar"}, {"time": 5.0}] match = [root, root.children[0], root.children[0].children[0]] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"name": "foo"}, {"name": "qux"}, ("+", {"time (inc)": "> 15.0"})] match = [ @@ -376,22 +376,22 @@ def test_apply(mock_graph_literal): root.children[1].children[0].children[0].children[0], ] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = [{"name": "this"}, ("*", {"name": "is"}), {"name": "nonsense"}] query = ObjectQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] path = [{"name": 5}, "*", {"name": "whatever"}] query = ObjectQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = [{"time": "badstring"}, "*", {"name": "whatever"}] query = ObjectQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) class DummyType: def __init__(self): @@ -407,7 +407,7 @@ def __init__(self): path = [{"name": "foo"}, {"name": "bar"}, {"list": DummyType()}] query = ObjectQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = ["*", {"name": "bar"}, {"name": "grault"}, "*"] match = [ @@ -452,11 +452,11 @@ def __init__(self): ] match = list(set().union(*match)) query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = ["*", {"name": "bar"}, {"name": "grault"}, "+"] query = ObjectQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] # Test a former edge case with the + quantifier/wildcard match = [ @@ -486,7 +486,7 @@ def __init__(self): match = list(set().union(*match)) path = [("+", {"name": "ba.*"})] query = ObjectQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) def test_apply_indices(calc_pi_hpct_db): @@ -518,7 +518,7 @@ def test_apply_indices(calc_pi_hpct_db): ) == sorted(matches) gf.drop_index_levels() - assert engine.apply(query, gf.graph, gf.dataframe) == matches + assert engine.apply(query, gf.graph, gf.dataframe, None) == matches def test_object_dialect_depth(mock_graph_literal): @@ -527,7 +527,7 @@ def test_object_dialect_depth(mock_graph_literal): engine = QueryEngine() roots = gf.graph.roots matches = [c for r in roots for c in r.children] - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) query = ObjectQuery([("*", {"depth": "<= 2"})]) matches = [ @@ -554,11 +554,11 @@ def test_object_dialect_depth(mock_graph_literal): [roots[1].children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) with pytest.raises(InvalidQueryFilter): query = ObjectQuery([{"depth": "hello"}]) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) def test_object_dialect_hatchet_nid(mock_graph_literal): @@ -575,14 +575,14 @@ def test_object_dialect_hatchet_nid(mock_graph_literal): [root.children[0].children[1]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) query = ObjectQuery([{"node_id": 0}]) - assert engine.apply(query, gf.graph, gf.dataframe) == [gf.graph.roots[0]] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [gf.graph.roots[0]] with pytest.raises(InvalidQueryFilter): query = ObjectQuery([{"node_id": "hello"}]) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) def test_object_dialect_depth_index_levels(calc_pi_hpct_db): @@ -702,7 +702,7 @@ def test_object_dialect_multi_condition_one_attribute(mock_graph_literal): [roots[1].children[0]], ] matches = list(set().union(*matches)) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(matches) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(matches) def test_obj_query_is_query(): @@ -799,7 +799,7 @@ def test_conjunction_query(mock_graph_literal): roots[0].children[1], roots[0].children[1].children[0], ] - assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -822,7 +822,7 @@ def test_disjunction_query(mock_graph_literal): roots[1].children[0].children[0], roots[1].children[0].children[1], ] - assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -841,7 +841,7 @@ def test_exc_disjunction_query(mock_graph_literal): roots[0].children[2].children[0].children[1].children[0].children[0], roots[1].children[0].children[0], ] - assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -963,7 +963,7 @@ def test_apply_string_dialect(mock_graph_literal): query = StringQuery(path) engine = QueryEngine() - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(".")->(q)->("*") WHERE p."time (inc)" >= 30.0 AND q."name" = "bar" @@ -976,14 +976,14 @@ def test_apply_string_dialect(mock_graph_literal): root.children[1].children[0].children[0].children[0].children[1], ] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(q)->(r) WHERE p."name" = "foo" AND q."name" = "bar" AND r."time" = 5.0 """ match = [root, root.children[0], root.children[0].children[0]] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(q)->("+", r) WHERE p."name" = "foo" AND q."name" = "qux" AND r."time (inc)" > 15.0 @@ -996,7 +996,7 @@ def test_apply_string_dialect(mock_graph_literal): root.children[1].children[0].children[0].children[0], ] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->(q) WHERE p."time (inc)" > 100 OR p."time (inc)" <= 30 AND q."time (inc)" = 20 @@ -1009,28 +1009,28 @@ def test_apply_string_dialect(mock_graph_literal): roots[1].children[0], ] query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH (p)->("*", q)->(r) WHERE p."name" = "this" AND q."name" = "is" AND r."name" = "nonsense" """ query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] path = """MATCH (p)->("*")->(q) WHERE p."name" = 5 AND q."name" = "whatever" """ with pytest.raises(InvalidQueryFilter): query = StringQuery(path) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = """MATCH (p)->("*")->(q) WHERE p."time" = "badstring" AND q."name" = "whatever" """ query = StringQuery(path) with pytest.raises(InvalidQueryFilter): - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) class DummyType: def __init__(self): @@ -1048,7 +1048,7 @@ def __init__(self): """ with pytest.raises(InvalidQueryPath): query = StringQuery(path) - engine.apply(query, gf.graph, gf.dataframe) + engine.apply(query, gf.graph, gf.dataframe, None) path = """MATCH ("*")->(p)->(q)->("*") WHERE p."name" = "bar" AND q."name" = "grault" @@ -1095,13 +1095,13 @@ def __init__(self): ] match = list(set().union(*match)) query = StringQuery(path) - assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) + assert sorted(engine.apply(query, gf.graph, gf.dataframe, None)) == sorted(match) path = """MATCH ("*")->(p)->(q)->("+") WHERE p."name" = "bar" AND q."name" = "grault" """ query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == [] + assert engine.apply(query, gf.graph, gf.dataframe, None) == [] gf.dataframe["time"] = np.NaN gf.dataframe.at[gf.graph.roots[0], "time"] = 5.0 @@ -1109,7 +1109,7 @@ def __init__(self): WHERE p."time" IS NOT NAN""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["time"] = 5.0 gf.dataframe.at[gf.graph.roots[0], "time"] = np.NaN @@ -1117,7 +1117,7 @@ def __init__(self): WHERE p."time" IS NAN""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["time"] = np.Inf gf.dataframe.at[gf.graph.roots[0], "time"] = 5.0 @@ -1125,7 +1125,7 @@ def __init__(self): WHERE p."time" IS NOT INF""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["time"] = 5.0 gf.dataframe.at[gf.graph.roots[0], "time"] = np.Inf @@ -1133,7 +1133,7 @@ def __init__(self): WHERE p."time" IS INF""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match names = gf.dataframe["name"].copy() gf.dataframe["name"] = None @@ -1142,7 +1142,7 @@ def __init__(self): WHERE p."name" IS NOT NONE""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match gf.dataframe["name"] = names gf.dataframe.at[gf.graph.roots[0], "name"] = None @@ -1150,7 +1150,7 @@ def __init__(self): WHERE p."name" IS NONE""" match = [gf.graph.roots[0]] query = StringQuery(path) - assert engine.apply(query, gf.graph, gf.dataframe) == match + assert engine.apply(query, gf.graph, gf.dataframe, None) == match def test_string_conj_compound_query(mock_graph_literal): @@ -1173,10 +1173,10 @@ def test_string_conj_compound_query(mock_graph_literal): roots[0].children[1], roots[0].children[1].children[0], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe, None)) == sorted( matches ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -1208,10 +1208,10 @@ def test_string_disj_compound_query(mock_graph_literal): roots[1].children[0].children[0], roots[1].children[0].children[1], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe, None)) == sorted( matches ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -1239,10 +1239,10 @@ def test_cypher_exc_disj_compound_query(mock_graph_literal): roots[0].children[2].children[0].children[1].children[0].children[0], roots[1].children[0].children[0], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe, None)) == sorted( matches ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe, None)) == sorted( matches ) @@ -1278,15 +1278,15 @@ def test_leaf_query(small_mock2): """ ) engine = QueryEngine() - assert sorted(engine.apply(obj_query, gf.graph, gf.dataframe)) == sorted(matches) - assert sorted(engine.apply(str_query_numeric, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(obj_query, gf.graph, gf.dataframe, None)) == sorted(matches) + assert sorted(engine.apply(str_query_numeric, gf.graph, gf.dataframe, None)) == sorted( matches ) - assert sorted(engine.apply(str_query_is_leaf, gf.graph, gf.dataframe)) == sorted( + assert sorted(engine.apply(str_query_is_leaf, gf.graph, gf.dataframe, None)) == sorted( matches ) assert sorted( - engine.apply(str_query_is_not_leaf, gf.graph, gf.dataframe) + engine.apply(str_query_is_not_leaf, gf.graph, gf.dataframe, None) ) == sorted(nonleaves) From a1ccdf467eeeb17ea5e922e534f1ba4a285d32e1 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:29:46 -0400 Subject: [PATCH 16/20] Even more formatting because Black sucks --- hatchet/tests/query.py | 58 ++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index d50bb2d0..0053a638 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -315,7 +315,10 @@ def test_match(mock_graph_literal): ] query0 = ObjectQuery(path0) engine = QueryEngine() - assert engine._match_pattern(query0, gf.dataframe, query0.default_aggregator, root, 0) == match0 + assert ( + engine._match_pattern(query0, gf.dataframe, query0.default_aggregator, root, 0) + == match0 + ) engine.reset_cache() @@ -327,7 +330,10 @@ def test_match(mock_graph_literal): {"time (inc)": 7.5, "time": 7.5}, ] query1 = ObjectQuery(path1) - assert engine._match_pattern(query1, gf.dataframe, query0.default_aggregator, root, 0) is None + assert ( + engine._match_pattern(query1, gf.dataframe, query0.default_aggregator, root, 0) + is None + ) def test_apply(mock_graph_literal): @@ -1173,12 +1179,12 @@ def test_string_conj_compound_query(mock_graph_literal): roots[0].children[1], roots[0].children[1].children[0], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe, None)) == sorted( - matches - ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe, None)) == sorted( - matches - ) + assert sorted( + engine.apply(compound_query1, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(compound_query2, gf.graph, gf.dataframe, None) + ) == sorted(matches) def test_string_disj_compound_query(mock_graph_literal): @@ -1208,12 +1214,12 @@ def test_string_disj_compound_query(mock_graph_literal): roots[1].children[0].children[0], roots[1].children[0].children[1], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe, None)) == sorted( - matches - ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe, None)) == sorted( - matches - ) + assert sorted( + engine.apply(compound_query1, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(compound_query2, gf.graph, gf.dataframe, None) + ) == sorted(matches) def test_cypher_exc_disj_compound_query(mock_graph_literal): @@ -1239,12 +1245,12 @@ def test_cypher_exc_disj_compound_query(mock_graph_literal): roots[0].children[2].children[0].children[1].children[0].children[0], roots[1].children[0].children[0], ] - assert sorted(engine.apply(compound_query1, gf.graph, gf.dataframe, None)) == sorted( - matches - ) - assert sorted(engine.apply(compound_query2, gf.graph, gf.dataframe, None)) == sorted( - matches - ) + assert sorted( + engine.apply(compound_query1, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(compound_query2, gf.graph, gf.dataframe, None) + ) == sorted(matches) def test_leaf_query(small_mock2): @@ -1278,13 +1284,15 @@ def test_leaf_query(small_mock2): """ ) engine = QueryEngine() - assert sorted(engine.apply(obj_query, gf.graph, gf.dataframe, None)) == sorted(matches) - assert sorted(engine.apply(str_query_numeric, gf.graph, gf.dataframe, None)) == sorted( - matches - ) - assert sorted(engine.apply(str_query_is_leaf, gf.graph, gf.dataframe, None)) == sorted( + assert sorted(engine.apply(obj_query, gf.graph, gf.dataframe, None)) == sorted( matches ) + assert sorted( + engine.apply(str_query_numeric, gf.graph, gf.dataframe, None) + ) == sorted(matches) + assert sorted( + engine.apply(str_query_is_leaf, gf.graph, gf.dataframe, None) + ) == sorted(matches) assert sorted( engine.apply(str_query_is_not_leaf, gf.graph, gf.dataframe, None) ) == sorted(nonleaves) From e1eb8b96ceeb711d464b8f991ffd13bfd3ad89aa Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:32:00 -0400 Subject: [PATCH 17/20] Removes some spaces that accidentally got left over --- hatchet/tests/query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 0053a638..8b17192e 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -316,7 +316,7 @@ def test_match(mock_graph_literal): query0 = ObjectQuery(path0) engine = QueryEngine() assert ( - engine._match_pattern(query0, gf.dataframe, query0.default_aggregator, root, 0) + engine._match_pattern(query0, gf.dataframe, query0.default_aggregator, root, 0) == match0 ) @@ -331,7 +331,7 @@ def test_match(mock_graph_literal): ] query1 = ObjectQuery(path1) assert ( - engine._match_pattern(query1, gf.dataframe, query0.default_aggregator, root, 0) + engine._match_pattern(query1, gf.dataframe, query0.default_aggregator, root, 0) is None ) From 0fcdb0907b515188ab6fad77698afd93a2004dd3 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:53:11 -0400 Subject: [PATCH 18/20] Fixes a few bugs in testing --- hatchet/query/engine.py | 2 +- hatchet/tests/query.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index a515112c..e052091f 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -301,7 +301,7 @@ def _apply_impl( # this node. if query.query_pattern[0][0] == "*": if 1 in self.search_cache[node._hatchet_nid]: - sub_match = self._match_pattern(query, dframe, node, 1) + sub_match = self._match_pattern(query, dframe, predicate_row_aggregator, node, 1) if sub_match is not None: matches.extend(sub_match) if 0 in self.search_cache[node._hatchet_nid]: diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 8b17192e..a2dfddbb 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -25,6 +25,7 @@ ExclusiveDisjunctionQuery, NegationQuery, ) +from hatchet.query.engine import _all_aggregator, _any_aggregator def test_construct_object_dialect(): @@ -316,7 +317,7 @@ def test_match(mock_graph_literal): query0 = ObjectQuery(path0) engine = QueryEngine() assert ( - engine._match_pattern(query0, gf.dataframe, query0.default_aggregator, root, 0) + engine._match_pattern(query0, gf.dataframe, _all_aggregator, root, 0) == match0 ) @@ -331,7 +332,7 @@ def test_match(mock_graph_literal): ] query1 = ObjectQuery(path1) assert ( - engine._match_pattern(query1, gf.dataframe, query0.default_aggregator, root, 0) + engine._match_pattern(query1, gf.dataframe, _all_aggregator, root, 0) is None ) From 241a823e70d4e68c705fa864e6ba8b40c6fe868c Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:55:13 -0400 Subject: [PATCH 19/20] Formatting, yet again --- hatchet/query/engine.py | 4 +++- hatchet/tests/query.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hatchet/query/engine.py b/hatchet/query/engine.py index e052091f..50ce9410 100644 --- a/hatchet/query/engine.py +++ b/hatchet/query/engine.py @@ -301,7 +301,9 @@ def _apply_impl( # this node. if query.query_pattern[0][0] == "*": if 1 in self.search_cache[node._hatchet_nid]: - sub_match = self._match_pattern(query, dframe, predicate_row_aggregator, node, 1) + sub_match = self._match_pattern( + query, dframe, predicate_row_aggregator, node, 1 + ) if sub_match is not None: matches.extend(sub_match) if 0 in self.search_cache[node._hatchet_nid]: diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index a2dfddbb..1c242e4a 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -317,8 +317,7 @@ def test_match(mock_graph_literal): query0 = ObjectQuery(path0) engine = QueryEngine() assert ( - engine._match_pattern(query0, gf.dataframe, _all_aggregator, root, 0) - == match0 + engine._match_pattern(query0, gf.dataframe, _all_aggregator, root, 0) == match0 ) engine.reset_cache() From d96d8f279989526f83d133c7720bdc45537af772 Mon Sep 17 00:00:00 2001 From: Ian Lumsden Date: Fri, 14 Mar 2025 14:57:53 -0400 Subject: [PATCH 20/20] More formatting that magically appeared --- hatchet/tests/query.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index 1c242e4a..9f911ae7 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -330,10 +330,7 @@ def test_match(mock_graph_literal): {"time (inc)": 7.5, "time": 7.5}, ] query1 = ObjectQuery(path1) - assert ( - engine._match_pattern(query1, gf.dataframe, _all_aggregator, root, 0) - is None - ) + assert engine._match_pattern(query1, gf.dataframe, _all_aggregator, root, 0) is None def test_apply(mock_graph_literal):