Skip to content

Commit

Permalink
Merge pull request #6 from SCAI-BIO/weaviate
Browse files Browse the repository at this point in the history
Weaviate
  • Loading branch information
tiadams authored Jul 15, 2024
2 parents 76338a1 + 2beafa5 commit 40d3780
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 24 deletions.
8 changes: 7 additions & 1 deletion datastew/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@

# Importing submodules to expose their attributes if needed
from .process import mapping, parsing
from .repository import model, sqllite
from .repository import model, sqllite, base

__all__ = [
"mapping",
"parsing",
"model",
"base",
"sqllite",
"DataDictionarySource",
"BaseRepository",
"MPNetAdapter",
"Terminology",
"Concept",
"Mapping"
]
2 changes: 1 addition & 1 deletion datastew/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, concept: Concept, variable: Variable, source: str):
self.source = source

def __eq__(self, other):
return self.concept.identifier == other.concept.identifier and self.variable.name == other.variable.name
return self.concept.identifier == other.concept.identifier and self.variable.name == other.variable.pref_label

def __hash__(self):
return hash((self.concept.identifier, self.variable.name))
Expand Down
16 changes: 10 additions & 6 deletions datastew/repository/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,25 @@ def __init__(self, name: str, id: str) -> object:

class Concept(Base):
__tablename__ = 'concept'
id = Column(String, primary_key=True)
name = Column(String)
concept_identifier = Column(String, primary_key=True)
pref_label = Column(String)
terminology_id = Column(String, ForeignKey('terminology.id'))
terminology = relationship("Terminology")
uuid = Column(String)

def __init__(self, terminology: Terminology, name: str, id: str) -> object:
def __init__(self, terminology: Terminology, pref_label: str, concept_identifier: str, id: str = None) -> object:
self.terminology = terminology
self.name = name
self.pref_label = pref_label
# should be unique
self.concept_identifier = concept_identifier
# enforced to be unique
self.id = id


class Mapping(Base):
__tablename__ = 'mapping'
id = Column(Integer, primary_key=True, autoincrement=True) # Auto-incrementing primary key
concept_id = Column(String, ForeignKey('concept.id'))
concept_identifier = Column(String, ForeignKey('concept.concept_identifier'))
concept = relationship("Concept")
text = Column(Text)
embedding_json = Column(Text)
Expand All @@ -47,7 +51,7 @@ def __init__(self, concept: Concept, text: str, embedding: list) -> object:
self.embedding_json = json.dumps(embedding) # Store embedding as JSON

def __str__(self):
return f"{self.concept.terminology.name} > {self.concept_id} : {self.concept.name} | {self.text}"
return f"{self.concept.terminology.name} > {self.concept.concept_identifier} : {self.concept.pref_label} | {self.text}"

@property
def embedding(self):
Expand Down
2 changes: 1 addition & 1 deletion datastew/repository/sqllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class SQLLiteRepository(BaseRepository):

def __init__(self, mode="memory", path="index/db/index.db"):
def __init__(self, mode="memory", path=None):
if mode == "disk":
self.engine = create_engine(f'sqlite:///{path}')
# for tests
Expand Down
203 changes: 203 additions & 0 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import shutil
from typing import List, Union

import uuid as uuid
import weaviate

from weaviate.embedded import EmbeddedOptions

from datastew import BaseRepository
from datastew.repository import Mapping, Terminology, Concept
from datastew.repository.weaviate_schema import terminology_schema, concept_schema, mapping_schema


class WeaviateRepository(BaseRepository):

def __init__(self, mode="memory", path=None):
self.mode = mode
try:
if mode == "memory":
self.client = weaviate.Client(embedded_options=EmbeddedOptions(
persistence_data_path="db"
))
elif mode == "disk":
if path is None:
raise ValueError("Path must be provided for disk mode.")
self.client = weaviate.Client(embedded_options=EmbeddedOptions(
persistence_data_path=path
))
elif mode == "remote":
if path is None:
raise ValueError("Remote URL must be provided for remote mode.")
self.client = weaviate.Client(
url=path
)
else:
raise ValueError(f'Repository mode {mode} is not defined. Use either memory, disk or remote.')
except Exception as e:
raise ConnectionError(f"Failed to initialize Weaviate client: {e}")

self.client.schema.create_class(terminology_schema)
self.client.schema.create_class(concept_schema)
self.client.schema.create_class(mapping_schema)

def store_all(self, model_object_instances):
for instance in model_object_instances:
self.store(instance)

def get_all_concepts(self) -> List[Concept]:
concepts = []
try:
result = self.client.query.get(
"Concept",
["conceptID", "prefLabel", "hasTerminology { ... on Terminology { _additional { id } name } }"]
).with_additional("vector").do()
for item in result['data']['Get']['Concept']:
terminology_data = item["hasTerminology"][0] # Assuming it has only one terminology
terminology = Terminology(
name=terminology_data["name"],
id=terminology_data["_additional"]["id"]
)
concept = Concept(
concept_identifier=item["conceptID"],
pref_label=item["prefLabel"],
terminology=terminology,
)
concepts.append(concept)
except Exception as e:
raise RuntimeError(f"Failed to fetch concepts: {e}")
return concepts

def get_all_terminologies(self) -> List[Terminology]:
terminologies = []
try:
result = self.client.query.get("Terminology", ["name", "_additional { id }"]).do()
for item in result['data']['Get']['Terminology']:
terminology = Terminology(
name=item["name"],
id=item["_additional"]["id"]
)
terminologies.append(terminology)
except Exception as e:
raise RuntimeError(f"Failed to fetch terminologies: {e}")
return terminologies

def get_all_mappings(self, limit=1000) -> List[Mapping]:
mappings = []
try:
result = self.client.query.get(
"Mapping",
["text", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
).with_additional("vector").with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
embedding_vector = item["_additional"]["vector"]
concept_data = item["hasConcept"][0] # Assuming it has only one concept
terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology
terminology = Terminology(
name=terminology_data["name"],
id=terminology_data["_additional"]["id"]
)
concept = Concept(
concept_identifier=concept_data["conceptID"],
pref_label=concept_data["prefLabel"],
terminology=terminology,
id=concept_data["_additional"]["id"]
)
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector
)
mappings.append(mapping)
except Exception as e:
raise RuntimeError(f"Failed to fetch mappings: {e}")
return mappings

def get_closest_mappings(self, embedding, limit=5) -> List[Mapping]:
mappings = []
try:
result = self.client.query.get(
"Mapping",
["text", "_additional { distance }", "hasConcept { ... on Concept { _additional { id } conceptID prefLabel hasTerminology { ... on Terminology { _additional { id } name } } } }"]
).with_additional("vector").with_near_vector({"vector": embedding}).with_limit(limit).do()
for item in result['data']['Get']['Mapping']:
embedding_vector = item["_additional"]["vector"]
concept_data = item["hasConcept"][0] # Assuming it has only one concept
terminology_data = concept_data["hasTerminology"][0] # Assuming it has only one terminology
terminology = Terminology(
name=terminology_data["name"],
id=terminology_data["_additional"]["id"]
)
concept = Concept(
concept_identifier=concept_data["conceptID"],
pref_label=concept_data["prefLabel"],
terminology=terminology,
id=concept_data["_additional"]["id"]
)
mapping = Mapping(
text=item["text"],
concept=concept,
embedding=embedding_vector
)
mappings.append(mapping)
except Exception as e:
raise RuntimeError(f"Failed to fetch closest mappings: {e}")
return mappings

def shut_down(self):
if self.mode == "memory":
shutil.rmtree("db")

def store(self, model_object_instance: Union[Terminology, Concept, Mapping]):
random_uuid = uuid.uuid4()
model_object_instance.id = random_uuid
try:
if isinstance(model_object_instance, Terminology):
properties = {
"name": model_object_instance.name
}
self.client.data_object.create(
class_name="Terminology",
data_object=properties,
uuid=random_uuid
)
elif isinstance(model_object_instance, Concept):
model_object_instance.uuid = random_uuid
properties = {
"conceptID": model_object_instance.concept_identifier,
"prefLabel": model_object_instance.pref_label,
}
self.client.data_object.create(
class_name="Concept",
data_object=properties,
uuid=random_uuid
)
self.client.data_object.reference.add(
from_class_name="Concept",
from_uuid=random_uuid,
from_property_name="hasTerminology",
to_class_name="Terminology",
to_uuid=model_object_instance.terminology.id,
)
elif isinstance(model_object_instance, Mapping):
properties = {
"text": model_object_instance.text,
}
self.client.data_object.create(
class_name="Mapping",
data_object=properties,
uuid=random_uuid,
vector=model_object_instance.embedding
)
self.client.data_object.reference.add(
from_class_name="Mapping",
from_uuid=random_uuid,
from_property_name="hasConcept",
to_class_name="Concept",
to_uuid=model_object_instance.concept.uuid,
)
else:
raise ValueError("Unsupported model object instance type.")

except Exception as e:
raise RuntimeError(f"Failed to store object in Weaviate: {e}")
48 changes: 48 additions & 0 deletions datastew/repository/weaviate_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
terminology_schema = {
"class": "Terminology",
"description": "A terminology entry",
"properties": [
{
"name": "name",
"dataType": ["string"]
}
]
}

concept_schema = {
"class": "Concept",
"description": "A concept entry",
"properties": [
{
"name": "conceptID",
"dataType": ["string"]
},
{
"name": "prefLabel",
"dataType": ["string"]
},
{
"name": "hasTerminology",
"dataType": ["Terminology"]
}
]
}

mapping_schema = {
"class": "Mapping",
"description": "A mapping entry",
"properties": [
{
"name": "text",
"dataType": ["string"]
},
{
"name": "vector",
"dataType": ["number[]"]
},
{
"name": "hasConcept",
"dataType": ["Concept"]
}
]
}
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ wheel==0.37.1
aiofiles~=0.7.0
SQLAlchemy~=2.0.27
scipy~=1.11.4
pydantic~=1.10.14
pydantic~=2.5.0
weaviate-client~=4.6.5
Loading

0 comments on commit 40d3780

Please sign in to comment.