From fd3dacc091357dab289924d8f9fe9fb52dfb491b Mon Sep 17 00:00:00 2001 From: Ali Ince Date: Fri, 9 Jun 2023 15:50:54 +0100 Subject: [PATCH] Improve logging --- neo4j_arrow/_client.py | 58 +++++++++++++++++++++++++----------------- neo4j_arrow/error.py | 6 ++++- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/neo4j_arrow/_client.py b/neo4j_arrow/_client.py index dec9a97..6f89104 100644 --- a/neo4j_arrow/_client.py +++ b/neo4j_arrow/_client.py @@ -1,6 +1,6 @@ +import logging from collections import abc from enum import Enum -import logging as log import json import pyarrow as pa @@ -42,11 +42,12 @@ class Neo4jArrowClient: debug: bool client: flight.FlightClient call_opts: flight.FlightCallOptions + logger: logging.Logger def __init__(self, host: str, graph: str, *, port: int = 8491, database: str = "neo4j", user: str = "neo4j", password: str = "neo4j", tls: bool = True, concurrency: int = 4, timeout: Optional[float] = None, max_chunk_size: int = 10_000, - debug: bool = False): + debug: bool = False, logger: logging.Logger = None): self.host = host self.port = port self.user = user @@ -61,9 +62,14 @@ def __init__(self, host: str, graph: str, *, port: int = 8491, database: str = " self.debug = debug self.timeout = timeout self.max_chunk_size = max_chunk_size + if not logger: + logger = logging.getLogger("Neo4jArrowClient") + self.logger = logger def __str__(self): - return f"Neo4jArrowClient{{{self.user}@{self.host}:{self.port}/{self.database}?graph={self.graph}}}" + return f"Neo4jArrowClient{{{self.user}@{self.host}:{self.port}/{self.database}?graph={self.graph}" \ + f"&encrypted={self.tls}&concurrency={self.concurrency}&debug={self.debug}&timeout={self.timeout}" \ + f"&max_chunk_size={self.max_chunk_size}}}" def __getstate__(self): state = self.__dict__.copy() @@ -164,7 +170,7 @@ def _map(data: Arrow) -> Arrow: if node.label_field: columns, fields = cls._rename_and_add_column(columns, fields, data, node.label_field, "labels") for name in node.properties: - columns, fields = cls._rename_and_add_column(columns, fields, name, node.properties[name]) + columns, fields = cls._rename_and_add_column(columns, fields, data, name, node.properties[name]) return data.from_arrays(columns, schema=pa.schema(fields)) @@ -232,20 +238,20 @@ def _write_batches(self, desc: Dict[str, Any], upload_descriptor = flight.FlightDescriptor.for_command( json.dumps(desc).encode("utf-8") ) - rows, nbytes = 0, 0 + n_rows, n_bytes = 0, 0 try: writer, _ = client.do_put(upload_descriptor, first.schema, self.call_opts) with writer: writer.write_batch(first) - rows += first.num_rows - nbytes += first.get_total_buffer_size() + n_rows += first.num_rows + n_bytes += first.get_total_buffer_size() for remaining in batches: writer.write_batch(fn(remaining)) - rows += remaining.num_rows - nbytes += remaining.get_total_buffer_size() + n_rows += remaining.num_rows + n_bytes += remaining.get_total_buffer_size() except Exception as e: raise error.interpret(e) - return rows, nbytes + return n_rows, n_bytes def start(self, action: str = "CREATE_GRAPH", *, config: Dict[str, Any] = None, @@ -276,20 +282,24 @@ def start_create_graph(self, *, force: bool = False, undirected_rel_types: Itera return self._start("CREATE_GRAPH", config=config, force=force) - def start_create_database(self, *, force: bool = False, id_type: str = "INTEGER", id_property: str = "originalId", - record_format: str = "", high_io: bool = False, use_bad_collector: bool = False) -> \ + def start_create_database(self, *, force: bool = False, id_type: str = "", id_property: str = "", + record_format: str = "", high_io: bool = True, use_bad_collector: bool = False) -> \ Dict[str, Any]: config = { "name": self.graph, "concurrency": self.concurrency, - "id_type": id_type, - "id_property": id_property, - "record_format": record_format, "high_io": high_io, "use_bad_collector": use_bad_collector, "force": force } + if id_type: + config["id_type"] = id_type + if id_property: + config["id_property"] = id_property + if record_format: + config["record_format"] = record_format + return self._start("CREATE_DATABASE", config=config, force=force) def _start(self, action: str = "CREATE_GRAPH", *, @@ -309,14 +319,14 @@ def _start(self, action: str = "CREATE_GRAPH", *, raise error.Neo4jArrowException(f"failed to start {action} for {config['name']}, got {result}") except error.AlreadyExists as e: if force: - log.warning(f"forcing cancellation of existing {action} import" - f" for {config['name']}") + self.logger.warning(f"forcing cancellation of existing {action} import" + f" for {config['name']}") if self.abort(): return self._start(action, config=config) - log.error(f"{action} import job already exists for {config['name']}") + self.logger.error(f"{action} import job already exists for {config['name']}") except Exception as e: - log.error(f"fatal error performing action {action}: {e}") + self.logger.error(f"fatal error performing action {action}: {e}") raise e return {} @@ -324,15 +334,15 @@ def _start(self, action: str = "CREATE_GRAPH", *, def _write_entities(self, desc: Dict[str, Any], entities: Union[Nodes, Edges], mapper: MappingFn) -> Result: try: if isinstance(entities, pa.Table): - nodes = mapper(entities).to_batches(max_chunksize=self.max_chunk_size) + entities = mapper(entities).to_batches(max_chunksize=self.max_chunk_size) mapper = self._nop return self._write_batches(desc, entities, mapper) except error.NotFound as e: - log.error(f"no existing import job found for graph f{self.graph}") + self.logger.error(f"no existing import job found for graph f{self.graph}") raise e except Exception as e: - log.error(f"fatal error while feeding {desc['entity_type']}s for graph {self.graph}: {e}") + self.logger.error(f"fatal error while feeding {desc['entity_type']}s for graph {self.graph}: {e}") raise e def write_nodes(self, nodes: Nodes, @@ -459,9 +469,9 @@ def abort(self, name: Optional[str] = None) -> bool: raise error.Neo4jArrowException(f"invalid response for abort of graph {self.graph}, got {result}") except error.NotFound as e: - log.warning(f"no existing import for {config['name']}") + self.logger.warning(f"no existing import for {config['name']}") except Exception as e: - log.error(f"error aborting {config['name']}: {e}") + self.logger.error(f"error aborting {config['name']}: {e}") return False def wait(self, timeout: int = 0): diff --git a/neo4j_arrow/error.py b/neo4j_arrow/error.py index 6a9c999..7d21359 100644 --- a/neo4j_arrow/error.py +++ b/neo4j_arrow/error.py @@ -8,9 +8,10 @@ KnownExceptions = Union[ArrowException, FlightServerError, Exception] + def interpret(e: KnownExceptions) -> KnownExceptions: """ - Try to figure out which exception occcurred based on the server response. + Try to figure out which exception occurred based on the server response. """ try: message = "".join(e.args) @@ -35,6 +36,7 @@ class Neo4jArrowException(Exception): """ Base class for neo4j_arrow exceptions. """ + def __init__(self, message: str): self.message = message @@ -43,6 +45,7 @@ class UnknownError(Neo4jArrowException): """ We have no idea what is wrong :( """ + def __init__(self, message: str): # These errors have ugly stack traces often repeated. Try to beautify. # nb. In reality there's an embedded gRPC dict-like message, but let's @@ -73,6 +76,7 @@ class InvalidArgument(Neo4jArrowException): """ pass + class NotFound(Neo4jArrowException): """ The requested import process could not be found.