Skip to content

Commit

Permalink
Merge pull request #10 from SCAI-BIO/weaviate
Browse files Browse the repository at this point in the history
Add failsafe for duplicate entries and DB restarts
  • Loading branch information
tiadams authored Jul 15, 2024
2 parents 0ba2d04 + c07b9be commit d0b3891
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 47 deletions.
146 changes: 101 additions & 45 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import shutil
from typing import List, Union

import uuid as uuid
import weaviate

from weaviate.embedded import EmbeddedOptions

from datastew.repository import Mapping, Terminology, Concept
Expand All @@ -13,6 +12,8 @@

class WeaviateRepository(BaseRepository):

logger = logging.getLogger(__name__)

def __init__(self, mode="memory", path=None):
self.mode = mode
try:
Expand All @@ -37,9 +38,22 @@ def __init__(self, mode="memory", path=None):
except Exception as e:
raise ConnectionError(f"Failed to initialize Weaviate client: {e}")

self.client.schema.create_class(terminology_schema)
self.client.schema.create_class(concept_schema)
self.client.schema.create_class(mapping_schema)
try:
self._create_schema_if_not_exists(terminology_schema)
self._create_schema_if_not_exists(concept_schema)
self._create_schema_if_not_exists(mapping_schema)
except Exception as e:
raise RuntimeError(f"Failed to create schema: {e}")

def _create_schema_if_not_exists(self, schema):
class_name = schema['class']
try:
if not self.client.schema.exists(class_name):
self.client.schema.create_class(schema)
else:
self.logger.info(f"Schema for {class_name} already exists. Skipping.")
except Exception as e:
raise RuntimeError(f"Failed to check/create schema for {class_name}: {e}")

def store_all(self, model_object_instances):
for instance in model_object_instances:
Expand Down Expand Up @@ -153,51 +167,93 @@ def store(self, model_object_instance: Union[Terminology, Concept, Mapping]):
model_object_instance.id = random_uuid
try:
if isinstance(model_object_instance, Terminology):
properties = {
"name": model_object_instance.name
}
self.client.data_object.create(
class_name="Terminology",
data_object=properties,
uuid=random_uuid
)
if not self._terminology_exists(model_object_instance.name):
properties = {
"name": model_object_instance.name
}
self.client.data_object.create(
class_name="Terminology",
data_object=properties,
uuid=random_uuid
)
else:
self.logger.info(f'Terminology with name {model_object_instance.name} already exists. Skipping.')
elif isinstance(model_object_instance, Concept):
model_object_instance.uuid = random_uuid
properties = {
"conceptID": model_object_instance.concept_identifier,
"prefLabel": model_object_instance.pref_label,
}
self.client.data_object.create(
class_name="Concept",
data_object=properties,
uuid=random_uuid
)
self.client.data_object.reference.add(
from_class_name="Concept",
from_uuid=random_uuid,
from_property_name="hasTerminology",
to_class_name="Terminology",
to_uuid=model_object_instance.terminology.id,
)
if not self._concept_exists(model_object_instance.concept_identifier):
properties = {
"conceptID": model_object_instance.concept_identifier,
"prefLabel": model_object_instance.pref_label,
}
self.client.data_object.create(
class_name="Concept",
data_object=properties,
uuid=random_uuid
)
self.client.data_object.reference.add(
from_class_name="Concept",
from_uuid=random_uuid,
from_property_name="hasTerminology",
to_class_name="Terminology",
to_uuid=model_object_instance.terminology.id,
)
else:
self.logger.info(f'Concept with identifier {model_object_instance.concept_identifier} '
f'already exists. Skipping.')
elif isinstance(model_object_instance, Mapping):
properties = {
"text": model_object_instance.text,
}
self.client.data_object.create(
class_name="Mapping",
data_object=properties,
uuid=random_uuid,
vector=model_object_instance.embedding
)
self.client.data_object.reference.add(
from_class_name="Mapping",
from_uuid=random_uuid,
from_property_name="hasConcept",
to_class_name="Concept",
to_uuid=model_object_instance.concept.uuid,
)
if not self._mapping_exists(model_object_instance.embedding):
properties = {
"text": model_object_instance.text,
}
self.client.data_object.create(
class_name="Mapping",
data_object=properties,
uuid=random_uuid,
vector=model_object_instance.embedding
)
self.client.data_object.reference.add(
from_class_name="Mapping",
from_uuid=random_uuid,
from_property_name="hasConcept",
to_class_name="Concept",
to_uuid=model_object_instance.concept.uuid,
)
else:
self.logger.info(f'Mapping with same embedding already exists. Skipping.')
else:
raise ValueError("Unsupported model object instance type.")

except Exception as e:
raise RuntimeError(f"Failed to store object in Weaviate: {e}")

def _terminology_exists(self, name: str) -> bool:
try:
result = self.client.query.get("Terminology", ["name"]).with_where({
"path": ["name"],
"operator": "Equal",
"valueText": name
}).do()
return len(result['data']['Get']['Terminology']) > 0
except Exception as e:
raise RuntimeError(f"Failed to check if terminology exists: {e}")

def _concept_exists(self, concept_id: str) -> bool:
try:
result = self.client.query.get("Concept", ["conceptID"]).with_where({
"path": ["conceptID"],
"operator": "Equal",
"valueText": concept_id
}).do()
return len(result['data']['Get']['Concept']) > 0
except Exception as e:
raise RuntimeError(f"Failed to check if concept exists: {e}")

def _mapping_exists(self, embedding) -> bool:
try:
result = self.client.query.get("Mapping", ["_additional { vector }"]).with_near_vector({
"vector": embedding,
"distance": float(0) # Ensure distance is explicitly casted to float
}).do()
return len(result['data']['Get']['Mapping']) > 0
except Exception as e:
raise RuntimeError(f"Failed to check if mapping exists: {e}")
17 changes: 15 additions & 2 deletions tests/test_weaviate_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Test(TestCase):

def test_repository(self):

repository = WeaviateRepository(mode="memory")
repository = WeaviateRepository(mode="disk", path="db")

embedding_model = MPNetAdapter()

Expand Down Expand Up @@ -72,5 +72,18 @@ def test_repository(self):
self.assertEqual(len(closest_mappings), 5)
self.assertEqual(closest_mappings[0].text, "Influenza")

repository.shut_down()
# check if it crashed (due to schema re-creation) after restart
repository = WeaviateRepository(mode="disk", path="db")

# try to store all again (should not create new entries since they already exist)
repository.store_all([
terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3,
concept4, mapping4, concept5, mapping5, concept6, mapping6, concept7, mapping7,
concept8, mapping8, concept9, mapping9
])






0 comments on commit d0b3891

Please sign in to comment.