Skip to content

Commit

Permalink
⚡️ make a bit async
Browse files Browse the repository at this point in the history
  • Loading branch information
simonwoerpel committed Jul 25, 2024
1 parent c856ff6 commit 7a1aad8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 30 deletions.
16 changes: 9 additions & 7 deletions ftmstore_fastapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def dataset_list(request: Request) -> CatalogResponse:
This is basically a list of the available dataset within this api instance.
"""
return views.dataset_list(request)
return await views.dataset_list(request)


@app.get(
Expand All @@ -67,7 +67,7 @@ async def dataset_detail(request: Request, dataset: Datasets) -> DatasetResponse
Show metadata for given dataset (as described in
[nomenklatura.Dataset](https://github.com/opensanctions/nomenklatura))
"""
return views.dataset_detail(request, dataset)
return await views.dataset_detail(request, dataset)


def get_authenticated(
Expand Down Expand Up @@ -146,7 +146,9 @@ async def entities(
Use optional `q` parameter for a search term. This does a simple name matching
search, use the `/search` endpoint for actual fulltext search via `ftmq-search`
"""
return views.entity_list(request, retrieve_params, authenticated=authenticated)
return await views.entity_list(
request, retrieve_params, authenticated=authenticated
)


@app.get(
Expand Down Expand Up @@ -175,7 +177,7 @@ async def detail_entity(
`x-entity-id` - the new entity id
`x-entity-schema` - the new entity schema
"""
return views.entity_detail(request, entity_id, retrieve_params)
return await views.entity_detail(request, entity_id, retrieve_params)


@app.get(
Expand Down Expand Up @@ -203,7 +205,7 @@ async def aggregation(
?aggMax=amount&aggMax=date
"""
return views.aggregation(request)
return await views.aggregation(request)


@app.get(
Expand All @@ -225,7 +227,7 @@ async def search(
Returned entities are "dehydrated" and only contain properties defined
during indexing.
"""
return views.search(request, authenticated=authenticated)
return await views.search(request, authenticated=authenticated)


@app.get(
Expand All @@ -239,4 +241,4 @@ async def autocomplete(q: str) -> AutocompleteResponse:
"""
Simple autocomplete by names
"""
return views.autocomplete(q)
return await views.autocomplete(q)
27 changes: 20 additions & 7 deletions ftmstore_fastapi/store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from collections.abc import AsyncGenerator
from functools import cache, lru_cache
from typing import TYPE_CHECKING, Literal

from fastapi import HTTPException
from ftmq.dedupe import get_resolver
from ftmq.model import Catalog, Dataset
from ftmq.model import Catalog, Dataset, DatasetStats
from ftmq.query import Q
from ftmq.store import Store
from ftmq.store import get_store as _get_store
from ftmq.types import CE, CEGenerator
from ftmq.store.sql import AggregatorResult
from ftmq.types import CE
from ftmq.util import get_dehydrated_proxy, get_featured_proxy

from ftmstore_fastapi.logging import get_logger
Expand Down Expand Up @@ -66,11 +68,20 @@ def __init__(
self.query = self.store.query()
self.view = self.store.default_view()

self.stats = self.query.stats
self.aggregations = self.query.aggregations
self.get_adjacents = self.query.get_adjacents
async def stats(self, *args, **kwargs) -> DatasetStats:
return self.query.stats(*args, **kwargs)

def get_entity(self, entity_id: str, params: "RetrieveParams") -> CE | None:
async def aggregations(self, *args, **kwargs) -> AggregatorResult:
return self.query.aggregations(*args, **kwargs)

async def get_adjacents(self, *args, **kwargs) -> set[CE]:
return self.query.get_adjacents(*args, **kwargs)

async def get_adjacent(self, *args, **kwargs) -> AsyncGenerator:
for res in self.view.get_adjacent(*args, **kwargs):
yield res

async def get_entity(self, entity_id: str, params: "RetrieveParams") -> CE | None:
canonical = self.store.linker.get_canonical(entity_id)
proxy = get_cached_entity(self.view, canonical)
if proxy is None:
Expand All @@ -81,7 +92,9 @@ def get_entity(self, entity_id: str, params: "RetrieveParams") -> CE | None:
return get_featured_proxy(proxy)
return proxy

def get_entities(self, query: Q, params: "RetrieveParams") -> CEGenerator:
async def get_entities(
self, query: Q, params: "RetrieveParams"
) -> AsyncGenerator[CE, None]:
for proxy in self.query.entities(query):
if params.dehydrate:
proxy = get_dehydrated_proxy(proxy)
Expand Down
34 changes: 18 additions & 16 deletions ftmstore_fastapi/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,27 @@ def get_aggregation_params(


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def dataset_list(request: Request) -> CatalogResponse:
async def dataset_list(request: Request) -> CatalogResponse:
catalog = get_catalog()
datasets: list[Dataset] = []
for dataset in catalog.datasets:
view = get_view(dataset.name)
dataset.apply_stats(view.stats())
dataset.apply_stats(await view.stats())
datasets.append(dataset)
catalog.datasets = datasets
return CatalogResponse.from_catalog(request, catalog)


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def dataset_detail(request: Request, name: str) -> DatasetResponse:
async def dataset_detail(request: Request, name: str) -> DatasetResponse:
view = get_view(name)
dataset = get_dataset(name)
dataset.apply_stats(view.stats())
dataset.apply_stats(await view.stats())
return DatasetResponse.from_dataset(request, dataset)


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def entity_list(
async def entity_list(
request: Request,
retrieve_params: RetrieveParams,
authenticated: bool | None = False,
Expand All @@ -97,29 +97,29 @@ def entity_list(
params = ViewQueryParams.from_request(request, authenticated)
query = Query.from_params(params)
adjacents = []
entities = [e for e in view.get_entities(query, retrieve_params)]
entities = [e async for e in view.get_entities(query, retrieve_params)]
if retrieve_params.nested:
adjacents = view.get_adjacents(entities)
adjacents = await view.get_adjacents(entities)
return EntitiesResponse.from_view(
request=request,
entities=entities,
adjacents=adjacents,
stats=view.stats(query),
stats=await view.stats(query),
authenticated=authenticated,
)


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def entity_detail(
async def entity_detail(
request: Request,
entity_id: str,
retrieve_params: RetrieveParams,
) -> EntityResponse | RedirectResponse:
view = get_view()
entity = view.get_entity(entity_id, retrieve_params)
entity = await view.get_entity(entity_id, retrieve_params)
adjacents: Iterable[CE] = []
if retrieve_params.nested:
adjacents = [e[1] for e in view.view.get_adjacent(entity)]
adjacents = [e[1] async for e in view.get_adjacent(entity)]
if retrieve_params.dehydrate_nested:
adjacents = [get_dehydrated_proxy(e) for e in adjacents]
if entity.id != entity_id: # we have a redirect to a merged entity
Expand All @@ -133,19 +133,21 @@ def entity_detail(


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def aggregation(request: Request) -> AggregationResponse:
async def aggregation(request: Request) -> AggregationResponse:
view = get_view()
params = ViewQueryParams.from_request(request)
query = Query.from_params(params)
return AggregationResponse.from_view(
request=request,
aggregations=view.aggregations(query),
stats=view.stats(query),
aggregations=await view.aggregations(query),
stats=await view.stats(query),
)


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def search(request: Request, authenticated: bool | None = False) -> EntitiesResponse:
async def search(
request: Request, authenticated: bool | None = False
) -> EntitiesResponse:
params = SearchQueryParams.from_request(request, authenticated)
q = params.q
if q is None or len(q) < 4:
Expand All @@ -162,7 +164,7 @@ def search(request: Request, authenticated: bool | None = False) -> EntitiesResp


@anycache(key_func=get_cache_key, serialization_mode="pickle")
def autocomplete(q: str) -> AutocompleteResponse:
async def autocomplete(q: str) -> AutocompleteResponse:
if q is None or len(q) < 4:
raise HTTPException(400, [f"Invalid search query: `{q}`"])
store = get_search_store()
Expand Down

0 comments on commit 7a1aad8

Please sign in to comment.