diff --git a/prez/routers/sparql.py b/prez/routers/sparql.py index 34e01221..8e329fa7 100644 --- a/prez/routers/sparql.py +++ b/prez/routers/sparql.py @@ -1,6 +1,7 @@ import io from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse, Response from rdflib import Namespace, Graph from starlette.background import BackgroundTask from starlette.datastructures import Headers @@ -52,10 +53,25 @@ async def sparql_endpoint( headers=prof_and_mt_info.profile_headers, ) else: - response = await repo.sparql(query, request.headers.raw) - return StreamingResponse( - response.aiter_raw(), - status_code=response.status_code, - headers=dict(response.headers), - background=BackgroundTask(response.aclose), - ) + # response = await repo.sparql(query, request.headers.raw) + # return StreamingResponse( + # response.aiter_raw(), + # status_code=response.status_code, + # headers=dict(response.headers), + # background=BackgroundTask(response.aclose), + # ) + query_result = await repo.sparql(query, request.headers.raw) + if isinstance(query_result, dict): + return JSONResponse(content=query_result) + elif isinstance(query_result, Graph): + return Response( + content=query_result.serialize(format="text/turtle"), + status_code=200 + ) + else: + return StreamingResponse( + query_result.aiter_raw(), + status_code=query_result.status_code, + headers=dict(query_result.headers), + background=BackgroundTask(query_result.aclose), + ) diff --git a/prez/sparql/methods.py b/prez/sparql/methods.py index 389448c2..fbb783d2 100644 --- a/prez/sparql/methods.py +++ b/prez/sparql/methods.py @@ -155,11 +155,8 @@ def _sync_tabular_query_to_table(self, query: str, context: URIRef = None) -> tu # only return the bindings from the results. return context, results_dict["results"]["bindings"] - def _sparql(self, request: Request) -> dict | Graph: - try: - query = request.query_params["query"] - except KeyError: - raise KeyError(f"No query was provided in the request parameters.") + def _sparql(self, query: str) -> dict | Graph | bool: + """Submit a sparql query to the pyoxigraph store and return the formatted results.""" results = self.pyoxi_store.query(query) if isinstance(results, pyoxigraph.QuerySolutions): # a SELECT query result results_dict = self._handle_query_solution_results(results) @@ -181,8 +178,8 @@ async def tabular_query_to_table(self, query: str, context: URIRef = None) -> li self._sync_tabular_query_to_table, query, context ) - async def sparql(self, request: Request) -> list | Graph: - return self._sparql(request) + async def sparql(self, query: str, raw_headers: list[tuple[bytes, bytes]], method: str = "") -> list | Graph | bool: + return self._sparql(query) @staticmethod def _pyoxi_result_type(term) -> str: