Skip to content

Commit

Permalink
Add more tests for SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
bjonnh committed Dec 27, 2023
1 parent bce7174 commit d7c059a
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 269 deletions.
94 changes: 28 additions & 66 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging

from fastapi import FastAPI, HTTPException
from fastapi import Depends, FastAPI, HTTPException
from fastapi_versioning import VersionedFastAPI, version

from api.models import (CoupleResult, Item, ReferenceInfo, ReferenceResult,
StructureInfo, StructureResult, TaxonInfo, TaxonResult)
from api.queries import ( # get_matching_references_from_couple_in_item,
get_matching_references_from_reference_in_item,
combine_and_filter_outputs, get_matching_references_from_reference_in_item,
get_matching_references_from_structure_in_item,
get_matching_references_from_taxon_in_item,
get_matching_structures_from_reference_in_item,
Expand All @@ -21,13 +21,14 @@
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


description = """
LOTUSFast API helps you do awesome stuff. 🚀
"""

dm = DataModel()
# Should likely move in the data model if that's used all the time

def get_dm():
return DataModel()


app = FastAPI(
title="LOTUS FastAPI",
Expand All @@ -50,7 +51,7 @@

@app.post("/couples")
@version(1, 0)
async def search_couples(item: Item) -> CoupleResult:
async def search_couples(item: Item, dm: DataModel = Depends(get_dm)) -> CoupleResult:
# WIP
selected_references = get_matching_references_from_reference_in_item(dm, item)
selected_structures = get_matching_structures_from_structure_in_item(dm, item)
Expand Down Expand Up @@ -101,7 +102,7 @@ async def search_couples(item: Item) -> CoupleResult:

@app.post("/structures")
@version(1, 0)
async def search_structures(item: Item) -> StructureResult:
async def search_structures(item: Item, dm: DataModel = Depends(get_dm)) -> StructureResult:
matching_structures_by_structure = get_matching_structures_from_structure_in_item(
dm, item
)
Expand All @@ -110,102 +111,63 @@ async def search_structures(item: Item) -> StructureResult:
dm, item
)

non_empty_sets = [
s
for s in [
matching_structures_by_reference,
matching_structures_by_taxon,
matching_structures_by_structure,
]
if s
]
matching_structures = set.intersection(*non_empty_sets) if non_empty_sets else set()
ids = combine_and_filter_outputs([matching_structures_by_reference,
matching_structures_by_taxon,
matching_structures_by_structure], limit=item.limit)

items = list(dm.get_dict_of_sid_to_smiles(matching_structures).items())

if item.limit == 0:
items = items
else:
items = items[:item.limit]
items = list(dm.get_dict_of_sid_to_smiles(ids).items())

return StructureResult(
ids=matching_structures,
ids=ids,
structures={sid: StructureInfo(smiles=value) for sid, value in items},
description="Structures matching the query",
count=len(matching_structures),
count=len(items),
)


@app.post("/taxa")
@version(1, 0)
async def search_taxa(item: Item) -> TaxonResult:
async def search_taxa(item: Item, dm: DataModel = Depends(get_dm)) -> TaxonResult:
matching_taxa_by_taxon = get_matching_taxa_from_taxon_in_item(dm, item)
matching_taxa_by_structure = get_matching_taxa_from_structure_in_item(dm, item)
matching_taxa_by_reference = get_matching_taxa_from_reference_in_item(dm, item)

non_empty_sets = [
s
for s in [
matching_taxa_by_reference,
matching_taxa_by_structure,
matching_taxa_by_taxon,
]
if s
]
matching_taxa = set.intersection(*non_empty_sets) if non_empty_sets else set()
ids = combine_and_filter_outputs([matching_taxa_by_reference,
matching_taxa_by_structure,
matching_taxa_by_taxon], limit=item.limit)

items = list(dm.get_dict_of_tid_to_taxon_name(matching_taxa).items())

if item.limit == 0:
items = items
else:
items = items[: item.limit]
items = list(dm.get_dict_of_tid_to_taxon_name(ids).items())

return TaxonResult(
ids=matching_taxa,
ids=ids,
taxa={tid: TaxonInfo(name=value) for tid, value in items},
description="Taxa matching the query",
count=len(matching_taxa),
count=len(ids),
)


@app.post("/references")
@version(1, 0)
async def search_references(item: Item) -> ReferenceResult:
async def search_references(item: Item, dm: DataModel = Depends(get_dm)) -> ReferenceResult:
matching_references_by_reference = get_matching_references_from_reference_in_item(
dm, item
)
# matching_references_by_couple = get_matching_references_from_couple_in_item(
# dm, item
# )
matching_references_by_structure = get_matching_references_from_structure_in_item(
dm, item
)
matching_references_by_taxon = get_matching_references_from_taxon_in_item(dm, item)

non_empty_sets = [
s
for s in [
matching_references_by_reference,
matching_references_by_structure,
matching_references_by_taxon,
]
if s
]
matching_references = set.intersection(*non_empty_sets) if non_empty_sets else set()
ids = combine_and_filter_outputs([matching_references_by_reference,
matching_references_by_structure,
matching_references_by_taxon], limit=item.limit)

items = list(dm.get_dict_of_rid_to_reference_doi(matching_references).items())

if item.limit == 0:
items = items
else:
items = items[: item.limit]
items = list(dm.get_dict_of_rid_to_reference_doi(ids).items())

return ReferenceResult(
ids=matching_references,
ids=ids,
references={rid: ReferenceInfo(doi=value) for rid, value in items},
description="References matching the query",
count=len(matching_references),
count=len(ids),
)


Expand Down
72 changes: 52 additions & 20 deletions api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
)


def get_matching_references_from_reference_in_item(
dm: DataModel, item: Item
) -> set[int]:
def get_matching_references_from_reference_in_item(dm: DataModel, item: Item) -> set[int] | None:
"""Returns the WID of matching references."""
references = set()
references = None
if item.reference_doi and item.reference_wid:
raise HTTPException(
status_code=500,
Expand All @@ -26,6 +24,8 @@ def get_matching_references_from_reference_in_item(
if item.reference_wid:
if item.reference_wid in dm.get_references():
return {item.reference_wid}
else:
return set()
else:
if item.reference_doi:
references = set(dm.get_references_with_doi(item.reference_doi))
Expand All @@ -39,11 +39,10 @@ def get_matching_references_from_reference_in_item(
return references


def get_matching_structures_from_structure_in_item(
dm: DataModel, item: Item
) -> set[int]:
def get_matching_structures_from_structure_in_item(dm: DataModel, item: Item) -> set[int] | None:
"""Returns the WID of matching structures."""
structures = set()
structures = None

if item.structure and item.structure_wid:
raise HTTPException(
status_code=500,
Expand All @@ -54,6 +53,8 @@ def get_matching_structures_from_structure_in_item(
if item.structure_wid:
if item.structure_wid in dm.structures_set():
return {item.structure_wid}
else:
return set()
else:
if item.structure:
if item.substructure_search:
Expand Down Expand Up @@ -85,14 +86,16 @@ def get_matching_structures_from_structure_in_item(
return structures


def get_matching_taxa_from_taxon_in_item(dm: DataModel, item: Item) -> set[int]:
def get_matching_taxa_from_taxon_in_item(dm: DataModel, item: Item) -> set[int] | None:
"""Returns the WID of matching taxa."""
taxa = set()
taxa = None
if item.taxon_wid is not None or item.taxon_name is not None:
# This needs to be explained in the API doc
if item.taxon_wid:
if item.taxon_wid in dm.get_taxa():
return {item.taxon_wid}
else:
return set()
else:
if item.taxon_name:
taxa = set(dm.get_taxa_with_name_containing(item.taxon_name))
Expand Down Expand Up @@ -130,31 +133,39 @@ def get_matching_taxa_from_taxon_in_item(dm: DataModel, item: Item) -> set[int]:
# return out


def get_matching_references_from_structure_in_item(
dm: DataModel, item: Item
) -> set[int]:
def get_matching_references_from_structure_in_item(dm: DataModel, item: Item) -> set[int] | None:
structures = get_matching_structures_from_structure_in_item(dm, item)

if structures is None:
return None

return dm.get_references_of_structures(structures)


def get_matching_references_from_taxon_in_item(dm: DataModel, item: Item) -> set[int]:
def get_matching_references_from_taxon_in_item(dm: DataModel, item: Item) -> set[int] | None:
taxa = get_matching_taxa_from_taxon_in_item(dm, item)

if taxa is None:
return None

return dm.get_references_of_taxa(taxa)


def get_matching_structures_from_reference_in_item(
dm: DataModel, item: Item
) -> set[int]:
def get_matching_structures_from_reference_in_item(dm: DataModel, item: Item) -> set[int] | None:
references = get_matching_references_from_reference_in_item(dm, item)

if references is None:
return None

return dm.get_structures_of_references(references)


def get_matching_structures_from_taxon_in_item(dm: DataModel, item: Item) -> set[int]:
def get_matching_structures_from_taxon_in_item(dm: DataModel, item: Item) -> set[int] | None:
taxa = get_matching_taxa_from_taxon_in_item(dm, item)

if taxa is None:
return None

# TODO Set recursive=True to have all the structures from the parents too?
# We may have issues if we have a lot, and it will require a bit more work to get it with the db
# We could also have all the parenting relations in the DB and it would be much much faster
Expand All @@ -165,13 +176,34 @@ def get_matching_structures_from_taxon_in_item(dm: DataModel, item: Item) -> set
return out


def get_matching_taxa_from_structure_in_item(dm: DataModel, item: Item) -> set[int]:
def get_matching_taxa_from_structure_in_item(dm: DataModel, item: Item) -> set[int] | None:
structures = get_matching_structures_from_structure_in_item(dm, item)

if structures is None:
return None

return dm.get_taxa_of_structures(structures)


def get_matching_taxa_from_reference_in_item(dm: DataModel, item: Item) -> set[int]:
def get_matching_taxa_from_reference_in_item(dm: DataModel, item: Item) -> set[int] | None:
references = get_matching_references_from_reference_in_item(dm, item)

if references is None:
return None

return dm.get_taxa_of_references(references)


def combine_and_filter_outputs(sets: list[set], limit: int) -> list[int]:
non_none_outputs = [
s
for s in sets
if s is not None
]

items = list(set.intersection(*non_none_outputs) if non_none_outputs else set())

if limit == 0:
return items
else:
return items[:limit]
2 changes: 2 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __new__(cls, path: Path = Path("./data")):
cls.instance = super(DataModel, cls).__new__(cls)
cls.instance.db = cls.load_all_data(path)
cls.instance.storage = Storage(path)
cls.instance.path = path
return cls.instance

@classmethod
Expand Down Expand Up @@ -176,6 +177,7 @@ def get_structure_smiles_from_list_of_sids(self, sids: list[int]) -> list[str]:
return [self.db["structure_smiles"][i] for i in ids if 0 <= i < len(llen)]

def get_dict_of_sid_to_smiles(self, sid: Iterable[int]) -> dict[int, str]:
# TODO This is gross lets use the db
ids = {
s: self.db["structure_id"][s] for s in sid if s in self.db["structure_id"]
}
Expand Down
Loading

0 comments on commit d7c059a

Please sign in to comment.