Skip to content

Commit

Permalink
Add a POST endpoint for the /sparql passthru feature. (#239)
Browse files Browse the repository at this point in the history
* Add a POST endpoint for the /sparql passthru feature.
This causes the proxied query to the graph db to be called as a POST too

* Always set correct content-type on request to Graphdb backend, regardless of proxyied headers content-type.
ensure quote_plus is used to encode the payload, so use raw body content, because if relying on httpx it will use regular urlquote instead.

* Don't stream the proxied /sparql response to the client if the response has an attachment, this breaks when Prez is exposed through Azure APIM, because APIM strips out the attached file.
Instead we wait for the whole response with attached file, read it, then send the response.

* Add minimal post tests; ensure updates are not possible. Add Exception handler

---------

Co-authored-by: david <dcchabgood@gmail.com>
  • Loading branch information
ashleysommer and recalcitrantsupplant committed Jul 2, 2024
1 parent a06547b commit 971f21b
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 34 deletions.
4 changes: 3 additions & 1 deletion prez/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from prez.models.model_exceptions import (
ClassNotFoundException,
URINotFoundException,
NoProfilesException,
NoProfilesException, InvalidSPARQLQueryException,
)
from prez.routers.catprez import router as catprez_router
from prez.routers.cql import router as cql_router
Expand All @@ -47,6 +47,7 @@
catch_class_not_found_exception,
catch_uri_not_found_exception,
catch_no_profiles_exception,
catch_invalid_sparql_query,
)
from prez.services.generate_profiles import create_profiles_graph
from prez.services.prez_logging import setup_logger
Expand Down Expand Up @@ -167,6 +168,7 @@ def assemble_app(
ClassNotFoundException: catch_class_not_found_exception,
URINotFoundException: catch_uri_not_found_exception,
NoProfilesException: catch_no_profiles_exception,
InvalidSPARQLQueryException: catch_invalid_sparql_query
},
**kwargs
)
Expand Down
10 changes: 10 additions & 0 deletions prez/models/model_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,13 @@ def __init__(self, classes: list):
f"for which a profile was searched was/were: {', '.join(klass for klass in classes)}"
)
super().__init__(self.message)


class InvalidSPARQLQueryException(Exception):
"""
Raised when a SPARQL query is invalid.
"""

def __init__(self, error: str):
self.message = f"Invalid SPARQL query: {error}"
super().__init__(self.message)
81 changes: 57 additions & 24 deletions prez/routers/sparql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
from typing import Annotated

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse, Response
from rdflib import Namespace, Graph
from starlette.background import BackgroundTask
Expand All @@ -17,19 +18,31 @@

router = APIRouter(tags=["SPARQL"])

# TODO: Split this into two routes on the same /sparql path.
# One to handle SPARQL GET requests, the other for SPARQL POST requests.

@router.post("/sparql")
async def sparql_post_passthrough(
# To maintain compatibility with the other SPARQL endpoints,
# /sparql POST endpoint is not a JSON API, it uses
# values encoded with x-www-form-urlencoded
query: Annotated[str, Form()], # Pydantic validation prevents update queries (the Form would need to be "update")
request: Request,
repo: Repo = Depends(get_repo),
):
return await sparql_endpoint_handler(query, request, repo, method="POST")


@router.get("/sparql")
async def sparql_endpoint(
async def sparql_get_passthrough(
query: str,
request: Request,
repo: Repo = Depends(get_repo),
):
request_mediatype = request.headers.get("accept").split(",")[
0
] # can't default the MT where not provided as it could be
return await sparql_endpoint_handler(query, request, repo, method="GET")


async def sparql_endpoint_handler(query: str, request: Request, repo: Repo, method="GET"):
request_mediatype = request.headers.get("accept").split(",")[0]
# can't default the MT where not provided as it could be
# graph (CONSTRUCT like queries) or tabular (SELECT queries)

# Intercept "+anot" mediatypes
Expand All @@ -39,11 +52,11 @@ async def sparql_endpoint(
)
non_anot_mediatype = request_mediatype.replace("anot+", "")
request._headers = Headers({**request.headers, "accept": non_anot_mediatype})
response = await repo.sparql(request)
response = await repo.sparql(query, request.headers.raw, method=method)
await response.aread()
g = Graph()
g.parse(data=response.text, format=non_anot_mediatype)
graph = await return_annotated_rdf(g, prof_and_mt_info.profile)
graph = await return_annotated_rdf(g, prof_and_mt_info.profile, repo)
content = io.BytesIO(
graph.serialize(format=non_anot_mediatype, encoding="utf-8")
)
Expand All @@ -52,19 +65,39 @@ async def sparql_endpoint(
media_type=non_anot_mediatype,
headers=prof_and_mt_info.profile_headers,
)
query_result: 'httpx.Response' = await repo.sparql(query, request.headers.raw, method=method)
if isinstance(query_result, dict):
return JSONResponse(content=query_result)
elif isinstance(query_result, Graph):
return Response(
content=query_result.serialize(format="text/turtle"),
status_code=200
)

dispositions = query_result.headers.get_list("Content-Disposition")
for disposition in dispositions:
if disposition.lower().startswith("attachment"):
is_attachment = True
break
else:
is_attachment = False
if is_attachment:
# remove transfer-encoding chunked, disposition=attachment, and content-length
headers = dict()
for k, v in query_result.headers.items():
if k.lower() not in ("transfer-encoding", "content-disposition", "content-length"):
headers[k] = v
content = await query_result.aread()
await query_result.aclose()
return Response(
content=content,
status_code=query_result.status_code,
headers=headers,
)
else:
query_result = await repo.sparql(query, request.headers.raw)
if isinstance(query_result, dict):
return JSONResponse(content=query_result)
elif isinstance(query_result, Graph):
return Response(
content=query_result.serialize(format="text/turtle"),
status_code=200
)
else:
return StreamingResponse(
query_result.aiter_raw(),
status_code=query_result.status_code,
headers=dict(query_result.headers),
background=BackgroundTask(query_result.aclose),
)
return StreamingResponse(
query_result.aiter_raw(),
status_code=query_result.status_code,
headers=dict(query_result.headers),
background=BackgroundTask(query_result.aclose),
)
12 changes: 11 additions & 1 deletion prez/services/exception_catchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from prez.models.model_exceptions import (
ClassNotFoundException,
URINotFoundException,
NoProfilesException,
NoProfilesException, InvalidSPARQLQueryException,
)


Expand Down Expand Up @@ -51,3 +51,13 @@ async def catch_no_profiles_exception(request: Request, exc: NoProfilesException
"detail": exc.message,
},
)


async def catch_invalid_sparql_query(request: Request, exc: InvalidSPARQLQueryException):
return JSONResponse(
status_code=400,
content={
"error": "Bad Request",
"detail": exc.message,
},
)
26 changes: 18 additions & 8 deletions prez/sparql/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rdflib import Namespace, Graph, URIRef, Literal, BNode

from prez.config import settings
from prez.models.model_exceptions import InvalidSPARQLQueryException

PREZ = Namespace("https://prez.dev/")

Expand Down Expand Up @@ -96,18 +97,24 @@ async def sparql(
):
"""Sends a starlette Request object (containing a SPARQL query in the URL parameters) to a proxied SPARQL
endpoint."""
# TODO: This only supports SPARQL GET requests because the query is sent as a query parameter.
# Uses GET if the proxied query was received as a query param using GET
# Uses POST if the proxied query was received as a form-encoded body using POST

query_escaped_as_bytes = f"query={quote_plus(query)}".encode("utf-8")

# TODO: Global app settings should be passed in as a function argument.
url = httpx.URL(url=settings.sparql_endpoint, query=query_escaped_as_bytes)
headers = []
for header in raw_headers:
if header[0] != b"host":
if header[0] not in (b"host", b"content-length", b"content-type"):
headers.append(header)
query_escaped_as_bytes = f"query={quote_plus(query)}".encode("utf-8")
if method == "GET":
url = httpx.URL(url=settings.sparql_endpoint, query=query_escaped_as_bytes)
content = None
else:
headers.append((b"content-type", b"application/x-www-form-urlencoded"))
url = httpx.URL(url=settings.sparql_endpoint)
content = query_escaped_as_bytes

headers.append((b"host", str(url.host).encode("utf-8")))
rp_req = self.async_client.build_request(method, url, headers=headers)
rp_req = self.async_client.build_request(method, url, headers=headers, content=content)
return await self.async_client.send(rp_req, stream=True)


Expand Down Expand Up @@ -157,7 +164,10 @@ def _sync_tabular_query_to_table(self, query: str, context: URIRef = None) -> tu

def _sparql(self, query: str) -> dict | Graph | bool:
"""Submit a sparql query to the pyoxigraph store and return the formatted results."""
results = self.pyoxi_store.query(query)
try:
results = self.pyoxi_store.query(query)
except SyntaxError as e:
raise InvalidSPARQLQueryException(e.msg)
if isinstance(results, pyoxigraph.QuerySolutions): # a SELECT query result
results_dict = self._handle_query_solution_results(results)
return results_dict
Expand Down
38 changes: 38 additions & 0 deletions tests/test_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,41 @@ def test_ask(client):
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D"
)
assert (r.status_code, 200)


def test_post(client):
"""check that a valid post query returns a 200 response."""
r = client.post(
"/sparql",
data={
"query": "SELECT * WHERE { ?s ?p ?o } LIMIT 1",
"format": "application/x-www-form-urlencoded",
},
)
assert (r.status_code, 200)


def test_post_invalid_data(client):
"""check that a post query with invalid data returns a 400 response."""
r = client.post(
"/sparql",
data={
"query": "INVALID QUERY",
"format": "application/x-www-form-urlencoded",
},
)
assert r.status_code == 400


def test_insert_as_query(client):
"""
Also tested manually with Fuseki
"""
r = client.post(
"/sparql",
data={
"query": "INSERT {<:s> <:p> <:o>}",
"format": "application/x-www-form-urlencoded",
},
)
assert r.status_code == 400

0 comments on commit 971f21b

Please sign in to comment.