Skip to content

Commit

Permalink
fix: make _exec_query thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkleiven committed Apr 8, 2024
1 parent 424d284 commit fdc694d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
14 changes: 7 additions & 7 deletions cimsparql/graphdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass, field
from enum import auto
from http import HTTPStatus
Expand Down Expand Up @@ -59,7 +60,7 @@ def parse_namespaces_rdf4j(response: httpx.Response) -> dict[str, str]:
return prefixes


@dataclass
@dataclass(frozen=True)
class ServiceConfig:
repo: str = field(default=os.getenv("GRAPHDB_REPO", "LATEST"))
protocol: str = "https"
Expand Down Expand Up @@ -216,10 +217,6 @@ def update_prefixes(self, pref: dict[str, str]) -> None:
def __str__(self) -> str:
return f"<GraphDBClient object, service: {self.service_cfg.url}>"

def _prep_query(self, query: str) -> None:
self.sparql.setQuery(query)
self._update_sparql_parameters()

@staticmethod
def _process_result(results: SparqlResultJson) -> dict:
cols = results.head.variables
Expand All @@ -228,14 +225,17 @@ def _process_result(results: SparqlResultJson) -> dict:
return {"out": out, "cols": cols, "data": data}

def _exec_query(self, query: str) -> SparqlResult:
self._prep_query(query)
# To allow exec query to be run in threads, we use a deepcopy of the underlying
# sparql wrapper .This is needed since setQuery changes the state of the SPARQLWrapper
sparql_wrapper = deepcopy(self.sparql)
sparql_wrapper.setQuery(query)

for attempt in tenacity.Retrying(
stop=tenacity.stop_after_attempt(self.service_cfg.num_retries + 1),
wait=tenacity.wait_exponential(max=self.service_cfg.max_delay_seconds),
):
with attempt:
results = self.sparql.queryAndConvert()
results = sparql_wrapper.queryAndConvert()

sparql_result = SparqlResultJson(**results)
if self.service_cfg.validate:
Expand Down
10 changes: 7 additions & 3 deletions cimsparql/type_mapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
from collections.abc import Callable, Generator
from contextlib import contextmanager
from decimal import Decimal
Expand Down Expand Up @@ -64,12 +65,15 @@ def str_preserve_none(x: Any) -> str | None:

@contextmanager
def enforce_no_limit(client: GraphDBClient) -> Generator[GraphDBClient, None, None]:
orig_limit = client.service_cfg.limit
client.service_cfg.limit = None
orig_cfg = client.service_cfg
client.service_cfg = dataclasses.replace(orig_cfg, limit=None)
client._update_sparql_parameters()

try:
yield client
finally:
client.service_cfg.limit = orig_limit
client.service_cfg = orig_cfg
client._update_sparql_parameters()


def build_type_map(prefixes: dict[PREFIX, URI]) -> dict[SPARQL_TYPE, TYPE_CASTER]:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_rdf4j.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import dataclasses
import os
from collections.abc import Generator
from copy import deepcopy
from pathlib import Path
from string import Template

Expand Down Expand Up @@ -97,9 +97,11 @@ def test_upload_with_context(upload_client: GraphDBClient):


def test_direct_sparql_endpoint(rdf4j_gdb: GraphDBClient):
service_cfg_direct = deepcopy(rdf4j_gdb.service_cfg)
service_cfg_direct.server = rdf4j_gdb.service_cfg.url
service_cfg_direct.rest_api = RestApi.DIRECT_SPARQL_ENDPOINT
service_cfg_direct = dataclasses.replace(
rdf4j_gdb.service_cfg,
server=rdf4j_gdb.service_cfg.url,
rest_api=RestApi.DIRECT_SPARQL_ENDPOINT,
)

gdb_direct = GraphDBClient(service_cfg_direct)
query = "SELECT * {?s ?p ?o}"
Expand Down

0 comments on commit fdc694d

Please sign in to comment.