diff --git a/datastew/repository/weaviate.py b/datastew/repository/weaviate.py index f2337f1..22c4c87 100644 --- a/datastew/repository/weaviate.py +++ b/datastew/repository/weaviate.py @@ -1,6 +1,6 @@ import logging import shutil -from typing import List, Union +from typing import List, Union, Tuple import uuid as uuid import weaviate from weaviate.embedded import EmbeddedOptions @@ -11,7 +11,6 @@ class WeaviateRepository(BaseRepository): - logger = logging.getLogger(__name__) def __init__(self, mode="memory", path=None): @@ -101,7 +100,8 @@ def get_all_mappings(self, limit=1000) -> List[Mapping]: try: result = self.client.query.get( "Mapping", - ["text", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] + ["text", + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] ).with_additional("vector").with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] @@ -132,7 +132,8 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: try: result = self.client.query.get( "Mapping", - ["text", "_additional { distance }", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] + ["text", "_additional { distance }", + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() for item in result['data']['Get']['Mapping']: embedding_vector = item["_additional"]["vector"] @@ -158,6 +159,39 @@ def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]: raise RuntimeError(f"Failed to fetch closest mappings: {e}") return mappings + def get_closest_mappings_with_similarities(self, embedding, limit=5) -> List[Tuple[Mapping, float]]: + mappings_with_similarities = [] + try: + result = self.client.query.get( + "Mapping", + ["text", "_additional { distance }", + "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"] + ).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do() + for item in result['data']['Get']['Mapping']: + similarity = 1 - item["_additional"]["distance"] + embedding_vector = item["_additional"]["vector"] + concept_data = item["hasConcept"][0] # Assuming it has only one concept + terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology + terminology = Terminology( + name=terminology_data["name"], + id=terminology_data["_additional"]["id"] + ) + concept = Concept( + concept_identifier=concept_data["conceptID"], + pref_label=concept_data["prefLabel"], + terminology=terminology, + id=concept_data["_additional"]["id"] + ) + mapping = Mapping( + text=item["text"], + concept=concept, + embedding=embedding_vector + ) + mappings_with_similarities.append((mapping, similarity)) + except Exception as e: + raise RuntimeError(f"Failed to fetch closest mappings with similarities: {e}") + return mappings_with_similarities + def shut_down(self): if self.mode == "memory": shutil.rmtree("db") diff --git a/tests/test_weaviate_repository.py b/tests/test_weaviate_repository.py index 2aea22b..3e7599b 100644 --- a/tests/test_weaviate_repository.py +++ b/tests/test_weaviate_repository.py @@ -68,10 +68,17 @@ def test_repository(self): terminologies = repository.get_all_terminologies() self.assertEqual(len(terminologies), 1) - closest_mappings = repository.get_closest_mappings(embedding_model.get_embedding(text10)) + test_embedding = embedding_model.get_embedding(text10) + + closest_mappings = repository.get_closest_mappings(test_embedding) self.assertEqual(len(closest_mappings), 5) self.assertEqual(closest_mappings[0].text, "Influenza") + closest_mappings_with_similarities = repository.get_closest_mappings_with_similarities(test_embedding) + self.assertEqual(len(closest_mappings_with_similarities), 5) + self.assertEqual(closest_mappings_with_similarities[0][0].text, "Influenza") + self.assertEqual(closest_mappings_with_similarities[0][1], 0.86187172) + # check if it crashed (due to schema re-creation) after restart repository = WeaviateRepository(mode="disk", path="db") @@ -82,8 +89,3 @@ def test_repository(self): concept8, mapping8, concept9, mapping9 ]) - - - - -