Skip to content

Commit

Permalink
Use different procedure names with different versions of GDS (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-ince authored Dec 13, 2023
1 parent 7e79be6 commit 2cd465c
Show file tree
Hide file tree
Showing 12 changed files with 1,174 additions and 95 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,11 @@ jobs:
- name: install dependencies and build
run: poetry install
- name: run tests
run: poetry run tox -p
run: poetry run tox -m unit
- name: run integration tests
env:
gds_license: ${{ secrets.GDS_LICENSE }}
run: |
LICENSE_FILE=${{ runner.temp }}/license.tmp
echo "${gds_license}" > $LICENSE_FILE
GDS_LICENSE_FILE=$LICENSE_FILE poetry run tox -m integration
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[mypy]
python_version = 3.8
python_version = 3.9
warn_return_any = True
ignore_missing_imports = True
787 changes: 720 additions & 67 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ flake8-isort = "^6.1.1"
flake8-spellcheck = "^0.28.0"
flake8-comprehensions = "^3.14.0"
flake8-bandit = "^4.1.1"
testcontainers = "^3.7.1"
neo4j = "^5.15.0"

[build-system]
requires = ["poetry-core"]
Expand Down
30 changes: 27 additions & 3 deletions src/neo4j_arrow/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ class ClientState(Enum):
GRAPH_READY = "done"


class ProcedureNames:
nodes_single_property: str = "gds.graph.nodeProperty.stream"
nodes_multiple_property: str = "gds.graph.nodeProperties.stream"
edges_single_property: str = "gds.graph.relationshipProperty.stream"
edges_multiple_property: str = "gds.graph.relationshipProperties.stream"
edges_topology: str = "gds.graph.relationships.stream"


def procedure_names(version: Optional[str] = None) -> ProcedureNames:
if not version:
return ProcedureNames()
elif version.startswith("2.5"):
return ProcedureNames()
else:
names = ProcedureNames()
names.edges_topology = "gds.beta.graph.relationships.stream"
return names


class Neo4jArrowClient:
host: str
port: int
Expand Down Expand Up @@ -68,6 +87,7 @@ def __init__(
max_chunk_size: int = 10_000,
debug: bool = False,
logger: Optional[logging.Logger] = None,
proc_names: Optional[ProcedureNames] = None,
):
self.host = host
self.port = port
Expand All @@ -86,6 +106,9 @@ def __init__(
if not logger:
logger = logging.getLogger("Neo4jArrowClient")
self.logger = logger
if not proc_names:
proc_names = procedure_names()
self.proc_names = proc_names

def __str__(self) -> str:
return (
Expand Down Expand Up @@ -476,13 +499,13 @@ def read_edges(
if concurrency < 1:
raise ValueError("concurrency cannot be negative")
if properties:
procedure_name = "gds.graph.relationshipProperties.stream"
procedure_name = self.proc_names.edges_multiple_property
configuration = {
"relationship_properties": list(properties if properties is not None else []),
"relationship_types": list(relationship_types if relationship_types is not None else ["*"]),
}
else:
procedure_name = "gds.beta.graph.relationships.stream"
procedure_name = self.proc_names.edges_topology
configuration = {
"relationship_types": list(relationship_types if relationship_types is not None else ["*"]),
}
Expand Down Expand Up @@ -520,10 +543,11 @@ def read_nodes(
{
"graph_name": self.graph,
"database_name": self.database,
"procedure_name": "gds.graph.nodeProperties.stream",
"procedure_name": self.proc_names.nodes_multiple_property,
"configuration": {
"node_labels": list(labels if labels is not None else ["*"]),
"node_properties": list(properties if properties is not None else []),
"list_node_labels": True,
},
"concurrency": concurrency,
}
Expand Down
46 changes: 31 additions & 15 deletions src/neo4j_arrow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from typing import Any, Dict, Generic, List, Optional, Union, TypeVar


class ValidationError(Exception):
def __init__(self, message: str):
self.message = message


class _NodeEncoder(JSONEncoder):
def default(self, n: "Node") -> object:
return n.to_dict()
Expand All @@ -32,13 +37,14 @@ def __init__(
label: str = "",
label_field: str = "",
key_field: str,
**properties: Dict[str, Any],
properties: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
self._source = source
self._label = label
self._label_field = label_field
self._key_field = key_field
self._properties = properties
self._properties = dict(properties or {}, **kwargs)
self._pattern: Optional[re.Pattern[str]] = None
try:
self._pattern = re.compile(self._source)
Expand Down Expand Up @@ -81,12 +87,14 @@ def to_dict(self) -> Dict[str, Any]:
}

def validate(self) -> None:
if not self._source:
raise ValidationError(f"source must be provided in {self}")
if not self._key_field:
raise ValidationError(f"key_field must be provided in {self}")
if not self._label and not self._label_field:
raise Exception(f"either label or label_field must be provided in {self}")
raise ValidationError(f"either label or label_field must be provided in {self}")
if self._label and self._label_field:
raise Exception(f"use of label and label_field at the same time is not allowed " f"in {self}")
if not self._key_field:
raise Exception(f"empty key_field in {self}")
raise ValidationError(f"use of label and label_field at the same time is not allowed " f"in {self}")

def __str__(self) -> str:
return str(self.to_dict())
Expand All @@ -99,20 +107,26 @@ def __init__(
source: str,
edge_type: str = "",
type_field: str = "",
type: str = "",
source_field: str,
target_field: str,
**properties: Dict[str, Any],
properties: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
self._source = source
self._type = edge_type
if not self._type and type:
self._type = type
self._type_field = type_field
self._source_field = source_field
self._target_field = target_field
self._properties = properties
self._properties = dict(properties or {}, **kwargs)
self._pattern: Optional[re.Pattern[str]] = None
try:
self._pattern = re.compile(source)
except Exception:
except ValueError:
pass
except TypeError:
pass

@property
Expand Down Expand Up @@ -156,14 +170,16 @@ def to_dict(self) -> Dict[str, Any]:
}

def validate(self) -> None:
if not self._type_field and not self._type:
raise Exception(f"either type or type_field must be provided in {self}")
if self._type_field and self._type:
raise Exception(f"use of type and type_field at the same time is not allowed in {self}")
if not self._source:
raise ValidationError(f"source must be provided in {self}")
if not self._source_field:
raise Exception(f"empty source_field in {self}")
raise ValidationError(f"source_field must be provided in {self}")
if not self._target_field:
raise Exception(f"empty target_field in {self}")
raise ValidationError(f"target_field must be provided in {self}")
if not self._type_field and not self._type:
raise ValidationError(f"either type or type_field must be provided in {self}")
if self._type_field and self._type:
raise ValidationError(f"use of type and type_field at the same time is not allowed in {self}")

def __str__(self) -> str:
return str(self.to_dict())
Expand Down
71 changes: 71 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
from typing import Callable

import neo4j
import pytest
import testcontainers.neo4j

import neo4j_arrow._client
from neo4j_arrow import Neo4jArrowClient


def gds_version(driver: neo4j.Driver) -> str:
with driver.session() as session:
version = session.run(
"CALL gds.debug.sysInfo() YIELD key, value WITH * WHERE key = $key RETURN value", {"key": "gdsVersion"}
).single(strict=True)[0]
return version


@pytest.fixture(scope="module")
def neo4j():
container = (
testcontainers.neo4j.Neo4jContainer(os.getenv("NEO4J_IMAGE", "neo4j:5-enterprise"))
.with_volume_mapping(os.getenv("GDS_LICENSE_FILE", "/tmp/gds.license"), "/licenses/gds.license")
.with_env("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes")
.with_env("NEO4J_PLUGINS", '["graph-data-science"]')
.with_env("NEO4J_gds_enterprise_license__file", "/licenses/gds.license")
.with_env("NEO4J_dbms_security_procedures_unrestricted", "gds.*")
.with_env("NEO4J_dbms_security_procedures_allowlist", "gds.*")
.with_env("NEO4J_gds_arrow_enabled", "true")
.with_env("NEO4J_gds_arrow_listen__address", "0.0.0.0")
.with_exposed_ports(7687, 7474, 8491)
)
container.start()

yield container

container.stop()


@pytest.fixture(scope="module")
def driver(neo4j):
driver = neo4j.get_driver()

yield driver

driver.close()


@pytest.fixture(scope="module")
def arrow_client_factory(neo4j, driver) -> Callable[[str], Neo4jArrowClient]:
def _arrow_client_factory(graph_name: str) -> Neo4jArrowClient:
return Neo4jArrowClient(
neo4j.get_container_host_ip(),
graph=graph_name,
user=neo4j.NEO4J_USER,
password=neo4j.NEO4J_ADMIN_PASSWORD,
port=int(neo4j.get_exposed_port(8491)),
tls=False,
proc_names=neo4j_arrow._client.procedure_names(gds_version(driver)),
)

return _arrow_client_factory


@pytest.fixture(autouse=True)
def setup(driver, arrow_client_factory):
with driver.session() as session:
session.run("CREATE OR REPLACE DATABASE neo4j WAIT").consume()

yield
Loading

0 comments on commit 2cd465c

Please sign in to comment.