diff --git a/hatchet/query/compat.py b/hatchet/query/compat.py index b94433f2..d62a0c5c 100644 --- a/hatchet/query/compat.py +++ b/hatchet/query/compat.py @@ -33,7 +33,6 @@ class AbstractQuery(ABC): - """Base class for all 'old-style' queries.""" @abstractmethod @@ -87,7 +86,6 @@ def _get_new_query(self): class NaryQuery(AbstractQuery): - """Base class for all compound queries that act on and merged N separate subqueries.""" @@ -149,7 +147,6 @@ def _convert_to_new_query(self, subqueries): class AndQuery(NaryQuery): - """Compound query that returns the intersection of the results of the subqueries.""" @@ -160,8 +157,7 @@ def __init__(self, *args): *args (AbstractQuery, str, or list): the subqueries to be performed """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries (e.g., hatchet.query.ConjunctionQuery) instead.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries (e.g., hatchet.query.ConjunctionQuery) instead.", DeprecationWarning, stacklevel=2, ) @@ -181,7 +177,6 @@ def _convert_to_new_query(self, subqueries): class OrQuery(NaryQuery): - """Compound query that returns the union of the results of the subqueries""" @@ -192,8 +187,7 @@ def __init__(self, *args): *args (AbstractQuery, str, or list): the subqueries to be performed """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries (e.g., hatchet.query.DisjunctionQuery) instead.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries (e.g., hatchet.query.DisjunctionQuery) instead.", DeprecationWarning, stacklevel=2, ) @@ -213,7 +207,6 @@ def _convert_to_new_query(self, subqueries): class XorQuery(NaryQuery): - """Compound query that returns the symmetric difference (i.e., set-based XOR) of the results of the subqueries""" @@ -224,8 +217,7 @@ def __init__(self, *args): *args (AbstractQuery, str, or list): the subqueries to be performed """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries (e.g., hatchet.query.ExclusiveDisjunctionQuery) instead.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries (e.g., hatchet.query.ExclusiveDisjunctionQuery) instead.", DeprecationWarning, stacklevel=2, ) @@ -245,7 +237,6 @@ def _convert_to_new_query(self, subqueries): class NotQuery(NaryQuery): - """Compound query that returns all nodes in the GraphFrame that are not returned from the subquery.""" @@ -256,8 +247,7 @@ def __init__(self, *args): *args (AbstractQuery, str, or list): the subquery to be performed """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries (e.g., hatchet.query.NegationQuery) instead.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries (e.g., hatchet.query.NegationQuery) instead.", DeprecationWarning, stacklevel=2, ) @@ -273,7 +263,6 @@ def _convert_to_new_query(self, subqueries): class QueryMatcher(AbstractQuery): - """Processes and applies base syntax queries and Object-based queries to GraphFrames.""" def __init__(self, query=None): @@ -284,10 +273,7 @@ def __init__(self, query=None): into its internal representation """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries instead. For QueryMatcher, the equivalent \ - new-style queries are hatchet.query.Query for base-syntax queries and \ - hatchet.query.ObjectQuery for the object-dialect.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries instead. For QueryMatcher, the equivalent new-style queries are hatchet.query.Query for base-syntax queries and hatchet.query.ObjectQuery for the object-dialect.", DeprecationWarning, stacklevel=2, ) @@ -348,7 +334,6 @@ def _get_new_query(self): class CypherQuery(QueryMatcher): - """Processes and applies Strinb-based queries to GraphFrames.""" def __init__(self, cypher_query): @@ -358,9 +343,7 @@ def __init__(self, cypher_query): cypher_query (str): the String-based query """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries instead. For CypherQuery, the equivalent \ - new-style query is hatchet.query.StringQuery.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries instead. For CypherQuery, the equivalent new-style query is hatchet.query.StringQuery.", DeprecationWarning, stacklevel=2, ) @@ -386,8 +369,7 @@ def parse_cypher_query(cypher_query): (CypherQuery): a Hatchet query for this String-based query """ warnings.warn( - "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the \ - future. Please use new-style queries (e.g., hatchet.query.parse_string_dialect) instead.", + "Old-style queries are deprecated as of Hatchet 2023.1.0 and will be removed in the future. Please use new-style queries (e.g., hatchet.query.parse_string_dialect) instead.", DeprecationWarning, stacklevel=2, ) diff --git a/hatchet/query/object_dialect.py b/hatchet/query/object_dialect.py index e4010377..daf55c65 100644 --- a/hatchet/query/object_dialect.py +++ b/hatchet/query/object_dialect.py @@ -106,6 +106,9 @@ def filter_single_series(df_row, key, single_value): matches = True for k, v in attr_filter.items(): + metric_name = k + if isinstance(k, (tuple, list)) and len(k) == 1: + metric_name = k[0] try: _ = iter(v) # Manually raise TypeError if v is a string so that @@ -114,10 +117,12 @@ def filter_single_series(df_row, key, single_value): raise TypeError # Runs if v is not iterable (e.g., list, tuple, etc.) except TypeError: - matches = matches and filter_single_series(df_row, k, v) + matches = matches and filter_single_series(df_row, metric_name, v) else: for single_value in v: - matches = matches and filter_single_series(df_row, k, single_value) + matches = matches and filter_single_series( + df_row, metric_name, single_value + ) return matches def filter_dframe(df_row): @@ -186,16 +191,19 @@ def filter_single_dframe(node, df_row, key, single_value): matches = True node = df_row.name.to_frame().index[0][0] for k, v in attr_filter.items(): + metric_name = k + if isinstance(k, (tuple, list)) and len(k) == 1: + metric_name = k[0] try: _ = iter(v) if isinstance(v, str): raise TypeError except TypeError: - matches = matches and filter_single_dframe(node, df_row, k, v) + matches = matches and filter_single_dframe(node, df_row, metric_name, v) else: for single_value in v: matches = matches and filter_single_dframe( - node, df_row, k, single_value + node, df_row, metric_name, single_value ) return matches @@ -208,7 +216,6 @@ def filter_choice(df_row): class ObjectQuery(Query): - """Class for representing and parsing queries using the Object-based dialect.""" def __init__(self, query, multi_index_mode="off"): diff --git a/hatchet/query/string_dialect.py b/hatchet/query/string_dialect.py index ed4e8714..1fb83c48 100644 --- a/hatchet/query/string_dialect.py +++ b/hatchet/query/string_dialect.py @@ -18,7 +18,7 @@ # PEG grammar for the String-based dialect -CYPHER_GRAMMAR = u""" +CYPHER_GRAMMAR = """ FullQuery: path_expr=MatchExpr(cond_expr=WhereExpr)?; MatchExpr: 'MATCH' path=PathQuery; PathQuery: '(' nodes=NodeExpr ')'('->' '(' nodes=NodeExpr ')')*; @@ -32,26 +32,28 @@ UnaryCond: NotCond | SingleCond; NotCond: 'NOT' subcond=SingleCond; SingleCond: StringCond | NumberCond | NoneCond | NotNoneCond | LeafCond | NotLeafCond; -NoneCond: name=ID '.' prop=STRING 'IS NONE'; -NotNoneCond: name=ID '.' prop=STRING 'IS NOT NONE'; +NoneCond: name=ID '.' prop=MetricId 'IS NONE'; +NotNoneCond: name=ID '.' prop=MetricId 'IS NOT NONE'; LeafCond: name=ID 'IS LEAF'; NotLeafCond: name=ID 'IS NOT LEAF'; StringCond: StringEq | StringStartsWith | StringEndsWith | StringContains | StringMatch; -StringEq: name=ID '.' prop=STRING '=' val=STRING; -StringStartsWith: name=ID '.' prop=STRING 'STARTS WITH' val=STRING; -StringEndsWith: name=ID '.' prop=STRING 'ENDS WITH' val=STRING; -StringContains: name=ID '.' prop=STRING 'CONTAINS' val=STRING; -StringMatch: name=ID '.' prop=STRING '=~' val=STRING; +StringEq: name=ID '.' prop=MetricId '=' val=STRING; +StringStartsWith: name=ID '.' prop=MetricId 'STARTS WITH' val=STRING; +StringEndsWith: name=ID '.' prop=MetricId 'ENDS WITH' val=STRING; +StringContains: name=ID '.' prop=MetricId 'CONTAINS' val=STRING; +StringMatch: name=ID '.' prop=MetricId '=~' val=STRING; NumberCond: NumEq | NumLt | NumGt | NumLte | NumGte | NumNan | NumNotNan | NumInf | NumNotInf; -NumEq: name=ID '.' prop=STRING '=' val=NUMBER; -NumLt: name=ID '.' prop=STRING '<' val=NUMBER; -NumGt: name=ID '.' prop=STRING '>' val=NUMBER; -NumLte: name=ID '.' prop=STRING '<=' val=NUMBER; -NumGte: name=ID '.' prop=STRING '>=' val=NUMBER; -NumNan: name=ID '.' prop=STRING 'IS NAN'; -NumNotNan: name=ID '.' prop=STRING 'IS NOT NAN'; -NumInf: name=ID '.' prop=STRING 'IS INF'; -NumNotInf: name=ID '.' prop=STRING 'IS NOT INF'; +NumEq: name=ID '.' prop=MetricId '=' val=NUMBER; +NumLt: name=ID '.' prop=MetricId '<' val=NUMBER; +NumGt: name=ID '.' prop=MetricId '>' val=NUMBER; +NumLte: name=ID '.' prop=MetricId '<=' val=NUMBER; +NumGte: name=ID '.' prop=MetricId '>=' val=NUMBER; +NumNan: name=ID '.' prop=MetricId 'IS NAN'; +NumNotNan: name=ID '.' prop=MetricId 'IS NOT NAN'; +NumInf: name=ID '.' prop=MetricId 'IS INF'; +NumNotInf: name=ID '.' prop=MetricId 'IS NOT INF'; +MetricId: '(' ids+=SingleMetricId[','] ')' | ids=SingleMetricId; +SingleMetricId: INT | STRING; """ # TextX metamodel for the String-based dialect @@ -85,8 +87,14 @@ def filter_check_types(type_check, df_row, filt_lambda): return False -class StringQuery(Query): +######################################################################## +# NOTE: the use of single and double quotes in processing string-dialect +# queries is EXTREMELY important. Inner strings (e.g., for metric +# names) MUST use single quotes. +######################################################################## + +class StringQuery(Query): """Class for representing and parsing queries using the String-based dialect.""" def __init__(self, cypher_query, multi_index_mode="off"): @@ -275,14 +283,14 @@ def _parse_single_cond(self, obj): def _parse_none(self, obj): """Parses 'property IS NONE'.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "df_row.name._depth is None", None, ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -292,7 +300,11 @@ def _parse_none(self, obj): return [ None, obj.name, - 'df_row["{}"] is None'.format(obj.prop), + "df_row[{}] is None".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), None, ] @@ -302,47 +314,43 @@ def _add_aggregation_call_to_multi_idx_predicate(self, predicate): return predicate + ".all()" def _parse_none_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "df_row.index.get_level_values('node')[0]._depth is None", None, ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, "df_row.index.get_level_values('node')[0]._hatchet_nid is None", None, ] - if self.multi_index_mode == "any": - return [ - None, - obj.name, - "df_row['{}'].apply(lambda elem: elem is None).any()".format(obj.prop), - None, - ] - # if self.multi_index_mode == "all": return [ None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - "df_row['{}'].apply(lambda elem: elem is None)".format(obj.prop) + "df_row[{}].apply(lambda elem: elem is None)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ) ), None, ] def _parse_not_none(self, obj): """Parses 'property IS NOT NONE'.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "df_row.name._depth is not None", None, ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -352,19 +360,23 @@ def _parse_not_none(self, obj): return [ None, obj.name, - 'df_row["{}"] is not None'.format(obj.prop), + "df_row[{}] is not None".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), None, ] def _parse_not_none_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "df_row.index.get_level_values('node')[0]._depth is not None", None, ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -375,7 +387,11 @@ def _parse_not_none_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - "df_row['{}'].apply(lambda elem: elem is not None)".format(obj.prop) + "df_row[{}].apply(lambda elem: elem is not None)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ) ), None, ] @@ -465,8 +481,17 @@ def _parse_str_eq(self, obj): return [ None, obj.name, - 'df_row["{}"] == "{}"'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], str)".format(obj.prop), + 'df_row[{}] == "{}"'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], str)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_eq_multi_idx(self, obj): @@ -474,11 +499,18 @@ def _parse_str_eq_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem == "{}")'.format( - obj.prop, obj.val + 'df_row[{}].apply(lambda elem: elem == "{}")'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, ) ), - "is_string_dtype(df_row['{}'])".format(obj.prop), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_starts_with(self, obj): @@ -486,8 +518,17 @@ def _parse_str_starts_with(self, obj): return [ None, obj.name, - 'df_row["{}"].startswith("{}")'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], str)".format(obj.prop), + 'df_row[{}].startswith("{}")'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], str)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_starts_with_multi_idx(self, obj): @@ -495,11 +536,18 @@ def _parse_str_starts_with_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem.startswith("{}"))'.format( - obj.prop, obj.val + 'df_row[{}].apply(lambda elem: elem.startswith("{}"))'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, ) ), - "is_string_dtype(df_row['{}'])".format(obj.prop), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_ends_with(self, obj): @@ -507,8 +555,17 @@ def _parse_str_ends_with(self, obj): return [ None, obj.name, - 'df_row["{}"].endswith("{}")'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], str)".format(obj.prop), + 'df_row[{}].endswith("{}")'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], str)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_ends_with_multi_idx(self, obj): @@ -516,11 +573,18 @@ def _parse_str_ends_with_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem.endswith("{}"))'.format( - obj.prop, obj.val + 'df_row[{}].apply(lambda elem: elem.endswith("{}"))'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, ) ), - "is_string_dtype(df_row['{}'])".format(obj.prop), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_contains(self, obj): @@ -528,8 +592,17 @@ def _parse_str_contains(self, obj): return [ None, obj.name, - '"{}" in df_row["{}"]'.format(obj.val, obj.prop), - "isinstance(df_row['{}'], str)".format(obj.prop), + '"{}" in df_row[{}]'.format( + obj.val, + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + ), + "isinstance(df_row[{}], str)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_contains_multi_idx(self, obj): @@ -537,11 +610,18 @@ def _parse_str_contains_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: "{}" in elem)'.format( - obj.prop, obj.val + 'df_row[{}].apply(lambda elem: "{}" in elem)'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, ) ), - "is_string_dtype(df_row['{}'])".format(obj.prop), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_match(self, obj): @@ -549,8 +629,17 @@ def _parse_str_match(self, obj): return [ None, obj.name, - 're.match("{}", df_row["{}"]) is not None'.format(obj.val, obj.prop), - "isinstance(df_row['{}'], str)".format(obj.prop), + 're.match("{}", df_row[{}]) is not None'.format( + obj.val, + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + ), + "isinstance(df_row[{}], str)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_str_match_multi_idx(self, obj): @@ -558,11 +647,18 @@ def _parse_str_match_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: re.match("{}", elem) is not None)'.format( - obj.prop, obj.val + 'df_row[{}].apply(lambda elem: re.match("{}", elem) is not None)'.format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, ) ), - "is_string_dtype(df_row['{}'])".format(obj.prop), + "is_string_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num(self, obj): @@ -591,7 +687,7 @@ def _parse_num(self, obj): def _parse_num_eq(self, obj): """Processes numeric equivalence predicates.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: return [ None, @@ -623,7 +719,7 @@ def _parse_num_eq(self, obj): "df_row.name._depth == {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -651,12 +747,21 @@ def _parse_num_eq(self, obj): return [ None, obj.name, - 'df_row["{}"] == {}'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "df_row[{}] == {}".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_eq_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val == -1: return [ None, @@ -688,7 +793,7 @@ def _parse_num_eq_multi_idx(self, obj): "df_row.index.get_level_values('node')[0]._depth == {}".format(obj.val), "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -719,14 +824,23 @@ def _parse_num_eq_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem == {})'.format(obj.prop, obj.val) + "df_row[{}].apply(lambda elem: elem == {})".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_lt(self, obj): """Processes numeric less-than predicates.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -751,7 +865,7 @@ def _parse_num_lt(self, obj): "df_row.name._depth < {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -779,12 +893,21 @@ def _parse_num_lt(self, obj): return [ None, obj.name, - 'df_row["{}"] < {}'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "df_row[{}] < {}".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_lt_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -809,7 +932,7 @@ def _parse_num_lt_multi_idx(self, obj): "df_row.index.get_level_values('node')[0]._depth < {}".format(obj.val), "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -840,14 +963,23 @@ def _parse_num_lt_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem < {})'.format(obj.prop, obj.val) + "df_row[{}].apply(lambda elem: elem < {})".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_gt(self, obj): """Processes numeric greater-than predicates.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -872,7 +1004,7 @@ def _parse_num_gt(self, obj): "df_row.name._depth > {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -900,12 +1032,21 @@ def _parse_num_gt(self, obj): return [ None, obj.name, - 'df_row["{}"] > {}'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "df_row[{}] > {}".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_gt_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -930,7 +1071,7 @@ def _parse_num_gt_multi_idx(self, obj): "df_row.index.get_level_values('node')[0]._depth > {}".format(obj.val), "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -961,14 +1102,23 @@ def _parse_num_gt_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem > {})'.format(obj.prop, obj.val) + "df_row[{}].apply(lambda elem: elem > {})".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_lte(self, obj): """Processes numeric less-than-or-equal-to predicates.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -993,7 +1143,7 @@ def _parse_num_lte(self, obj): "df_row.name._depth <= {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -1021,12 +1171,21 @@ def _parse_num_lte(self, obj): return [ None, obj.name, - 'df_row["{}"] <= {}'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "df_row[{}] <= {}".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_lte_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -1051,7 +1210,7 @@ def _parse_num_lte_multi_idx(self, obj): "df_row.index.get_level_values('node')[0]._depth <= {}".format(obj.val), "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -1082,14 +1241,23 @@ def _parse_num_lte_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem <= {})'.format(obj.prop, obj.val) + "df_row[{}].apply(lambda elem: elem <= {})".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_gte(self, obj): """Processes numeric greater-than-or-equal-to predicates.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -1114,7 +1282,7 @@ def _parse_num_gte(self, obj): "df_row.name._depth >= {}".format(obj.val), "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": if obj.val < 0: warnings.warn( """ @@ -1142,12 +1310,21 @@ def _parse_num_gte(self, obj): return [ None, obj.name, - 'df_row["{}"] >= {}'.format(obj.prop, obj.val), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "df_row[{}] >= {}".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_gte_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": if obj.val < 0: warnings.warn( """ @@ -1172,7 +1349,7 @@ def _parse_num_gte_multi_idx(self, obj): "df_row.index.get_level_values('node')[0]._depth >= {}".format(obj.val), "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids == "node_id": if obj.val < 0: warnings.warn( """ @@ -1203,21 +1380,30 @@ def _parse_num_gte_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'df_row["{}"].apply(lambda elem: elem >= {})'.format(obj.prop, obj.val) + "df_row[{}].apply(lambda elem: elem >= {})".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]), + obj.val, + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_nan(self, obj): """Processes predicates that check for NaN.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "pd.isna(df_row.name._depth)", "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1227,19 +1413,27 @@ def _parse_num_nan(self, obj): return [ None, obj.name, - 'pd.isna(df_row["{}"])'.format(obj.prop), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_nan_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "pd.isna(df_row.index.get_level_values('node')[0]._depth)", "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1250,21 +1444,29 @@ def _parse_num_nan_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'pd.isna(df_row["{}"])'.format(obj.prop) + "pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_not_nan(self, obj): """Processes predicates that check for NaN.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "not pd.isna(df_row.name._depth)", "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1274,19 +1476,27 @@ def _parse_num_not_nan(self, obj): return [ None, obj.name, - 'not pd.isna(df_row["{}"])'.format(obj.prop), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "not pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_not_nan_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "not pd.isna(df_row.index.get_level_values('node')[0]._depth)", "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1297,21 +1507,29 @@ def _parse_num_not_nan_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'not pd.isna(df_row["{}"])'.format(obj.prop) + "not pd.isna(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_inf(self, obj): """Processes predicates that check for Infinity.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "np.isinf(df_row.name._depth)", "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1321,19 +1539,27 @@ def _parse_num_inf(self, obj): return [ None, obj.name, - 'np.isinf(df_row["{}"])'.format(obj.prop), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_inf_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "np.isinf(df_row.index.get_level_values('node')[0]._depth)", "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1344,21 +1570,29 @@ def _parse_num_inf_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'np.isinf(df_row["{}"])'.format(obj.prop) + "np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] def _parse_num_not_inf(self, obj): """Processes predicates that check for not-Infinity.""" - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "not np.isinf(df_row.name._depth)", "isinstance(df_row.name._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1368,19 +1602,27 @@ def _parse_num_not_inf(self, obj): return [ None, obj.name, - 'not np.isinf(df_row["{}"])'.format(obj.prop), - "isinstance(df_row['{}'], Real)".format(obj.prop), + "not np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), + "isinstance(df_row[{}], Real)".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ), ] def _parse_num_not_inf_multi_idx(self, obj): - if obj.prop == "depth": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "depth": return [ None, obj.name, "not np.isinf(df_row.index.get_level_values('node')[0]._depth)", "isinstance(df_row.index.get_level_values('node')[0]._depth, Real)", ] - if obj.prop == "node_id": + if len(obj.prop.ids) == 1 and obj.prop.ids[0] == "node_id": return [ None, obj.name, @@ -1391,9 +1633,17 @@ def _parse_num_not_inf_multi_idx(self, obj): None, obj.name, self._add_aggregation_call_to_multi_idx_predicate( - 'not np.isinf(df_row["{}"])'.format(obj.prop) + "not np.isinf(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) + ) + ), + "is_numeric_dtype(df_row[{}])".format( + str(tuple(obj.prop.ids)) + if len(obj.prop.ids) > 1 + else "'{}'".format(obj.prop.ids[0]) ), - "is_numeric_dtype(df_row['{}'])".format(obj.prop), ] diff --git a/hatchet/tests/query.py b/hatchet/tests/query.py index f2705aa3..a8082dea 100644 --- a/hatchet/tests/query.py +++ b/hatchet/tests/query.py @@ -45,10 +45,18 @@ def test_construct_object_dialect(): (3, {"time (inc)": 0.1}), {"name": "ibv[_a-zA-Z]*"}, ] + # Note: the comma's in the keys are necessary. In Python, creating a tuple + # from a single string results in a tuple containing every character of + # the string as a separate element. In other words, + # tuple("name") == ( "n", "a", "m", "e" ). + # The comma tells Python to create a tuple with a single element. In other words, + # ("name",) == tuple("name",) == ( "name" ) + path5 = [{("name",): "MPI_[_a-zA-Z]*"}, "*", {("name",): "ibv[_a-zA-Z]*"}] query1 = ObjectQuery(path1) query2 = ObjectQuery(path2) query3 = ObjectQuery(path3) query4 = ObjectQuery(path4) + query5 = ObjectQuery(path5) assert query1.query_pattern[0][0] == "." assert query1.query_pattern[0][1](mock_node_mpi) @@ -105,6 +113,17 @@ def test_construct_object_dialect(): assert not query4.query_pattern[3][1](mock_node_time_false) assert query4.query_pattern[4][0] == "." + assert query5.query_pattern[0][0] == "." + assert query5.query_pattern[0][1](mock_node_mpi) + assert not query5.query_pattern[0][1](mock_node_ibv) + assert not query5.query_pattern[0][1](mock_node_time_true) + assert query5.query_pattern[1][0] == "*" + assert query5.query_pattern[1][1](mock_node_mpi) + assert query5.query_pattern[1][1](mock_node_ibv) + assert query5.query_pattern[1][1](mock_node_time_true) + assert query5.query_pattern[1][1](mock_node_time_false) + assert query5.query_pattern[2][0] == "." + invalid_path = [ {"name": "MPI_[_a-zA-Z]*"}, ({"bad": "wildcard"}, {"time (inc)": 0.1}), @@ -821,22 +840,26 @@ def test_construct_string_dialect(): mock_node_ibv = {"name": "ibv_reg_mr"} mock_node_time_true = {"time (inc)": 0.1} mock_node_time_false = {"time (inc)": 0.001} - path1 = u"""MATCH (p)->("*")->(q) + path1 = """MATCH (p)->("*")->(q) WHERE p."name" STARTS WITH "MPI_" AND q."name" STARTS WITH "ibv" """ - path2 = u"""MATCH (p)->(2)->(q) + path2 = """MATCH (p)->(2)->(q) WHERE p."name" STARTS WITH "MPI_" AND q."name" STARTS WITH "ibv" """ - path3 = u"""MATCH (p)->("+", a)->(q) + path3 = """MATCH (p)->("+", a)->(q) WHERE p."name" STARTS WITH "MPI" AND a."time (inc)" >= 0.1 AND q."name" STARTS WITH "ibv" """ - path4 = u"""MATCH (p)->(3, a)->(q) + path4 = """MATCH (p)->(3, a)->(q) WHERE p."name" STARTS WITH "MPI" AND a."time (inc)" = 0.1 AND q."name" STARTS WITH "ibv" """ + path5 = """MATCH (p)->("*")->(q) + WHERE p.("name") STARTS WITH "MPI_" AND q.("name") STARTS WITH "ibv" + """ query1 = StringQuery(path1) query2 = StringQuery(path2) query3 = StringQuery(path3) query4 = StringQuery(path4) + query5 = StringQuery(path5) assert query1.query_pattern[0][0] == "." assert query1.query_pattern[0][1](mock_node_mpi) @@ -893,7 +916,18 @@ def test_construct_string_dialect(): assert not query4.query_pattern[3][1](mock_node_time_false) assert query4.query_pattern[4][0] == "." - invalid_path = u"""MATCH (p)->({"bad": "wildcard"}, a)->(q) + assert query5.query_pattern[0][0] == "." + assert query5.query_pattern[0][1](mock_node_mpi) + assert not query5.query_pattern[0][1](mock_node_ibv) + assert not query5.query_pattern[0][1](mock_node_time_true) + assert query5.query_pattern[1][0] == "*" + assert query5.query_pattern[1][1](mock_node_mpi) + assert query5.query_pattern[1][1](mock_node_ibv) + assert query5.query_pattern[1][1](mock_node_time_true) + assert query5.query_pattern[1][1](mock_node_time_false) + assert query5.query_pattern[2][0] == "." + + invalid_path = """MATCH (p)->({"bad": "wildcard"}, a)->(q) WHERE p."name" STARTS WITH "MPI" AND a."time (inc)" = 0.1 AND q."name" STARTS WITH "ibv" """ @@ -903,7 +937,7 @@ def test_construct_string_dialect(): def test_apply_string_dialect(mock_graph_literal): gf = GraphFrame.from_literal(mock_graph_literal) - path = u"""MATCH (p)->(2, q)->("*", r)->(s) + path = """MATCH (p)->(2, q)->("*", r)->(s) WHERE p."time (inc)" >= 30.0 AND NOT q."name" STARTS WITH "b" AND r."name" =~ "[^b][a-z]+" AND s."name" STARTS WITH "gr" """ @@ -920,7 +954,7 @@ def test_apply_string_dialect(mock_graph_literal): assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) - path = u"""MATCH (p)->(".")->(q)->("*") + path = """MATCH (p)->(".")->(q)->("*") WHERE p."time (inc)" >= 30.0 AND q."name" = "bar" """ match = [ @@ -933,14 +967,14 @@ def test_apply_string_dialect(mock_graph_literal): query = StringQuery(path) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) - path = u"""MATCH (p)->(q)->(r) + path = """MATCH (p)->(q)->(r) WHERE p."name" = "foo" AND q."name" = "bar" AND r."time" = 5.0 """ match = [root, root.children[0], root.children[0].children[0]] query = StringQuery(path) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) - path = u"""MATCH (p)->(q)->("+", r) + path = """MATCH (p)->(q)->("+", r) WHERE p."name" = "foo" AND q."name" = "qux" AND r."time (inc)" > 15.0 """ match = [ @@ -953,7 +987,7 @@ def test_apply_string_dialect(mock_graph_literal): query = StringQuery(path) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) - path = u"""MATCH (p)->(q) + path = """MATCH (p)->(q) WHERE p."time (inc)" > 100 OR p."time (inc)" <= 30 AND q."time (inc)" = 20 """ roots = gf.graph.roots @@ -966,21 +1000,21 @@ def test_apply_string_dialect(mock_graph_literal): query = StringQuery(path) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) - path = u"""MATCH (p)->("*", q)->(r) + path = """MATCH (p)->("*", q)->(r) WHERE p."name" = "this" AND q."name" = "is" AND r."name" = "nonsense" """ query = StringQuery(path) assert engine.apply(query, gf.graph, gf.dataframe) == [] - path = u"""MATCH (p)->("*")->(q) + path = """MATCH (p)->("*")->(q) WHERE p."name" = 5 AND q."name" = "whatever" """ with pytest.raises(InvalidQueryFilter): query = StringQuery(path) engine.apply(query, gf.graph, gf.dataframe) - path = u"""MATCH (p)->("*")->(q) + path = """MATCH (p)->("*")->(q) WHERE p."time" = "badstring" AND q."name" = "whatever" """ query = StringQuery(path) @@ -998,14 +1032,14 @@ def __init__(self): "list" ] = DummyType() gf = GraphFrame.from_literal(bad_field_test_dict) - path = u"""MATCH (p)->(q)->(r) + path = """MATCH (p)->(q)->(r) WHERE p."name" = "foo" AND q."name" = "bar" AND p."list" = DummyType() """ with pytest.raises(InvalidQueryPath): query = StringQuery(path) engine.apply(query, gf.graph, gf.dataframe) - path = u"""MATCH ("*")->(p)->(q)->("*") + path = """MATCH ("*")->(p)->(q)->("*") WHERE p."name" = "bar" AND q."name" = "grault" """ match = [ @@ -1052,7 +1086,7 @@ def __init__(self): query = StringQuery(path) assert sorted(engine.apply(query, gf.graph, gf.dataframe)) == sorted(match) - path = u"""MATCH ("*")->(p)->(q)->("+") + path = """MATCH ("*")->(p)->(q)->("+") WHERE p."name" = "bar" AND q."name" = "grault" """ query = StringQuery(path) @@ -1060,7 +1094,7 @@ def __init__(self): gf.dataframe["time"] = np.NaN gf.dataframe.at[gf.graph.roots[0], "time"] = 5.0 - path = u"""MATCH ("*", p) + path = """MATCH ("*", p) WHERE p."time" IS NOT NAN""" match = [gf.graph.roots[0]] query = StringQuery(path) @@ -1068,7 +1102,7 @@ def __init__(self): gf.dataframe["time"] = 5.0 gf.dataframe.at[gf.graph.roots[0], "time"] = np.NaN - path = u"""MATCH ("*", p) + path = """MATCH ("*", p) WHERE p."time" IS NAN""" match = [gf.graph.roots[0]] query = StringQuery(path) @@ -1076,7 +1110,7 @@ def __init__(self): gf.dataframe["time"] = np.Inf gf.dataframe.at[gf.graph.roots[0], "time"] = 5.0 - path = u"""MATCH ("*", p) + path = """MATCH ("*", p) WHERE p."time" IS NOT INF""" match = [gf.graph.roots[0]] query = StringQuery(path) @@ -1084,7 +1118,7 @@ def __init__(self): gf.dataframe["time"] = 5.0 gf.dataframe.at[gf.graph.roots[0], "time"] = np.Inf - path = u"""MATCH ("*", p) + path = """MATCH ("*", p) WHERE p."time" IS INF""" match = [gf.graph.roots[0]] query = StringQuery(path) @@ -1093,7 +1127,7 @@ def __init__(self): names = gf.dataframe["name"].copy() gf.dataframe["name"] = None gf.dataframe.at[gf.graph.roots[0], "name"] = names.iloc[0] - path = u"""MATCH ("*", p) + path = """MATCH ("*", p) WHERE p."name" IS NOT NONE""" match = [gf.graph.roots[0]] query = StringQuery(path) @@ -1101,7 +1135,7 @@ def __init__(self): gf.dataframe["name"] = names gf.dataframe.at[gf.graph.roots[0], "name"] = None - path = u"""MATCH ("*", p) + path = """MATCH ("*", p) WHERE p."name" IS NONE""" match = [gf.graph.roots[0]] query = StringQuery(path) @@ -1111,13 +1145,13 @@ def __init__(self): def test_string_conj_compound_query(mock_graph_literal): gf = GraphFrame.from_literal(mock_graph_literal) compound_query1 = parse_string_dialect( - u""" + """ {MATCH ("*", p) WHERE p."time (inc)" >= 20 AND p."time (inc)" <= 60} AND {MATCH ("*", p) WHERE p."time (inc)" >= 60} """ ) compound_query2 = parse_string_dialect( - u""" + """ MATCH ("*", p) WHERE {p."time (inc)" >= 20 AND p."time (inc)" <= 60} AND {p."time (inc)" >= 60} """ @@ -1139,13 +1173,13 @@ def test_string_conj_compound_query(mock_graph_literal): def test_string_disj_compound_query(mock_graph_literal): gf = GraphFrame.from_literal(mock_graph_literal) compound_query1 = parse_string_dialect( - u""" + """ {MATCH ("*", p) WHERE p."time (inc)" = 5.0} OR {MATCH ("*", p) WHERE p."time (inc)" = 10.0} """ ) compound_query2 = parse_string_dialect( - u""" + """ MATCH ("*", p) WHERE {p."time (inc)" = 5.0} OR {p."time (inc)" = 10.0} """ @@ -1174,13 +1208,13 @@ def test_string_disj_compound_query(mock_graph_literal): def test_cypher_exc_disj_compound_query(mock_graph_literal): gf = GraphFrame.from_literal(mock_graph_literal) compound_query1 = parse_string_dialect( - u""" + """ {MATCH ("*", p) WHERE p."time (inc)" >= 5.0 AND p."time (inc)" <= 10.0} XOR {MATCH ("*", p) WHERE p."time (inc)" = 10.0} """ ) compound_query2 = parse_string_dialect( - u""" + """ MATCH ("*", p) WHERE {p."time (inc)" >= 5.0 AND p."time (inc)" <= 10.0} XOR {p."time (inc)" = 10.0} """ @@ -1215,19 +1249,19 @@ def test_leaf_query(small_mock2): nonleaves = list(nodes - set(matches)) obj_query = ObjectQuery([{"depth": -1}]) str_query_numeric = parse_string_dialect( - u""" + """ MATCH (p) WHERE p."depth" = -1 """ ) str_query_is_leaf = parse_string_dialect( - u""" + """ MATCH (p) WHERE p IS LEAF """ ) str_query_is_not_leaf = parse_string_dialect( - u""" + """ MATCH (p) WHERE p IS NOT LEAF """ @@ -1265,7 +1299,7 @@ def test_string_dialect_all_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() query = StringQuery( - u"""MATCH (".")->("+", p) + """MATCH (".")->("+", p) WHERE p."time (inc)" >= 17983.0 """, multi_index_mode="all", @@ -1296,7 +1330,7 @@ def test_string_dialect_any_mode(tau_profile_dir): gf = GraphFrame.from_tau(tau_profile_dir) engine = QueryEngine() query = StringQuery( - u"""MATCH (".", p) + """MATCH (".", p) WHERE p."time" < 24.0 """, multi_index_mode="any", @@ -1314,7 +1348,7 @@ def test_multi_index_mode_assertion_error(tau_profile_dir): _ = ObjectQuery([".", ("*", {"name": "test"})], multi_index_mode="foo") with pytest.raises(AssertionError): _ = StringQuery( - u""" MATCH (".")->("*", p) + """ MATCH (".")->("*", p) WHERE p."name" = "test" """, multi_index_mode="foo", diff --git a/pyproject.toml b/pyproject.toml index 47e4ff69..4dec4d54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,5 @@ [build-system] -requires = [ - "setuptools", - "wheel", - "Cython", -] +requires = ["setuptools", "wheel", "Cython"] build-backend = "setuptools.build_meta" [tool.poetry] @@ -17,6 +13,23 @@ authors = [ ] license = "MIT" +[tool.ruff] +line-length = 88 +target-version = 'py37' +include = ['\.pyi?$'] +exclude = [ + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".tox", + ".venv", + "_build", + "buck-out", + "build", + "dist", +] + [tool.black] line-length = 88 target-version = ['py27', 'py35', 'py36', 'py37', 'py38']