Skip to content

Commit

Permalink
Add retrying to batch writing (#8)
Browse files Browse the repository at this point in the history
* Add retrying to batch writing

- Fixed number of retries (10)
- Retries on three different flight errors

Co-authored-by: Max Kießling <max.kiessling@neotechnology.com>

* Read metadata before returning

The ambition is to make sure to have received an ACK from the server
before continuing to send a DONE message.

Co-authored-by: Max Kießling <max.kiessling@neotechnology.com>

* Fix numpy to <2 version and add new pyarrow versions

* Explicitly set username and password

* Update poetry

* Diagnose what is wrong in CI

* Remove extra logging

* Read the ack message for each batch

---------

Co-authored-by: Max Kießling <max.kiessling@neotechnology.com>
Co-authored-by: Ali Ince <ali.ince@neo4j.com>
Co-authored-by: Max Kießling <max.kiessling@neo4j.com>
  • Loading branch information
4 people authored Oct 24, 2024
1 parent c222cf4 commit c6c40af
Show file tree
Hide file tree
Showing 7 changed files with 618 additions and 501 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: install poetry
run: |
python3 -m venv $POETRY_HOME
$POETRY_HOME/bin/pip install poetry==1.7.1
$POETRY_HOME/bin/pip install poetry==1.8.4
$POETRY_HOME/bin/poetry --version
- name: add poetry to path
run: echo "${POETRY_HOME}/bin" >> $GITHUB_PATH
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ need the `neo4j_arrow` directory.

# Build

This project uses [poetry]() as the build tool.
This project uses [poetry](https://python-poetry.org/) as the build tool.
Install `poetry`, define your environment with `poetry env use` and invoke `poetry install` to install dependencies.

To build;
Expand Down
1,045 changes: 571 additions & 474 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ packages = [{ include = "neo4j_arrow", from = "src" }]

[tool.poetry.dependencies]
python = ">=3.9, <4"
pyarrow = ">=10, <15"
pyarrow = ">=10, <18"
numpy = "<2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand All @@ -28,6 +29,7 @@ flake8-comprehensions = "^3.14.0"
flake8-bandit = "^4.1.1"
testcontainers = "^3.7.1"
neo4j = "^5.15.0"
pandas = "^2.2.3"

[build-system]
requires = ["poetry-core"]
Expand Down
45 changes: 25 additions & 20 deletions src/neo4j_arrow/_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections import abc
import time
from enum import Enum
import json

Expand All @@ -10,7 +10,6 @@
from .model import Graph

from typing import (
cast,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -276,41 +275,47 @@ def _rename_and_add_column(
def _write_batches(
self,
desc: Dict[str, Any],
batches: Union[pa.RecordBatch, Iterable[pa.RecordBatch]],
batches: List[pa.RecordBatch],
mapping_fn: Optional[MappingFn] = None,
) -> Result:
"""
Write PyArrow RecordBatches to the GDS Flight service.
"""
if isinstance(batches, abc.Iterable):
batches = iter(batches)
else:
batches = iter([batches])
if len(batches) == 0:
raise Exception("no record batches provided")

fn = mapping_fn or self._nop

first = next(batches, None)
if not first:
raise Exception("empty iterable of record batches provided")
first = cast(pa.RecordBatch, fn(first))
schema = fn(batches[0]).schema

client = self._client()
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(desc).encode("utf-8"))
n_rows, n_bytes = 0, 0
try:
writer, _ = client.do_put(upload_descriptor, first.schema, self.call_opts)
writer, metadata_reader = client.do_put(upload_descriptor, schema, self.call_opts)
with writer:
writer.write_batch(first)
n_rows += first.num_rows
n_bytes += first.get_total_buffer_size()
for remaining in batches:
writer.write_batch(fn(remaining))
n_rows += remaining.num_rows
n_bytes += remaining.get_total_buffer_size()
for batch in batches:
mapped_batch = fn(batch)
self._write_batch_with_retries(mapped_batch, writer)
metadata_reader.read() # read the ack message
n_rows += batch.num_rows
n_bytes += batch.get_total_buffer_size()
except Exception as e:
raise error.interpret(e)
return n_rows, n_bytes

def _write_batch_with_retries(self, mapped_batch, writer):
num_retries = 10
while True:
try:
writer.write_batch(mapped_batch)
break
except flight.FlightUnavailableError | flight.FlightTimedOutError | flight.FlightInternalError as e:
self.logger.exception(f"Encountered transient error; retrying {num_retries} more times ...")
time.sleep(0.1 / num_retries)
num_retries -= 1
if num_retries == 0:
raise e

def start(
self,
action: str = "CREATE_GRAPH",
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@


def gds_version(driver: neo4j.Driver) -> str:
with driver.session() as session:
data = session.run("CALL gds.debug.sysInfo()").to_df()
print(data)

with driver.session() as session:
version = session.run(
"CALL gds.debug.sysInfo() YIELD key, value WITH * WHERE key = $key RETURN value", {"key": "gdsVersion"}
Expand All @@ -19,6 +23,9 @@ def gds_version(driver: neo4j.Driver) -> str:

@pytest.fixture(scope="module")
def neo4j():
testcontainers.neo4j.Neo4jContainer.NEO4J_USER = "neo4j"
testcontainers.neo4j.Neo4jContainer.NEO4J_ADMIN_PASSWORD = "password"

container = (
testcontainers.neo4j.Neo4jContainer(os.getenv("NEO4J_IMAGE", "neo4j:5-enterprise"))
.with_volume_mapping(os.getenv("GDS_LICENSE_FILE", "/tmp/gds.license"), "/licenses/gds.license")
Expand All @@ -31,6 +38,8 @@ def neo4j():
.with_env("NEO4J_gds_arrow_listen__address", "0.0.0.0")
.with_exposed_ports(7687, 7474, 8491)
)
container.NEO4J_USER = "neo4j"
container.NEO4J_ADMIN_PASSWORD = "password"
container.start()

yield container
Expand Down
12 changes: 8 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
[tox]
env_list = py{39,310,311}-pyarrow{10,11,12,13,14}-unit, py{312}-pyarrow{14}-unit, py{39,310,311,312}-neo4j{4.4,5}-integration
env_list = py{39,310,311}-pyarrow{10,11,12,13,14,15,16,17}-unit, py{312}-pyarrow{14,15,16,17}-unit, py{39,310,311,312}-neo4j{4.4,5}-integration

[testenv]
deps =
pytest
neo4j
testcontainers
pandas
pyarrow10: pyarrow >= 10.0, < 11.0
pyarrow11: pyarrow >= 11.0, < 12.0
pyarrow12: pyarrow >= 12.0, < 13.0
pyarrow13: pyarrow >= 13.0, < 14.0
pyarrow14: pyarrow >= 14.0, < 15.0
pyarrow15: pyarrow >= 15.0, < 16.0
pyarrow16: pyarrow >= 16.0, < 17.0
pyarrow17: pyarrow >= 17.0, < 18.0
warn_args =
py{39,310,311,312}: -W error
commands =
unit: python -m pytest {[testenv]warn_args} -v {posargs} tests/unit
integration: python -m pytest -v {posargs} tests/integration
integration: python -m pytest -rx -v {posargs} tests/integration

[testenv:py{39,310,311}-pyarrow{10,11,12,13,14}-unit]
[testenv:py{39,310,311}-pyarrow{10,11,12,13,14,15,16,17}-unit]
labels = unit

[testenv:py{312}-pyarrow{14}-unit]
[testenv:py{312}-pyarrow{14,15,16,17}-unit]
labels = unit

[testenv:py{39,310,311,312}-neo4j{4.4,5}-integration]
Expand Down

0 comments on commit c6c40af

Please sign in to comment.