From 971f21b849b75889df60c534de76c787e9e9fb4f Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Tue, 2 Jul 2024 10:40:44 +1000 Subject: [PATCH] Add a POST endpoint for the /sparql passthru feature. (#239) * Add a POST endpoint for the /sparql passthru feature. This causes the proxied query to the graph db to be called as a POST too * Always set correct content-type on request to Graphdb backend, regardless of proxyied headers content-type. ensure quote_plus is used to encode the payload, so use raw body content, because if relying on httpx it will use regular urlquote instead. * Don't stream the proxied /sparql response to the client if the response has an attachment, this breaks when Prez is exposed through Azure APIM, because APIM strips out the attached file. Instead we wait for the whole response with attached file, read it, then send the response. * Add minimal post tests; ensure updates are not possible. Add Exception handler --------- Co-authored-by: david --- prez/app.py | 4 +- prez/models/model_exceptions.py | 10 ++++ prez/routers/sparql.py | 81 ++++++++++++++++++++--------- prez/services/exception_catchers.py | 12 ++++- prez/sparql/methods.py | 26 ++++++--- tests/test_sparql.py | 38 ++++++++++++++ 6 files changed, 137 insertions(+), 34 deletions(-) diff --git a/prez/app.py b/prez/app.py index ee184673..bc7ed9bb 100644 --- a/prez/app.py +++ b/prez/app.py @@ -20,7 +20,7 @@ from prez.models.model_exceptions import ( ClassNotFoundException, URINotFoundException, - NoProfilesException, + NoProfilesException, InvalidSPARQLQueryException, ) from prez.routers.catprez import router as catprez_router from prez.routers.cql import router as cql_router @@ -47,6 +47,7 @@ catch_class_not_found_exception, catch_uri_not_found_exception, catch_no_profiles_exception, + catch_invalid_sparql_query, ) from prez.services.generate_profiles import create_profiles_graph from prez.services.prez_logging import setup_logger @@ -167,6 +168,7 @@ def assemble_app( ClassNotFoundException: catch_class_not_found_exception, URINotFoundException: catch_uri_not_found_exception, NoProfilesException: catch_no_profiles_exception, + InvalidSPARQLQueryException: catch_invalid_sparql_query }, **kwargs ) diff --git a/prez/models/model_exceptions.py b/prez/models/model_exceptions.py index 65a44369..1f01e890 100644 --- a/prez/models/model_exceptions.py +++ b/prez/models/model_exceptions.py @@ -35,3 +35,13 @@ def __init__(self, classes: list): f"for which a profile was searched was/were: {', '.join(klass for klass in classes)}" ) super().__init__(self.message) + + +class InvalidSPARQLQueryException(Exception): + """ + Raised when a SPARQL query is invalid. + """ + + def __init__(self, error: str): + self.message = f"Invalid SPARQL query: {error}" + super().__init__(self.message) \ No newline at end of file diff --git a/prez/routers/sparql.py b/prez/routers/sparql.py index 359343f9..6986c8e5 100644 --- a/prez/routers/sparql.py +++ b/prez/routers/sparql.py @@ -1,6 +1,7 @@ import io +from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Form from fastapi.responses import JSONResponse, Response from rdflib import Namespace, Graph from starlette.background import BackgroundTask @@ -17,19 +18,31 @@ router = APIRouter(tags=["SPARQL"]) -# TODO: Split this into two routes on the same /sparql path. -# One to handle SPARQL GET requests, the other for SPARQL POST requests. + +@router.post("/sparql") +async def sparql_post_passthrough( + # To maintain compatibility with the other SPARQL endpoints, + # /sparql POST endpoint is not a JSON API, it uses + # values encoded with x-www-form-urlencoded + query: Annotated[str, Form()], # Pydantic validation prevents update queries (the Form would need to be "update") + request: Request, + repo: Repo = Depends(get_repo), +): + return await sparql_endpoint_handler(query, request, repo, method="POST") @router.get("/sparql") -async def sparql_endpoint( +async def sparql_get_passthrough( query: str, request: Request, repo: Repo = Depends(get_repo), ): - request_mediatype = request.headers.get("accept").split(",")[ - 0 - ] # can't default the MT where not provided as it could be + return await sparql_endpoint_handler(query, request, repo, method="GET") + + +async def sparql_endpoint_handler(query: str, request: Request, repo: Repo, method="GET"): + request_mediatype = request.headers.get("accept").split(",")[0] + # can't default the MT where not provided as it could be # graph (CONSTRUCT like queries) or tabular (SELECT queries) # Intercept "+anot" mediatypes @@ -39,11 +52,11 @@ async def sparql_endpoint( ) non_anot_mediatype = request_mediatype.replace("anot+", "") request._headers = Headers({**request.headers, "accept": non_anot_mediatype}) - response = await repo.sparql(request) + response = await repo.sparql(query, request.headers.raw, method=method) await response.aread() g = Graph() g.parse(data=response.text, format=non_anot_mediatype) - graph = await return_annotated_rdf(g, prof_and_mt_info.profile) + graph = await return_annotated_rdf(g, prof_and_mt_info.profile, repo) content = io.BytesIO( graph.serialize(format=non_anot_mediatype, encoding="utf-8") ) @@ -52,19 +65,39 @@ async def sparql_endpoint( media_type=non_anot_mediatype, headers=prof_and_mt_info.profile_headers, ) + query_result: 'httpx.Response' = await repo.sparql(query, request.headers.raw, method=method) + if isinstance(query_result, dict): + return JSONResponse(content=query_result) + elif isinstance(query_result, Graph): + return Response( + content=query_result.serialize(format="text/turtle"), + status_code=200 + ) + + dispositions = query_result.headers.get_list("Content-Disposition") + for disposition in dispositions: + if disposition.lower().startswith("attachment"): + is_attachment = True + break + else: + is_attachment = False + if is_attachment: + # remove transfer-encoding chunked, disposition=attachment, and content-length + headers = dict() + for k, v in query_result.headers.items(): + if k.lower() not in ("transfer-encoding", "content-disposition", "content-length"): + headers[k] = v + content = await query_result.aread() + await query_result.aclose() + return Response( + content=content, + status_code=query_result.status_code, + headers=headers, + ) else: - query_result = await repo.sparql(query, request.headers.raw) - if isinstance(query_result, dict): - return JSONResponse(content=query_result) - elif isinstance(query_result, Graph): - return Response( - content=query_result.serialize(format="text/turtle"), - status_code=200 - ) - else: - return StreamingResponse( - query_result.aiter_raw(), - status_code=query_result.status_code, - headers=dict(query_result.headers), - background=BackgroundTask(query_result.aclose), - ) + return StreamingResponse( + query_result.aiter_raw(), + status_code=query_result.status_code, + headers=dict(query_result.headers), + background=BackgroundTask(query_result.aclose), + ) diff --git a/prez/services/exception_catchers.py b/prez/services/exception_catchers.py index 8e5b88d2..a474f5eb 100644 --- a/prez/services/exception_catchers.py +++ b/prez/services/exception_catchers.py @@ -5,7 +5,7 @@ from prez.models.model_exceptions import ( ClassNotFoundException, URINotFoundException, - NoProfilesException, + NoProfilesException, InvalidSPARQLQueryException, ) @@ -51,3 +51,13 @@ async def catch_no_profiles_exception(request: Request, exc: NoProfilesException "detail": exc.message, }, ) + + +async def catch_invalid_sparql_query(request: Request, exc: InvalidSPARQLQueryException): + return JSONResponse( + status_code=400, + content={ + "error": "Bad Request", + "detail": exc.message, + }, + ) \ No newline at end of file diff --git a/prez/sparql/methods.py b/prez/sparql/methods.py index ab88ad5b..7657e127 100644 --- a/prez/sparql/methods.py +++ b/prez/sparql/methods.py @@ -11,6 +11,7 @@ from rdflib import Namespace, Graph, URIRef, Literal, BNode from prez.config import settings +from prez.models.model_exceptions import InvalidSPARQLQueryException PREZ = Namespace("https://prez.dev/") @@ -96,18 +97,24 @@ async def sparql( ): """Sends a starlette Request object (containing a SPARQL query in the URL parameters) to a proxied SPARQL endpoint.""" - # TODO: This only supports SPARQL GET requests because the query is sent as a query parameter. + # Uses GET if the proxied query was received as a query param using GET + # Uses POST if the proxied query was received as a form-encoded body using POST - query_escaped_as_bytes = f"query={quote_plus(query)}".encode("utf-8") - - # TODO: Global app settings should be passed in as a function argument. - url = httpx.URL(url=settings.sparql_endpoint, query=query_escaped_as_bytes) headers = [] for header in raw_headers: - if header[0] != b"host": + if header[0] not in (b"host", b"content-length", b"content-type"): headers.append(header) + query_escaped_as_bytes = f"query={quote_plus(query)}".encode("utf-8") + if method == "GET": + url = httpx.URL(url=settings.sparql_endpoint, query=query_escaped_as_bytes) + content = None + else: + headers.append((b"content-type", b"application/x-www-form-urlencoded")) + url = httpx.URL(url=settings.sparql_endpoint) + content = query_escaped_as_bytes + headers.append((b"host", str(url.host).encode("utf-8"))) - rp_req = self.async_client.build_request(method, url, headers=headers) + rp_req = self.async_client.build_request(method, url, headers=headers, content=content) return await self.async_client.send(rp_req, stream=True) @@ -157,7 +164,10 @@ def _sync_tabular_query_to_table(self, query: str, context: URIRef = None) -> tu def _sparql(self, query: str) -> dict | Graph | bool: """Submit a sparql query to the pyoxigraph store and return the formatted results.""" - results = self.pyoxi_store.query(query) + try: + results = self.pyoxi_store.query(query) + except SyntaxError as e: + raise InvalidSPARQLQueryException(e.msg) if isinstance(results, pyoxigraph.QuerySolutions): # a SELECT query result results_dict = self._handle_query_solution_results(results) return results_dict diff --git a/tests/test_sparql.py b/tests/test_sparql.py index 876bb4d8..c2b8059a 100644 --- a/tests/test_sparql.py +++ b/tests/test_sparql.py @@ -65,3 +65,41 @@ def test_ask(client): "/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D" ) assert (r.status_code, 200) + + +def test_post(client): + """check that a valid post query returns a 200 response.""" + r = client.post( + "/sparql", + data={ + "query": "SELECT * WHERE { ?s ?p ?o } LIMIT 1", + "format": "application/x-www-form-urlencoded", + }, + ) + assert (r.status_code, 200) + + +def test_post_invalid_data(client): + """check that a post query with invalid data returns a 400 response.""" + r = client.post( + "/sparql", + data={ + "query": "INVALID QUERY", + "format": "application/x-www-form-urlencoded", + }, + ) + assert r.status_code == 400 + + +def test_insert_as_query(client): + """ + Also tested manually with Fuseki + """ + r = client.post( + "/sparql", + data={ + "query": "INSERT {<:s> <:p> <:o>}", + "format": "application/x-www-form-urlencoded", + }, + ) + assert r.status_code == 400 \ No newline at end of file