Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retrying to batch writing #8

Merged
merged 8 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: install poetry
run: |
python3 -m venv $POETRY_HOME
$POETRY_HOME/bin/pip install poetry==1.7.1
$POETRY_HOME/bin/pip install poetry==1.8.4
$POETRY_HOME/bin/poetry --version
- name: add poetry to path
run: echo "${POETRY_HOME}/bin" >> $GITHUB_PATH
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ need the `neo4j_arrow` directory.

# Build

This project uses [poetry]() as the build tool.
This project uses [poetry](https://python-poetry.org/) as the build tool.
Install `poetry`, define your environment with `poetry env use` and invoke `poetry install` to install dependencies.

To build;
Expand Down
1,045 changes: 571 additions & 474 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ packages = [{ include = "neo4j_arrow", from = "src" }]

[tool.poetry.dependencies]
python = ">=3.9, <4"
pyarrow = ">=10, <15"
pyarrow = ">=10, <18"
numpy = "<2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand All @@ -28,6 +29,7 @@ flake8-comprehensions = "^3.14.0"
flake8-bandit = "^4.1.1"
testcontainers = "^3.7.1"
neo4j = "^5.15.0"
pandas = "^2.2.3"

[build-system]
requires = ["poetry-core"]
Expand Down
45 changes: 25 additions & 20 deletions src/neo4j_arrow/_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections import abc
import time
from enum import Enum
import json

Expand All @@ -10,7 +10,6 @@
from .model import Graph

from typing import (
cast,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -276,41 +275,47 @@ def _rename_and_add_column(
def _write_batches(
self,
desc: Dict[str, Any],
batches: Union[pa.RecordBatch, Iterable[pa.RecordBatch]],
batches: List[pa.RecordBatch],
mapping_fn: Optional[MappingFn] = None,
) -> Result:
"""
Write PyArrow RecordBatches to the GDS Flight service.
"""
if isinstance(batches, abc.Iterable):
batches = iter(batches)
else:
batches = iter([batches])
if len(batches) == 0:
raise Exception("no record batches provided")

fn = mapping_fn or self._nop

first = next(batches, None)
if not first:
raise Exception("empty iterable of record batches provided")
first = cast(pa.RecordBatch, fn(first))
schema = fn(batches[0]).schema

client = self._client()
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(desc).encode("utf-8"))
n_rows, n_bytes = 0, 0
try:
writer, _ = client.do_put(upload_descriptor, first.schema, self.call_opts)
writer, metadata_reader = client.do_put(upload_descriptor, schema, self.call_opts)
with writer:
writer.write_batch(first)
n_rows += first.num_rows
n_bytes += first.get_total_buffer_size()
for remaining in batches:
writer.write_batch(fn(remaining))
n_rows += remaining.num_rows
n_bytes += remaining.get_total_buffer_size()
for batch in batches:
mapped_batch = fn(batch)
self._write_batch_with_retries(mapped_batch, writer)
metadata_reader.read() # read the ack message
n_rows += batch.num_rows
n_bytes += batch.get_total_buffer_size()
except Exception as e:
raise error.interpret(e)
return n_rows, n_bytes

def _write_batch_with_retries(self, mapped_batch, writer):
num_retries = 10
while True:
try:
writer.write_batch(mapped_batch)
break
except flight.FlightUnavailableError | flight.FlightTimedOutError | flight.FlightInternalError as e:
self.logger.exception(f"Encountered transient error; retrying {num_retries} more times ...")
time.sleep(0.1 / num_retries)
num_retries -= 1
if num_retries == 0:
raise e

def start(
self,
action: str = "CREATE_GRAPH",
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@


def gds_version(driver: neo4j.Driver) -> str:
with driver.session() as session:
data = session.run("CALL gds.debug.sysInfo()").to_df()
print(data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-ince Do you want to keep this print still?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please, because it took quite a bit of time to figure out that the GDS license on CI was expired :)


with driver.session() as session:
version = session.run(
"CALL gds.debug.sysInfo() YIELD key, value WITH * WHERE key = $key RETURN value", {"key": "gdsVersion"}
Expand All @@ -19,6 +23,9 @@ def gds_version(driver: neo4j.Driver) -> str:

@pytest.fixture(scope="module")
def neo4j():
testcontainers.neo4j.Neo4jContainer.NEO4J_USER = "neo4j"
testcontainers.neo4j.Neo4jContainer.NEO4J_ADMIN_PASSWORD = "password"

container = (
testcontainers.neo4j.Neo4jContainer(os.getenv("NEO4J_IMAGE", "neo4j:5-enterprise"))
.with_volume_mapping(os.getenv("GDS_LICENSE_FILE", "/tmp/gds.license"), "/licenses/gds.license")
Expand All @@ -31,6 +38,8 @@ def neo4j():
.with_env("NEO4J_gds_arrow_listen__address", "0.0.0.0")
.with_exposed_ports(7687, 7474, 8491)
)
container.NEO4J_USER = "neo4j"
container.NEO4J_ADMIN_PASSWORD = "password"
container.start()

yield container
Expand Down
12 changes: 8 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
[tox]
env_list = py{39,310,311}-pyarrow{10,11,12,13,14}-unit, py{312}-pyarrow{14}-unit, py{39,310,311,312}-neo4j{4.4,5}-integration
env_list = py{39,310,311}-pyarrow{10,11,12,13,14,15,16,17}-unit, py{312}-pyarrow{14,15,16,17}-unit, py{39,310,311,312}-neo4j{4.4,5}-integration

[testenv]
deps =
pytest
neo4j
testcontainers
pandas
pyarrow10: pyarrow >= 10.0, < 11.0
pyarrow11: pyarrow >= 11.0, < 12.0
pyarrow12: pyarrow >= 12.0, < 13.0
pyarrow13: pyarrow >= 13.0, < 14.0
pyarrow14: pyarrow >= 14.0, < 15.0
pyarrow15: pyarrow >= 15.0, < 16.0
pyarrow16: pyarrow >= 16.0, < 17.0
pyarrow17: pyarrow >= 17.0, < 18.0
warn_args =
py{39,310,311,312}: -W error
commands =
unit: python -m pytest {[testenv]warn_args} -v {posargs} tests/unit
integration: python -m pytest -v {posargs} tests/integration
integration: python -m pytest -rx -v {posargs} tests/integration

[testenv:py{39,310,311}-pyarrow{10,11,12,13,14}-unit]
[testenv:py{39,310,311}-pyarrow{10,11,12,13,14,15,16,17}-unit]
labels = unit

[testenv:py{312}-pyarrow{14}-unit]
[testenv:py{312}-pyarrow{14,15,16,17}-unit]
labels = unit

[testenv:py{39,310,311,312}-neo4j{4.4,5}-integration]
Expand Down
Loading