Skip to content

Commit

Permalink
community: fix issue #29429 in age_graph.py (#29506)
Browse files Browse the repository at this point in the history
## Description:

This PR addresses issue #29429 by fixing the _wrap_query method in
langchain_community/graphs/age_graph.py. The method now correctly
handles Cypher queries with UNION and EXCEPT operators, ensuring that
the fields in the SQL query are ordered as they appear in the Cypher
query. Additionally, the method now properly handles cases where RETURN
* is not supported.

### Issue: #29429

### Dependencies: None


### Add tests and docs:

Added unit tests in tests/unit_tests/graphs/test_age_graph.py to
validate the changes.
No new integrations were added, so no example notebook is necessary.
Lint and test:

Ran make format, make lint, and make test to ensure code quality and
functionality.
  • Loading branch information
rawathemant246 authored Feb 2, 2025
1 parent 2f97916 commit db1693a
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 64 deletions.
97 changes: 52 additions & 45 deletions libs/community/langchain_community/graphs/age_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,71 +473,78 @@ def _get_col_name(field: str, idx: int) -> str:
@staticmethod
def _wrap_query(query: str, graph_name: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Convert a Cyper query to an Apache Age compatible Sql Query.
Handles combined queries with UNION/EXCEPT operators
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
query (str) : A valid cypher query, can include UNION/EXCEPT operators
graph_name (str) : The name of the graph to query
Returns:
str: an equivalent pgsql query
Returns :
str : An equivalent pgSql query wrapped with ag_catalog.cypher
Raises:
ValueError : If query is empty, contain RETURN *, or has invalid field names
"""

if not query.strip():
raise ValueError("Empty query provided")

# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""

# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = query[return_match.end() :]

# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)

# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
# split the query into parts based on UNION and EXCEPT
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)

all_fields = []

for part in parts:
if part.strip().upper() in ("UNION", "EXCEPT"):
continue

# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = part[return_match.end() :]

# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)

# get pgsql formatted field names
fields = [
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
]

# build resulting pgsql relation
fields_str = ", ".join(
[
field.split(".")[-1] + " agtype"
for field in fields
if field.split(".")[-1]
]
)
# raise exception if RETURN * is found as we can't resolve the fields
clean_fileds = [f.strip() for f in fields if f.strip()]
if "*" in clean_fileds:
raise ValueError(
"Apache Age does not support RETURN * in Cypher queries"
)

# if no return statement we still need to return a single field of type agtype
else:
# Format fields and maintain order of appearance
for idx, field in enumerate(clean_fileds):
field_name = AGEGraph._get_col_name(field, idx)
if field_name not in all_fields:
all_fields.append(field_name)

# if no return statements found in any part
if not all_fields:
fields_str = "a agtype"

select_str = "*"
else:
fields_str = ", ".join(f"{field} agtype" for field in all_fields)

return template.format(
graph_name=graph_name,
query=query,
fields=fields_str,
projection=select_str,
projection="*",
)

@staticmethod
Expand Down
188 changes: 169 additions & 19 deletions libs/community/tests/unit_tests/graphs/test_age_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_get_col_name(self) -> None:
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])

def test_wrap_query(self) -> None:
"""Test basic query wrapping functionality."""
inputs = [
# Positive case: Simple return clause
"""
Expand All @@ -76,46 +77,195 @@ def test_wrap_query(self) -> None:

expected = [
# Expected output for the first positive case
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
""",
# Second test case (no RETURN clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
""",
# Expected output for the negative cases (no RETURN clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
]

for idx, value in enumerate(inputs):
result = AGEGraph._wrap_query(value, "test")
expected_result = expected[idx]
self.assertEqual(
re.sub(r"\s", "", result),
re.sub(r"\s", "", expected_result),
(
f"Failed on test case {idx + 1}\n"
f"Input:\n{value}\n"
f"Expected:\n{expected_result}\n"
f"Got:\n{result}"
),
)

def test_wrap_query_union_except(self) -> None:
"""Test query wrapping with UNION and EXCEPT operators."""
inputs = [
# UNION case
"""
MATCH (n:Person)
RETURN n.name AS name, n.age AS age
UNION
MATCH (n:Employee)
RETURN n.name AS name, n.salary AS salary
""",
"""
MATCH (a:Employee {name: "Alice"})
RETURN a.name AS name
UNION
MATCH (b:Manager {name: "Bob"})
RETURN b.name AS name
""",
# Complex UNION case
"""
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, type(r) AS relationship, n.name AS target
""",
"""
MATCH (a:Person)-[:FRIEND]->(b:Person)
WHERE a.age > 30
RETURN a.name AS name
UNION
MATCH (c:Person)-[:FRIEND]->(d:Person)
WHERE c.age < 25
RETURN c.name AS name
""",
# EXCEPT case
"""
MATCH (n:Person)
RETURN n.name AS name
EXCEPT
MATCH (n:Employee)
RETURN n.name AS name
""",
"""
MATCH (a:Person)
RETURN a.name AS name, a.age AS age
EXCEPT
MATCH (b:Person {name: "Alice", age: 30})
RETURN b.name AS name, b.age AS age
""",
]

expected = [
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
MATCH (n:Person)
RETURN n.name AS name, n.age AS age
UNION
MATCH (n:Employee)
RETURN n.name AS name, n.salary AS salary
$$) AS (name agtype, age agtype, salary agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
MATCH (a:Employee {name: "Alice"})
RETURN a.name AS name
UNION
MATCH (b:Manager {name: "Bob"})
RETURN b.name AS name
$$) AS (name agtype);
""",
# Expected output for the negative cases (no return clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, type(r) AS relationship, n.name AS target
$$) AS (source agtype, relationship agtype, target agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
MATCH (a:Person)-[:FRIEND]->(b:Person)
WHERE a.age > 30
RETURN a.name AS name
UNION
MATCH (c:Person)-[:FRIEND]->(d:Person)
WHERE c.age < 25
RETURN c.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n:Person)
RETURN n.name AS name
EXCEPT
MATCH (n:Employee)
RETURN n.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (a:Person)
RETURN a.name AS name, a.age AS age
EXCEPT
MATCH (b:Person {name: "Alice", age: 30})
RETURN b.name AS name, b.age AS age
$$) AS (name agtype, age agtype);
""",
]

for idx, value in enumerate(inputs):
result = AGEGraph._wrap_query(value, "test")
expected_result = expected[idx]
self.assertEqual(
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
re.sub(r"\s", "", expected[idx]),
re.sub(r"\s", "", result),
re.sub(r"\s", "", expected_result),
(
f"Failed on test case {idx + 1}\n"
f"Input:\n{value}\n"
f"Expected:\n{expected_result}\n"
f"Got:\n{result}"
),
)

with self.assertRaises(ValueError):
AGEGraph._wrap_query(
"""
def test_wrap_query_errors(self) -> None:
"""Test error cases for query wrapping."""
error_cases = [
# Empty query
"",
# Return * case
"""
MATCH ()
RETURN *
""",
"test",
)
# Return * in UNION
"""
MATCH (n:Person)
RETURN n.name
UNION
MATCH ()
RETURN *
""",
]

for query in error_cases:
with self.assertRaises(ValueError):
AGEGraph._wrap_query(query, "test")

def test_format_properties(self) -> None:
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]
Expand Down

0 comments on commit db1693a

Please sign in to comment.