diff --git a/libs/community/langchain_community/graphs/age_graph.py b/libs/community/langchain_community/graphs/age_graph.py index 434491253c6ab..116791ee5c070 100644 --- a/libs/community/langchain_community/graphs/age_graph.py +++ b/libs/community/langchain_community/graphs/age_graph.py @@ -473,71 +473,78 @@ def _get_col_name(field: str, idx: int) -> str: @staticmethod def _wrap_query(query: str, graph_name: str) -> str: """ - Convert a cypher query to an Apache Age compatible - sql query by wrapping the cypher query in ag_catalog.cypher, - casting results to agtype and building a select statement + Convert a Cyper query to an Apache Age compatible Sql Query. + Handles combined queries with UNION/EXCEPT operators Args: - query (str): a valid cypher query - graph_name (str): the name of the graph to query + query (str) : A valid cypher query, can include UNION/EXCEPT operators + graph_name (str) : The name of the graph to query - Returns: - str: an equivalent pgsql query + Returns : + str : An equivalent pgSql query wrapped with ag_catalog.cypher + + Raises: + ValueError : If query is empty, contain RETURN *, or has invalid field names """ + if not query.strip(): + raise ValueError("Empty query provided") + # pgsql template template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ {query} $$) AS ({fields});""" - # if there are any returned fields they must be added to the pgsql query - return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE) - if return_match: - # Extract the part of the query after the RETURN keyword - return_clause = query[return_match.end() :] - - # parse return statement to identify returned fields - fields = ( - return_clause.lower() - .split("distinct")[-1] - .split("order by")[0] - .split("skip")[0] - .split("limit")[0] - .split(",") - ) - - # raise exception if RETURN * is found as we can't resolve the fields - if "*" in [x.strip() for x in fields]: - raise ValueError( - "AGE graph does not support 'RETURN *'" - + " statements in Cypher queries" + # split the query into parts based on UNION and EXCEPT + parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE) + + all_fields = [] + + for part in parts: + if part.strip().upper() in ("UNION", "EXCEPT"): + continue + + # if there are any returned fields they must be added to the pgsql query + return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE) + if return_match: + # Extract the part of the query after the RETURN keyword + return_clause = part[return_match.end() :] + + # parse return statement to identify returned fields + fields = ( + return_clause.lower() + .split("distinct")[-1] + .split("order by")[0] + .split("skip")[0] + .split("limit")[0] + .split(",") ) - # get pgsql formatted field names - fields = [ - AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields) - ] - - # build resulting pgsql relation - fields_str = ", ".join( - [ - field.split(".")[-1] + " agtype" - for field in fields - if field.split(".")[-1] - ] - ) + # raise exception if RETURN * is found as we can't resolve the fields + clean_fileds = [f.strip() for f in fields if f.strip()] + if "*" in clean_fileds: + raise ValueError( + "Apache Age does not support RETURN * in Cypher queries" + ) - # if no return statement we still need to return a single field of type agtype - else: + # Format fields and maintain order of appearance + for idx, field in enumerate(clean_fileds): + field_name = AGEGraph._get_col_name(field, idx) + if field_name not in all_fields: + all_fields.append(field_name) + + # if no return statements found in any part + if not all_fields: fields_str = "a agtype" - select_str = "*" + else: + fields_str = ", ".join(f"{field} agtype" for field in all_fields) return template.format( graph_name=graph_name, query=query, fields=fields_str, - projection=select_str, + projection="*", ) @staticmethod diff --git a/libs/community/tests/unit_tests/graphs/test_age_graph.py b/libs/community/tests/unit_tests/graphs/test_age_graph.py index 6981c16d88aa5..19ff90803e007 100644 --- a/libs/community/tests/unit_tests/graphs/test_age_graph.py +++ b/libs/community/tests/unit_tests/graphs/test_age_graph.py @@ -53,6 +53,7 @@ def test_get_col_name(self) -> None: self.assertEqual(AGEGraph._get_col_name(*value), expected[idx]) def test_wrap_query(self) -> None: + """Test basic query wrapping functionality.""" inputs = [ # Positive case: Simple return clause """ @@ -76,46 +77,195 @@ def test_wrap_query(self) -> None: expected = [ # Expected output for the first positive case + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (keanu:Person {name:'Keanu Reeves'}) + RETURN keanu.name AS name, keanu.born AS born + $$) AS (name agtype, born agtype); + """, + # Second test case (no RETURN clause) + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MERGE (n:a {id: 1}) + $$) AS (a agtype); + """, + # Expected output for the negative cases (no RETURN clause) + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (n {description: "This will return a value"}) + MERGE (n)-[:RELATED]->(m) + $$) AS (a agtype); + """, + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (n {returnValue: "some value"}) + MERGE (n)-[:RELATED]->(m) + $$) AS (a agtype); + """, + ] + + for idx, value in enumerate(inputs): + result = AGEGraph._wrap_query(value, "test") + expected_result = expected[idx] + self.assertEqual( + re.sub(r"\s", "", result), + re.sub(r"\s", "", expected_result), + ( + f"Failed on test case {idx + 1}\n" + f"Input:\n{value}\n" + f"Expected:\n{expected_result}\n" + f"Got:\n{result}" + ), + ) + + def test_wrap_query_union_except(self) -> None: + """Test query wrapping with UNION and EXCEPT operators.""" + inputs = [ + # UNION case + """ + MATCH (n:Person) + RETURN n.name AS name, n.age AS age + UNION + MATCH (n:Employee) + RETURN n.name AS name, n.salary AS salary + """, + """ + MATCH (a:Employee {name: "Alice"}) + RETURN a.name AS name + UNION + MATCH (b:Manager {name: "Bob"}) + RETURN b.name AS name + """, + # Complex UNION case + """ + MATCH (n)-[r]->(m) + RETURN n.name AS source, type(r) AS relationship, m.name AS target + UNION + MATCH (m)-[r]->(n) + RETURN m.name AS source, type(r) AS relationship, n.name AS target + """, + """ + MATCH (a:Person)-[:FRIEND]->(b:Person) + WHERE a.age > 30 + RETURN a.name AS name + UNION + MATCH (c:Person)-[:FRIEND]->(d:Person) + WHERE c.age < 25 + RETURN c.name AS name + """, + # EXCEPT case + """ + MATCH (n:Person) + RETURN n.name AS name + EXCEPT + MATCH (n:Employee) + RETURN n.name AS name + """, + """ + MATCH (a:Person) + RETURN a.name AS name, a.age AS age + EXCEPT + MATCH (b:Person {name: "Alice", age: 30}) + RETURN b.name AS name, b.age AS age + """, + ] + + expected = [ """ SELECT * FROM ag_catalog.cypher('test', $$ - MATCH (keanu:Person {name:'Keanu Reeves'}) - RETURN keanu.name AS name, keanu.born AS born - $$) AS (name agtype, born agtype); + MATCH (n:Person) + RETURN n.name AS name, n.age AS age + UNION + MATCH (n:Employee) + RETURN n.name AS name, n.salary AS salary + $$) AS (name agtype, age agtype, salary agtype); """, """ SELECT * FROM ag_catalog.cypher('test', $$ - MERGE (n:a {id: 1}) - $$) AS (a agtype); + MATCH (a:Employee {name: "Alice"}) + RETURN a.name AS name + UNION + MATCH (b:Manager {name: "Bob"}) + RETURN b.name AS name + $$) AS (name agtype); """, - # Expected output for the negative cases (no return clause) """ SELECT * FROM ag_catalog.cypher('test', $$ - MATCH (n {description: "This will return a value"}) - MERGE (n)-[:RELATED]->(m) - $$) AS (a agtype); + MATCH (n)-[r]->(m) + RETURN n.name AS source, type(r) AS relationship, m.name AS target + UNION + MATCH (m)-[r]->(n) + RETURN m.name AS source, type(r) AS relationship, n.name AS target + $$) AS (source agtype, relationship agtype, target agtype); """, """ SELECT * FROM ag_catalog.cypher('test', $$ - MATCH (n {returnValue: "some value"}) - MERGE (n)-[:RELATED]->(m) - $$) AS (a agtype); + MATCH (a:Person)-[:FRIEND]->(b:Person) + WHERE a.age > 30 + RETURN a.name AS name + UNION + MATCH (c:Person)-[:FRIEND]->(d:Person) + WHERE c.age < 25 + RETURN c.name AS name + $$) AS (name agtype); + """, + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (n:Person) + RETURN n.name AS name + EXCEPT + MATCH (n:Employee) + RETURN n.name AS name + $$) AS (name agtype); + """, + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (a:Person) + RETURN a.name AS name, a.age AS age + EXCEPT + MATCH (b:Person {name: "Alice", age: 30}) + RETURN b.name AS name, b.age AS age + $$) AS (name agtype, age agtype); """, ] for idx, value in enumerate(inputs): + result = AGEGraph._wrap_query(value, "test") + expected_result = expected[idx] self.assertEqual( - re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")), - re.sub(r"\s", "", expected[idx]), + re.sub(r"\s", "", result), + re.sub(r"\s", "", expected_result), + ( + f"Failed on test case {idx + 1}\n" + f"Input:\n{value}\n" + f"Expected:\n{expected_result}\n" + f"Got:\n{result}" + ), ) - with self.assertRaises(ValueError): - AGEGraph._wrap_query( - """ + def test_wrap_query_errors(self) -> None: + """Test error cases for query wrapping.""" + error_cases = [ + # Empty query + "", + # Return * case + """ MATCH () RETURN * """, - "test", - ) + # Return * in UNION + """ + MATCH (n:Person) + RETURN n.name + UNION + MATCH () + RETURN * + """, + ] + + for query in error_cases: + with self.assertRaises(ValueError): + AGEGraph._wrap_query(query, "test") def test_format_properties(self) -> None: inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]