Skip to content

Commit

Permalink
Merge pull request #63 from SCAI-BIO/feat-adapter-parallel-processing
Browse files Browse the repository at this point in the history
feat: adapter parallel processing
  • Loading branch information
tiadams authored Jan 6, 2025
2 parents d1e2ce5 + 9dea79c commit fcd4a2f
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 58 deletions.
179 changes: 123 additions & 56 deletions datastew/embedding.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,164 @@
import concurrent.futures
import logging
from abc import ABC
from abc import ABC, abstractmethod
from typing import List, Sequence

import numpy as np
import openai
from openai.error import OpenAIError
from sentence_transformers import SentenceTransformer

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class EmbeddingModel(ABC):
def get_embedding(self, text: str) -> [float]:
pass
def __init__(self, model_name: str):
self.model_name = model_name

@abstractmethod
def get_embedding(self, text: str) -> Sequence[float]:
"""Retrieve the embedding vector for a single text input.
def get_embeddings(self, messages: [str]) -> [[float]]:
:param text: The input text to embed.
:return: A sequence of floats representing the embedding.
"""
pass

@abstractmethod
def get_embeddings(self, messages: List[str]) -> Sequence[Sequence[float]]:
"""Retrieve embeddings for a list of text messages
def get_model_name(self) -> str:
:param messages: A list of text messages to embed.
:return: A sequence of embedding vectors.
"""
pass

def get_model_name(self) -> str:
"""Return the name of the embedding model.
:return: The name of the model.
"""
return self.model_name

def sanitize(self, message: str) -> str:
"""Clean up the input text by trimming and converting to lowercase.
:param message: The input text message.
:return: Sanitized text.
"""
return message.strip().lower()


class GPT4Adapter(EmbeddingModel):
def __init__(self, api_key: str, model_name: str = "text-embedding-ada-002"):
"""Initialize the GPT-4 adapter with OpenAI API key and model name.
:param api_key: The API key for accessing OpenAI services.
:param model_name: The specific embedding model to use.
"""
super().__init__(model_name)
self.api_key = api_key
openai.api_key = api_key
self.model_name = model_name
logging.getLogger().setLevel(logging.INFO)

def get_embedding(self, text: str):
logging.info(f"Getting embedding for {text}")
def get_embedding(self, text: str) -> Sequence[float]:
"""Retrieve an embedding for a single text input using OpenAI API.
:param text: The input text to embed.
:return: A sequence of floats representing the embedding.
"""
if not text or not isinstance(text, str):
logging.warning("Empty or invalid text passed to get_embedding")
return []
text = self.sanitize(text.replace("\n", " "))
try:
if text is None or text == "" or text is np.nan:
logging.warning(f"Empty text passed to get_embedding")
return None
if isinstance(text, str):
text = text.replace("\n", " ")
text = self.sanitize(text)
return openai.Embedding.create(input=[text], model=self.model_name)["data"][0]["embedding"]
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str], max_length=2048):
sanitized_messages = [self.sanitize(message) for message in messages]
response = openai.Embedding.create(input=[text], model=self.model_name)
return response["data"][0]["embedding"]
except OpenAIError as e:
logging.error(f"OpenAI API error: {e}")
return []

def get_embeddings(self, messages: List[str], max_length: int = 2048, num_workers: int = 4) -> Sequence[Sequence[float]]:
"""Retrieve embeddings for a list of text messages using batching and multithreading.
:param messages: A list of text messages to embed.
:param max_length: Maximum length of each batch of messages.
:param num_workers: Number of threads for parallel processing.
:return: A sequence of embedding vectors.
"""
if max_length <= 0:
logging.warning(f"max_length is set to {max_length}, using default value 2048")
max_length = 2048

sanitized_messages = [self.sanitize(msg) for msg in messages]
chunks = [sanitized_messages[i:i + max_length] for i in range(0, len(sanitized_messages), max_length)]
embeddings = []
total_chunks = (len(sanitized_messages) + max_length - 1) // max_length
current_chunk = 0
for i in range(0, len(sanitized_messages), max_length):
current_chunk += 1
chunk = sanitized_messages[i:i + max_length]
response = openai.Embedding.create(input=chunk, model=self.model_name)
embeddings.extend([item["embedding"] for item in response["data"]])
logging.info("Processed chunk %d/%d", current_chunk, total_chunks)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(self._process_chunk, chunk): chunk for chunk in chunks}
for future in concurrent.futures.as_completed(futures):
try:
embeddings.extend(future.result())
except Exception as e:
logging.error(f"Error in processing chunk: {e}")
return embeddings

def get_model_name(self) -> str:
return self.model_name
def _process_chunk(self, chunk: List[str]) -> Sequence[Sequence[float]]:
"""Process a batch of text messages to retrieve embeddings.
:param chunk: A list of sanitized messages.
:return: A sequence of embedding vectors.
"""
try:
response = openai.Embedding.create(input=chunk, model=self.model_name)
return [item["embedding"] for item in response["data"]]
except Exception as e:
logging.error(f"Error processing chunk: {e}")
return []


class MPNetAdapter(EmbeddingModel):
def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2"):
logging.getLogger().setLevel(logging.INFO)
"""Initialize the MPNet adapter with a specified model name and threading settings.
:param model_name: The model name for sentence transformers.
:param num_threads: The number of CPU threads for inference.
"""
super().__init__(model_name)
self.model = SentenceTransformer(model_name)
self.model_name = model_name # For Weaviate

def get_embedding(self, text: str):
logging.info(f"Getting embedding for {text}")
def get_embedding(self, text: str) -> Sequence[float]:
"""Retrieve an embedding for a single text input using MPnet.
:param text: The input text to embed.
:return: A sequence of floats representing the embedding.
"""
if not text or not isinstance(text, str):
logging.warning("Empty or invalid text passed to get_embedding")
return []
text = self.sanitize(text.replace("\n", " "))
try:
if text is None or text == "" or text is np.nan:
logging.warn(f"Empty text passed to get_embedding")
return None
if isinstance(text, str):
text = text.replace("\n", " ")
text = self.sanitize(text)
return self.model.encode(text)
embedding = self.model.encode(text)
return embedding
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str]) -> [[float]]:
sanitized_messages = [self.sanitize(message) for message in messages]
return []

def get_embeddings(self, messages: List[str], batch_size: int = 64) -> Sequence[Sequence[float]]:
"""Retrieve embeddings for a list of text messages using MPNet.
:param messages: A list of text messages to embed.
:param batch_size: The batch size for processing.
:return: A sequence of embedding vectors.
"""
sanitized_messages = [self.sanitize(msg) for msg in messages]
try:
embeddings = self.model.encode(sanitized_messages)
embeddings = self.model.encode(sanitized_messages, batch_size=batch_size, show_progress_bar=True)
flattened_embeddings = [[float(element) for element in row] for row in embeddings]
return flattened_embeddings
except Exception as e:
logging.error(f"Failed for messages {sanitized_messages}")
flattened_embeddings = [[float(element) for element in row] for row in embeddings]
return flattened_embeddings

def get_model_name(self) -> str:
return self.model_name
logging.error(f"Failed processing messages: {e}")
return []


class TextEmbedding:
def __init__(self, text: str, embedding: [float]):
def __init__(self, text: str, embedding: List[float]):
self.text = text
self.embedding = embedding
4 changes: 2 additions & 2 deletions datastew/scripts/ols_snomed_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from datastew.process.ols import OLSTerminologyImportTask
from datastew.repository.sqllite import SQLLiteRepository

repository = SQLLiteRepository(name="snomed")
repository = SQLLiteRepository()
embedding_model = MPNetAdapter()

task = OLSTerminologyImportTask(repository, embedding_model, "SNONMED CT", "snomed")
task = OLSTerminologyImportTask(repository, embedding_model, "SNOMED CT", "snomed")
task.process()
print("done")

0 comments on commit fcd4a2f

Please sign in to comment.