diff --git a/.flake8 b/.flake8 index 54c243e..0736422 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -max-line-length = 88 +max-line-length = 120 max-complexity = 10 select = C,E,F,W,B,B950 ignore = E211, E999, F401, F821, W503 diff --git a/pyproject.toml b/pyproject.toml index 8816e4e..6bbd5f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ build-backend = "poetry.core.masonry.api" pythonpath = ["src"] [tool.black] -line-length = 88 +line-length = 120 target-version = ['py39'] [tool.isort] diff --git a/src/neo4j_arrow/_client.py b/src/neo4j_arrow/_client.py index 31edb5e..da4e199 100644 --- a/src/neo4j_arrow/_client.py +++ b/src/neo4j_arrow/_client.py @@ -135,9 +135,7 @@ def _client(self) -> flight.FlightClient: client = flight.FlightClient(location) if self.user and self.password: try: - (header, token) = client.authenticate_basic_token( - self.user, self.password - ) + (header, token) = client.authenticate_basic_token(self.user, self.password) if header: self.call_opts = flight.FlightCallOptions( headers=[(header, token)], @@ -164,9 +162,7 @@ def _send_action(self, action: str, body: Dict[str, Any]) -> Dict[str, Any]: def _get_chunks(self, ticket: Dict[str, Any]) -> Generator[Arrow, None, None]: client = self._client() try: - result = client.do_get( - pa.flight.Ticket(json.dumps(ticket).encode("utf8")), self.call_opts - ) + result = client.do_get(pa.flight.Ticket(json.dumps(ticket).encode("utf8")), self.call_opts) for chunk, _ in result: yield chunk except Exception as e: @@ -180,9 +176,7 @@ def _nop(cls, data: Arrow) -> Arrow: return data @classmethod - def _node_mapper( - cls, model: Graph, source_field: Optional[str] = None - ) -> MappingFn: + def _node_mapper(cls, model: Graph, source_field: Optional[str] = None) -> MappingFn: """ Generate a mapping function for a Node. """ @@ -196,33 +190,23 @@ def _map(data: Arrow) -> Arrow: my_label = data["labels"][0].as_py() node = model.node_by_label(my_label) if not node: - raise Exception( - "cannot find matching node in model given " f"{data.schema}" - ) + raise Exception("cannot find matching node in model given " f"{data.schema}") - columns, fields = cls._rename_and_add_column( - [], [], data, node.key_field, "nodeId" - ) + columns, fields = cls._rename_and_add_column([], [], data, node.key_field, "nodeId") if node.label: columns.append(pa.array([node.label] * len(data), pa.string())) fields.append(pa.field("labels", pa.string())) if node.label_field: - columns, fields = cls._rename_and_add_column( - columns, fields, data, node.label_field, "labels" - ) + columns, fields = cls._rename_and_add_column(columns, fields, data, node.label_field, "labels") for name in node.properties: - columns, fields = cls._rename_and_add_column( - columns, fields, data, name, node.properties[name] - ) + columns, fields = cls._rename_and_add_column(columns, fields, data, name, node.properties[name]) return data.from_arrays(columns, schema=pa.schema(fields)) return _map @classmethod - def _edge_mapper( - cls, model: Graph, source_field: Optional[str] = None - ) -> MappingFn: + def _edge_mapper(cls, model: Graph, source_field: Optional[str] = None) -> MappingFn: """ Generate a mapping function for an Edge. """ @@ -236,27 +220,17 @@ def _map(data: Arrow) -> Arrow: my_type = data["type"][0].as_py() edge = model.edge_by_type(my_type) if not edge: - raise Exception( - "cannot find matching edge in model given " f"{data.schema}" - ) - - columns, fields = cls._rename_and_add_column( - [], [], data, edge.source_field, "sourceNodeId" - ) - columns, fields = cls._rename_and_add_column( - columns, fields, data, edge.target_field, "targetNodeId" - ) + raise Exception("cannot find matching edge in model given " f"{data.schema}") + + columns, fields = cls._rename_and_add_column([], [], data, edge.source_field, "sourceNodeId") + columns, fields = cls._rename_and_add_column(columns, fields, data, edge.target_field, "targetNodeId") if edge.type: columns.append(pa.array([edge.type] * len(data), pa.string())) fields.append(pa.field("relationshipType", pa.string())) if edge.type_field: - columns, fields = cls._rename_and_add_column( - columns, fields, data, edge.type_field, "relationshipType" - ) + columns, fields = cls._rename_and_add_column(columns, fields, data, edge.type_field, "relationshipType") for name in edge.properties: - columns, fields = cls._rename_and_add_column( - columns, fields, data, name, edge.properties[name] - ) + columns, fields = cls._rename_and_add_column(columns, fields, data, name, edge.properties[name]) return data.from_arrays(columns, schema=pa.schema(fields)) @@ -298,9 +272,7 @@ def _write_batches( first = cast(pa.RecordBatch, fn(first)) client = self._client() - upload_descriptor = flight.FlightDescriptor.for_command( - json.dumps(desc).encode("utf-8") - ) + upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(desc).encode("utf-8")) n_rows, n_bytes = 0, 0 try: writer, _ = client.do_put(upload_descriptor, first.schema, self.call_opts) @@ -401,35 +373,24 @@ def _start( self.state = ClientState.FEEDING_NODES return result - raise error.Neo4jArrowException( - f"failed to start {action} for {config['name']}, got {result}" - ) + raise error.Neo4jArrowException(f"failed to start {action} for {config['name']}, got {result}") except error.AlreadyExists: if force: - self.logger.warning( - f"forcing cancellation of existing {action} import" - f" for {config['name']}" - ) + self.logger.warning(f"forcing cancellation of existing {action} import" f" for {config['name']}") if self.abort(): return self._start(action, config=config) - self.logger.error( - f"{action} import job already exists for {config['name']}" - ) + self.logger.error(f"{action} import job already exists for {config['name']}") except Exception as e: self.logger.error(f"fatal error performing action {action}: {e}") raise e return {} - def _write_entities( - self, desc: Dict[str, Any], entities: Union[Nodes, Edges], mapper: MappingFn - ) -> Result: + def _write_entities(self, desc: Dict[str, Any], entities: Union[Nodes, Edges], mapper: MappingFn) -> Result: try: if isinstance(entities, pa.Table): - entities = mapper(entities).to_batches( - max_chunksize=self.max_chunk_size - ) + entities = mapper(entities).to_batches(max_chunksize=self.max_chunk_size) mapper = self._nop return self._write_batches(desc, entities, mapper) @@ -437,10 +398,7 @@ def _write_entities( self.logger.error(f"no existing import job found for graph f{self.graph}") raise e except Exception as e: - self.logger.error( - f"fatal error while feeding {desc['entity_type']}s for " - f"graph {self.graph}: {e}" - ) + self.logger.error(f"fatal error while feeding {desc['entity_type']}s for " f"graph {self.graph}: {e}") raise e def write_nodes( @@ -468,9 +426,7 @@ def nodes_done(self) -> Dict[str, Any]: self.state = ClientState.FEEDING_EDGES return result - raise error.Neo4jArrowException( - f"invalid response for nodes_done for graph {self.graph}, got {result}" - ) + raise error.Neo4jArrowException(f"invalid response for nodes_done for graph {self.graph}, got {result}") except Exception as e: raise error.interpret(e) @@ -499,9 +455,7 @@ def edges_done(self) -> Dict[str, Any]: self.state = ClientState.AWAITING_GRAPH return result - raise error.Neo4jArrowException( - f"invalid response for edges_done for graph {self.graph}, got {result}" - ) + raise error.Neo4jArrowException(f"invalid response for edges_done for graph {self.graph}, got {result}") except Exception as e: raise error.interpret(e) @@ -524,19 +478,13 @@ def read_edges( if properties: procedure_name = "gds.graph.relationshipProperties.stream" configuration = { - "relationship_properties": list( - properties if properties is not None else [] - ), - "relationship_types": list( - relationship_types if relationship_types is not None else ["*"] - ), + "relationship_properties": list(properties if properties is not None else []), + "relationship_types": list(relationship_types if relationship_types is not None else ["*"]), } else: procedure_name = "gds.beta.graph.relationships.stream" configuration = { - "relationship_types": list( - relationship_types if relationship_types is not None else ["*"] - ), + "relationship_types": list(relationship_types if relationship_types is not None else ["*"]), } return self._get_chunks( @@ -575,9 +523,7 @@ def read_nodes( "procedure_name": "gds.graph.nodeProperties.stream", "configuration": { "node_labels": list(labels if labels is not None else ["*"]), - "node_properties": list( - properties if properties is not None else [] - ), + "node_properties": list(properties if properties is not None else []), }, "concurrency": concurrency, } @@ -594,9 +540,7 @@ def abort(self, name: Optional[str] = None) -> bool: self.state = ClientState.READY return True - raise error.Neo4jArrowException( - f"invalid response for abort of graph {self.graph}, got {result}" - ) + raise error.Neo4jArrowException(f"invalid response for abort of graph {self.graph}, got {result}") except error.NotFound: self.logger.warning(f"no existing import for {config['name']}") except Exception as e: diff --git a/src/neo4j_arrow/error.py b/src/neo4j_arrow/error.py index 003b57d..5a5c474 100644 --- a/src/neo4j_arrow/error.py +++ b/src/neo4j_arrow/error.py @@ -49,9 +49,7 @@ def __init__(self, message: str): # nb. In reality there's an embedded gRPC dict-like message, but let's # not introduce dict parsing here because that's a security issue. try: - self.message = ( - message.replace(r"\n", "\n").replace(r"\'", "'").splitlines()[-1] - ) + self.message = message.replace(r"\n", "\n").replace(r"\'", "'").splitlines()[-1] except Exception: self.message = message diff --git a/src/neo4j_arrow/model.py b/src/neo4j_arrow/model.py index ed3c52f..ce939ef 100644 --- a/src/neo4j_arrow/model.py +++ b/src/neo4j_arrow/model.py @@ -84,10 +84,7 @@ def validate(self) -> None: if not self._label and not self._label_field: raise Exception(f"either label or label_field must be provided in {self}") if self._label and self._label_field: - raise Exception( - f"use of label and label_field at the same time is not allowed " - f"in {self}" - ) + raise Exception(f"use of label and label_field at the same time is not allowed " f"in {self}") if not self._key_field: raise Exception(f"empty key_field in {self}") @@ -162,9 +159,7 @@ def validate(self) -> None: if not self._type_field and not self._type: raise Exception(f"either type or type_field must be provided in {self}") if self._type_field and self._type: - raise Exception( - f"use of type and type_field at the same time is not allowed in {self}" - ) + raise Exception(f"use of type and type_field at the same time is not allowed in {self}") if not self._source_field: raise Exception(f"empty source_field in {self}") if not self._target_field: @@ -187,9 +182,7 @@ class Graph: * A List of Edges (optional, though boring if none!) """ - def __init__( - self, *, name: str, db: str = "", nodes: List[Node] = [], edges: List[Edge] = [] - ): + def __init__(self, *, name: str, db: str = "", nodes: List[Node] = [], edges: List[Edge] = []): self.name = name self.db = db self.nodes = nodes @@ -208,14 +201,10 @@ def with_edges(self, edges: List[Edge]) -> "Graph": return Graph(name=self.name, db=self.db, nodes=self.nodes, edges=edges) def with_node(self, node: Node) -> "Graph": - return Graph( - name=self.name, db=self.db, nodes=self.nodes + [node], edges=self.edges - ) + return Graph(name=self.name, db=self.db, nodes=self.nodes + [node], edges=self.edges) def with_edge(self, edge: Edge) -> "Graph": - return Graph( - name=self.name, db=self.db, nodes=self.nodes, edges=self.edges + [edge] - ) + return Graph(name=self.name, db=self.db, nodes=self.nodes, edges=self.edges + [edge]) def node_for_src(self, source: str) -> Union[None, Node]: """Find a Node in a Graph based on matching source pattern.""" @@ -272,9 +261,7 @@ def from_json(cls, json: str) -> "Graph": ) for e in obj.get("edges", []) ] - return Graph( - name=obj["name"], db=obj.get("db", "neo4j"), nodes=nodes, edges=edges - ) + return Graph(name=obj["name"], db=obj.get("db", "neo4j"), nodes=nodes, edges=edges) def to_dict(self) -> Dict[str, Any]: return { diff --git a/tests/test_model.py b/tests/test_model.py index 175e396..2433e9f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -97,9 +97,7 @@ def test_retrieving_by_source(): def test_retrieving_by_pattern(): g = ( Graph(name="graph", db="db") - .with_node( - Node(source="gs://.*/alpha[.]parquet", label_field="label", key_field="key") - ) + .with_node(Node(source="gs://.*/alpha[.]parquet", label_field="label", key_field="key")) .with_node( Node( source="beta",