Skip to content

Commit

Permalink
Improve logging
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-ince committed Jun 9, 2023
1 parent 8fecaf2 commit fd3dacc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
58 changes: 34 additions & 24 deletions neo4j_arrow/_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import abc
from enum import Enum
import logging as log
import json

import pyarrow as pa
Expand Down Expand Up @@ -42,11 +42,12 @@ class Neo4jArrowClient:
debug: bool
client: flight.FlightClient
call_opts: flight.FlightCallOptions
logger: logging.Logger

def __init__(self, host: str, graph: str, *, port: int = 8491, database: str = "neo4j",
user: str = "neo4j", password: str = "neo4j", tls: bool = True,
concurrency: int = 4, timeout: Optional[float] = None, max_chunk_size: int = 10_000,
debug: bool = False):
debug: bool = False, logger: logging.Logger = None):
self.host = host
self.port = port
self.user = user
Expand All @@ -61,9 +62,14 @@ def __init__(self, host: str, graph: str, *, port: int = 8491, database: str = "
self.debug = debug
self.timeout = timeout
self.max_chunk_size = max_chunk_size
if not logger:
logger = logging.getLogger("Neo4jArrowClient")
self.logger = logger

def __str__(self):
return f"Neo4jArrowClient{{{self.user}@{self.host}:{self.port}/{self.database}?graph={self.graph}}}"
return f"Neo4jArrowClient{{{self.user}@{self.host}:{self.port}/{self.database}?graph={self.graph}" \
f"&encrypted={self.tls}&concurrency={self.concurrency}&debug={self.debug}&timeout={self.timeout}" \
f"&max_chunk_size={self.max_chunk_size}}}"

def __getstate__(self):
state = self.__dict__.copy()
Expand Down Expand Up @@ -164,7 +170,7 @@ def _map(data: Arrow) -> Arrow:
if node.label_field:
columns, fields = cls._rename_and_add_column(columns, fields, data, node.label_field, "labels")
for name in node.properties:
columns, fields = cls._rename_and_add_column(columns, fields, name, node.properties[name])
columns, fields = cls._rename_and_add_column(columns, fields, data, name, node.properties[name])

return data.from_arrays(columns, schema=pa.schema(fields))

Expand Down Expand Up @@ -232,20 +238,20 @@ def _write_batches(self, desc: Dict[str, Any],
upload_descriptor = flight.FlightDescriptor.for_command(
json.dumps(desc).encode("utf-8")
)
rows, nbytes = 0, 0
n_rows, n_bytes = 0, 0
try:
writer, _ = client.do_put(upload_descriptor, first.schema, self.call_opts)
with writer:
writer.write_batch(first)
rows += first.num_rows
nbytes += first.get_total_buffer_size()
n_rows += first.num_rows
n_bytes += first.get_total_buffer_size()
for remaining in batches:
writer.write_batch(fn(remaining))
rows += remaining.num_rows
nbytes += remaining.get_total_buffer_size()
n_rows += remaining.num_rows
n_bytes += remaining.get_total_buffer_size()
except Exception as e:
raise error.interpret(e)
return rows, nbytes
return n_rows, n_bytes

def start(self, action: str = "CREATE_GRAPH", *,
config: Dict[str, Any] = None,
Expand Down Expand Up @@ -276,20 +282,24 @@ def start_create_graph(self, *, force: bool = False, undirected_rel_types: Itera

return self._start("CREATE_GRAPH", config=config, force=force)

def start_create_database(self, *, force: bool = False, id_type: str = "INTEGER", id_property: str = "originalId",
record_format: str = "", high_io: bool = False, use_bad_collector: bool = False) -> \
def start_create_database(self, *, force: bool = False, id_type: str = "", id_property: str = "",
record_format: str = "", high_io: bool = True, use_bad_collector: bool = False) -> \
Dict[str, Any]:
config = {
"name": self.graph,
"concurrency": self.concurrency,
"id_type": id_type,
"id_property": id_property,
"record_format": record_format,
"high_io": high_io,
"use_bad_collector": use_bad_collector,
"force": force
}

if id_type:
config["id_type"] = id_type
if id_property:
config["id_property"] = id_property
if record_format:
config["record_format"] = record_format

return self._start("CREATE_DATABASE", config=config, force=force)

def _start(self, action: str = "CREATE_GRAPH", *,
Expand All @@ -309,30 +319,30 @@ def _start(self, action: str = "CREATE_GRAPH", *,
raise error.Neo4jArrowException(f"failed to start {action} for {config['name']}, got {result}")
except error.AlreadyExists as e:
if force:
log.warning(f"forcing cancellation of existing {action} import"
f" for {config['name']}")
self.logger.warning(f"forcing cancellation of existing {action} import"
f" for {config['name']}")
if self.abort():
return self._start(action, config=config)

log.error(f"{action} import job already exists for {config['name']}")
self.logger.error(f"{action} import job already exists for {config['name']}")
except Exception as e:
log.error(f"fatal error performing action {action}: {e}")
self.logger.error(f"fatal error performing action {action}: {e}")
raise e

return {}

def _write_entities(self, desc: Dict[str, Any], entities: Union[Nodes, Edges], mapper: MappingFn) -> Result:
try:
if isinstance(entities, pa.Table):
nodes = mapper(entities).to_batches(max_chunksize=self.max_chunk_size)
entities = mapper(entities).to_batches(max_chunksize=self.max_chunk_size)
mapper = self._nop

return self._write_batches(desc, entities, mapper)
except error.NotFound as e:
log.error(f"no existing import job found for graph f{self.graph}")
self.logger.error(f"no existing import job found for graph f{self.graph}")
raise e
except Exception as e:
log.error(f"fatal error while feeding {desc['entity_type']}s for graph {self.graph}: {e}")
self.logger.error(f"fatal error while feeding {desc['entity_type']}s for graph {self.graph}: {e}")
raise e

def write_nodes(self, nodes: Nodes,
Expand Down Expand Up @@ -459,9 +469,9 @@ def abort(self, name: Optional[str] = None) -> bool:

raise error.Neo4jArrowException(f"invalid response for abort of graph {self.graph}, got {result}")
except error.NotFound as e:
log.warning(f"no existing import for {config['name']}")
self.logger.warning(f"no existing import for {config['name']}")
except Exception as e:
log.error(f"error aborting {config['name']}: {e}")
self.logger.error(f"error aborting {config['name']}: {e}")
return False

def wait(self, timeout: int = 0):
Expand Down
6 changes: 5 additions & 1 deletion neo4j_arrow/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

KnownExceptions = Union[ArrowException, FlightServerError, Exception]


def interpret(e: KnownExceptions) -> KnownExceptions:
"""
Try to figure out which exception occcurred based on the server response.
Try to figure out which exception occurred based on the server response.
"""
try:
message = "".join(e.args)
Expand All @@ -35,6 +36,7 @@ class Neo4jArrowException(Exception):
"""
Base class for neo4j_arrow exceptions.
"""

def __init__(self, message: str):
self.message = message

Expand All @@ -43,6 +45,7 @@ class UnknownError(Neo4jArrowException):
"""
We have no idea what is wrong :(
"""

def __init__(self, message: str):
# These errors have ugly stack traces often repeated. Try to beautify.
# nb. In reality there's an embedded gRPC dict-like message, but let's
Expand Down Expand Up @@ -73,6 +76,7 @@ class InvalidArgument(Neo4jArrowException):
"""
pass


class NotFound(Neo4jArrowException):
"""
The requested import process could not be found.
Expand Down

0 comments on commit fd3dacc

Please sign in to comment.