Skip to content

Commit

Permalink
Increase line length
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-ince committed Nov 24, 2023
1 parent db80237 commit 59bc9e3
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 111 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
max-line-length = 88
max-line-length = 120
max-complexity = 10
select = C,E,F,W,B,B950
ignore = E211, E999, F401, F821, W503
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ build-backend = "poetry.core.masonry.api"
pythonpath = ["src"]

[tool.black]
line-length = 88
line-length = 120
target-version = ['py39']

[tool.isort]
Expand Down
112 changes: 28 additions & 84 deletions src/neo4j_arrow/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def _client(self) -> flight.FlightClient:
client = flight.FlightClient(location)
if self.user and self.password:
try:
(header, token) = client.authenticate_basic_token(
self.user, self.password
)
(header, token) = client.authenticate_basic_token(self.user, self.password)
if header:
self.call_opts = flight.FlightCallOptions(
headers=[(header, token)],
Expand All @@ -164,9 +162,7 @@ def _send_action(self, action: str, body: Dict[str, Any]) -> Dict[str, Any]:
def _get_chunks(self, ticket: Dict[str, Any]) -> Generator[Arrow, None, None]:
client = self._client()
try:
result = client.do_get(
pa.flight.Ticket(json.dumps(ticket).encode("utf8")), self.call_opts
)
result = client.do_get(pa.flight.Ticket(json.dumps(ticket).encode("utf8")), self.call_opts)
for chunk, _ in result:
yield chunk
except Exception as e:
Expand All @@ -180,9 +176,7 @@ def _nop(cls, data: Arrow) -> Arrow:
return data

@classmethod
def _node_mapper(
cls, model: Graph, source_field: Optional[str] = None
) -> MappingFn:
def _node_mapper(cls, model: Graph, source_field: Optional[str] = None) -> MappingFn:
"""
Generate a mapping function for a Node.
"""
Expand All @@ -196,33 +190,23 @@ def _map(data: Arrow) -> Arrow:
my_label = data["labels"][0].as_py()
node = model.node_by_label(my_label)
if not node:
raise Exception(
"cannot find matching node in model given " f"{data.schema}"
)
raise Exception("cannot find matching node in model given " f"{data.schema}")

columns, fields = cls._rename_and_add_column(
[], [], data, node.key_field, "nodeId"
)
columns, fields = cls._rename_and_add_column([], [], data, node.key_field, "nodeId")
if node.label:
columns.append(pa.array([node.label] * len(data), pa.string()))
fields.append(pa.field("labels", pa.string()))
if node.label_field:
columns, fields = cls._rename_and_add_column(
columns, fields, data, node.label_field, "labels"
)
columns, fields = cls._rename_and_add_column(columns, fields, data, node.label_field, "labels")
for name in node.properties:
columns, fields = cls._rename_and_add_column(
columns, fields, data, name, node.properties[name]
)
columns, fields = cls._rename_and_add_column(columns, fields, data, name, node.properties[name])

return data.from_arrays(columns, schema=pa.schema(fields))

return _map

@classmethod
def _edge_mapper(
cls, model: Graph, source_field: Optional[str] = None
) -> MappingFn:
def _edge_mapper(cls, model: Graph, source_field: Optional[str] = None) -> MappingFn:
"""
Generate a mapping function for an Edge.
"""
Expand All @@ -236,27 +220,17 @@ def _map(data: Arrow) -> Arrow:
my_type = data["type"][0].as_py()
edge = model.edge_by_type(my_type)
if not edge:
raise Exception(
"cannot find matching edge in model given " f"{data.schema}"
)

columns, fields = cls._rename_and_add_column(
[], [], data, edge.source_field, "sourceNodeId"
)
columns, fields = cls._rename_and_add_column(
columns, fields, data, edge.target_field, "targetNodeId"
)
raise Exception("cannot find matching edge in model given " f"{data.schema}")

columns, fields = cls._rename_and_add_column([], [], data, edge.source_field, "sourceNodeId")
columns, fields = cls._rename_and_add_column(columns, fields, data, edge.target_field, "targetNodeId")
if edge.type:
columns.append(pa.array([edge.type] * len(data), pa.string()))
fields.append(pa.field("relationshipType", pa.string()))
if edge.type_field:
columns, fields = cls._rename_and_add_column(
columns, fields, data, edge.type_field, "relationshipType"
)
columns, fields = cls._rename_and_add_column(columns, fields, data, edge.type_field, "relationshipType")
for name in edge.properties:
columns, fields = cls._rename_and_add_column(
columns, fields, data, name, edge.properties[name]
)
columns, fields = cls._rename_and_add_column(columns, fields, data, name, edge.properties[name])

return data.from_arrays(columns, schema=pa.schema(fields))

Expand Down Expand Up @@ -298,9 +272,7 @@ def _write_batches(
first = cast(pa.RecordBatch, fn(first))

client = self._client()
upload_descriptor = flight.FlightDescriptor.for_command(
json.dumps(desc).encode("utf-8")
)
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(desc).encode("utf-8"))
n_rows, n_bytes = 0, 0
try:
writer, _ = client.do_put(upload_descriptor, first.schema, self.call_opts)
Expand Down Expand Up @@ -401,46 +373,32 @@ def _start(
self.state = ClientState.FEEDING_NODES
return result

raise error.Neo4jArrowException(
f"failed to start {action} for {config['name']}, got {result}"
)
raise error.Neo4jArrowException(f"failed to start {action} for {config['name']}, got {result}")
except error.AlreadyExists:
if force:
self.logger.warning(
f"forcing cancellation of existing {action} import"
f" for {config['name']}"
)
self.logger.warning(f"forcing cancellation of existing {action} import" f" for {config['name']}")
if self.abort():
return self._start(action, config=config)

self.logger.error(
f"{action} import job already exists for {config['name']}"
)
self.logger.error(f"{action} import job already exists for {config['name']}")
except Exception as e:
self.logger.error(f"fatal error performing action {action}: {e}")
raise e

return {}

def _write_entities(
self, desc: Dict[str, Any], entities: Union[Nodes, Edges], mapper: MappingFn
) -> Result:
def _write_entities(self, desc: Dict[str, Any], entities: Union[Nodes, Edges], mapper: MappingFn) -> Result:
try:
if isinstance(entities, pa.Table):
entities = mapper(entities).to_batches(
max_chunksize=self.max_chunk_size
)
entities = mapper(entities).to_batches(max_chunksize=self.max_chunk_size)
mapper = self._nop

return self._write_batches(desc, entities, mapper)
except error.NotFound as e:
self.logger.error(f"no existing import job found for graph f{self.graph}")
raise e
except Exception as e:
self.logger.error(
f"fatal error while feeding {desc['entity_type']}s for "
f"graph {self.graph}: {e}"
)
self.logger.error(f"fatal error while feeding {desc['entity_type']}s for " f"graph {self.graph}: {e}")
raise e

def write_nodes(
Expand Down Expand Up @@ -468,9 +426,7 @@ def nodes_done(self) -> Dict[str, Any]:
self.state = ClientState.FEEDING_EDGES
return result

raise error.Neo4jArrowException(
f"invalid response for nodes_done for graph {self.graph}, got {result}"
)
raise error.Neo4jArrowException(f"invalid response for nodes_done for graph {self.graph}, got {result}")
except Exception as e:
raise error.interpret(e)

Expand Down Expand Up @@ -499,9 +455,7 @@ def edges_done(self) -> Dict[str, Any]:
self.state = ClientState.AWAITING_GRAPH
return result

raise error.Neo4jArrowException(
f"invalid response for edges_done for graph {self.graph}, got {result}"
)
raise error.Neo4jArrowException(f"invalid response for edges_done for graph {self.graph}, got {result}")
except Exception as e:
raise error.interpret(e)

Expand All @@ -524,19 +478,13 @@ def read_edges(
if properties:
procedure_name = "gds.graph.relationshipProperties.stream"
configuration = {
"relationship_properties": list(
properties if properties is not None else []
),
"relationship_types": list(
relationship_types if relationship_types is not None else ["*"]
),
"relationship_properties": list(properties if properties is not None else []),
"relationship_types": list(relationship_types if relationship_types is not None else ["*"]),
}
else:
procedure_name = "gds.beta.graph.relationships.stream"
configuration = {
"relationship_types": list(
relationship_types if relationship_types is not None else ["*"]
),
"relationship_types": list(relationship_types if relationship_types is not None else ["*"]),
}

return self._get_chunks(
Expand Down Expand Up @@ -575,9 +523,7 @@ def read_nodes(
"procedure_name": "gds.graph.nodeProperties.stream",
"configuration": {
"node_labels": list(labels if labels is not None else ["*"]),
"node_properties": list(
properties if properties is not None else []
),
"node_properties": list(properties if properties is not None else []),
},
"concurrency": concurrency,
}
Expand All @@ -594,9 +540,7 @@ def abort(self, name: Optional[str] = None) -> bool:
self.state = ClientState.READY
return True

raise error.Neo4jArrowException(
f"invalid response for abort of graph {self.graph}, got {result}"
)
raise error.Neo4jArrowException(f"invalid response for abort of graph {self.graph}, got {result}")
except error.NotFound:
self.logger.warning(f"no existing import for {config['name']}")
except Exception as e:
Expand Down
4 changes: 1 addition & 3 deletions src/neo4j_arrow/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def __init__(self, message: str):
# nb. In reality there's an embedded gRPC dict-like message, but let's
# not introduce dict parsing here because that's a security issue.
try:
self.message = (
message.replace(r"\n", "\n").replace(r"\'", "'").splitlines()[-1]
)
self.message = message.replace(r"\n", "\n").replace(r"\'", "'").splitlines()[-1]
except Exception:
self.message = message

Expand Down
25 changes: 6 additions & 19 deletions src/neo4j_arrow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,7 @@ def validate(self) -> None:
if not self._label and not self._label_field:
raise Exception(f"either label or label_field must be provided in {self}")
if self._label and self._label_field:
raise Exception(
f"use of label and label_field at the same time is not allowed "
f"in {self}"
)
raise Exception(f"use of label and label_field at the same time is not allowed " f"in {self}")
if not self._key_field:
raise Exception(f"empty key_field in {self}")

Expand Down Expand Up @@ -162,9 +159,7 @@ def validate(self) -> None:
if not self._type_field and not self._type:
raise Exception(f"either type or type_field must be provided in {self}")
if self._type_field and self._type:
raise Exception(
f"use of type and type_field at the same time is not allowed in {self}"
)
raise Exception(f"use of type and type_field at the same time is not allowed in {self}")
if not self._source_field:
raise Exception(f"empty source_field in {self}")
if not self._target_field:
Expand All @@ -187,9 +182,7 @@ class Graph:
* A List of Edges (optional, though boring if none!)
"""

def __init__(
self, *, name: str, db: str = "", nodes: List[Node] = [], edges: List[Edge] = []
):
def __init__(self, *, name: str, db: str = "", nodes: List[Node] = [], edges: List[Edge] = []):
self.name = name
self.db = db
self.nodes = nodes
Expand All @@ -208,14 +201,10 @@ def with_edges(self, edges: List[Edge]) -> "Graph":
return Graph(name=self.name, db=self.db, nodes=self.nodes, edges=edges)

def with_node(self, node: Node) -> "Graph":
return Graph(
name=self.name, db=self.db, nodes=self.nodes + [node], edges=self.edges
)
return Graph(name=self.name, db=self.db, nodes=self.nodes + [node], edges=self.edges)

def with_edge(self, edge: Edge) -> "Graph":
return Graph(
name=self.name, db=self.db, nodes=self.nodes, edges=self.edges + [edge]
)
return Graph(name=self.name, db=self.db, nodes=self.nodes, edges=self.edges + [edge])

def node_for_src(self, source: str) -> Union[None, Node]:
"""Find a Node in a Graph based on matching source pattern."""
Expand Down Expand Up @@ -272,9 +261,7 @@ def from_json(cls, json: str) -> "Graph":
)
for e in obj.get("edges", [])
]
return Graph(
name=obj["name"], db=obj.get("db", "neo4j"), nodes=nodes, edges=edges
)
return Graph(name=obj["name"], db=obj.get("db", "neo4j"), nodes=nodes, edges=edges)

def to_dict(self) -> Dict[str, Any]:
return {
Expand Down
4 changes: 1 addition & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def test_retrieving_by_source():
def test_retrieving_by_pattern():
g = (
Graph(name="graph", db="db")
.with_node(
Node(source="gs://.*/alpha[.]parquet", label_field="label", key_field="key")
)
.with_node(Node(source="gs://.*/alpha[.]parquet", label_field="label", key_field="key"))
.with_node(
Node(
source="beta",
Expand Down

0 comments on commit 59bc9e3

Please sign in to comment.