-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #63 from SCAI-BIO/feat-adapter-parallel-processing
feat: adapter parallel processing
- Loading branch information
Showing
2 changed files
with
125 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,164 @@ | ||
import concurrent.futures | ||
import logging | ||
from abc import ABC | ||
from abc import ABC, abstractmethod | ||
from typing import List, Sequence | ||
|
||
import numpy as np | ||
import openai | ||
from openai.error import OpenAIError | ||
from sentence_transformers import SentenceTransformer | ||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | ||
|
||
|
||
class EmbeddingModel(ABC): | ||
def get_embedding(self, text: str) -> [float]: | ||
pass | ||
def __init__(self, model_name: str): | ||
self.model_name = model_name | ||
|
||
@abstractmethod | ||
def get_embedding(self, text: str) -> Sequence[float]: | ||
"""Retrieve the embedding vector for a single text input. | ||
def get_embeddings(self, messages: [str]) -> [[float]]: | ||
:param text: The input text to embed. | ||
:return: A sequence of floats representing the embedding. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_embeddings(self, messages: List[str]) -> Sequence[Sequence[float]]: | ||
"""Retrieve embeddings for a list of text messages | ||
def get_model_name(self) -> str: | ||
:param messages: A list of text messages to embed. | ||
:return: A sequence of embedding vectors. | ||
""" | ||
pass | ||
|
||
def get_model_name(self) -> str: | ||
"""Return the name of the embedding model. | ||
:return: The name of the model. | ||
""" | ||
return self.model_name | ||
|
||
def sanitize(self, message: str) -> str: | ||
"""Clean up the input text by trimming and converting to lowercase. | ||
:param message: The input text message. | ||
:return: Sanitized text. | ||
""" | ||
return message.strip().lower() | ||
|
||
|
||
class GPT4Adapter(EmbeddingModel): | ||
def __init__(self, api_key: str, model_name: str = "text-embedding-ada-002"): | ||
"""Initialize the GPT-4 adapter with OpenAI API key and model name. | ||
:param api_key: The API key for accessing OpenAI services. | ||
:param model_name: The specific embedding model to use. | ||
""" | ||
super().__init__(model_name) | ||
self.api_key = api_key | ||
openai.api_key = api_key | ||
self.model_name = model_name | ||
logging.getLogger().setLevel(logging.INFO) | ||
|
||
def get_embedding(self, text: str): | ||
logging.info(f"Getting embedding for {text}") | ||
def get_embedding(self, text: str) -> Sequence[float]: | ||
"""Retrieve an embedding for a single text input using OpenAI API. | ||
:param text: The input text to embed. | ||
:return: A sequence of floats representing the embedding. | ||
""" | ||
if not text or not isinstance(text, str): | ||
logging.warning("Empty or invalid text passed to get_embedding") | ||
return [] | ||
text = self.sanitize(text.replace("\n", " ")) | ||
try: | ||
if text is None or text == "" or text is np.nan: | ||
logging.warning(f"Empty text passed to get_embedding") | ||
return None | ||
if isinstance(text, str): | ||
text = text.replace("\n", " ") | ||
text = self.sanitize(text) | ||
return openai.Embedding.create(input=[text], model=self.model_name)["data"][0]["embedding"] | ||
except Exception as e: | ||
logging.error(f"Error getting embedding for {text}: {e}") | ||
return None | ||
|
||
def get_embeddings(self, messages: [str], max_length=2048): | ||
sanitized_messages = [self.sanitize(message) for message in messages] | ||
response = openai.Embedding.create(input=[text], model=self.model_name) | ||
return response["data"][0]["embedding"] | ||
except OpenAIError as e: | ||
logging.error(f"OpenAI API error: {e}") | ||
return [] | ||
|
||
def get_embeddings(self, messages: List[str], max_length: int = 2048, num_workers: int = 4) -> Sequence[Sequence[float]]: | ||
"""Retrieve embeddings for a list of text messages using batching and multithreading. | ||
:param messages: A list of text messages to embed. | ||
:param max_length: Maximum length of each batch of messages. | ||
:param num_workers: Number of threads for parallel processing. | ||
:return: A sequence of embedding vectors. | ||
""" | ||
if max_length <= 0: | ||
logging.warning(f"max_length is set to {max_length}, using default value 2048") | ||
max_length = 2048 | ||
|
||
sanitized_messages = [self.sanitize(msg) for msg in messages] | ||
chunks = [sanitized_messages[i:i + max_length] for i in range(0, len(sanitized_messages), max_length)] | ||
embeddings = [] | ||
total_chunks = (len(sanitized_messages) + max_length - 1) // max_length | ||
current_chunk = 0 | ||
for i in range(0, len(sanitized_messages), max_length): | ||
current_chunk += 1 | ||
chunk = sanitized_messages[i:i + max_length] | ||
response = openai.Embedding.create(input=chunk, model=self.model_name) | ||
embeddings.extend([item["embedding"] for item in response["data"]]) | ||
logging.info("Processed chunk %d/%d", current_chunk, total_chunks) | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: | ||
futures = {executor.submit(self._process_chunk, chunk): chunk for chunk in chunks} | ||
for future in concurrent.futures.as_completed(futures): | ||
try: | ||
embeddings.extend(future.result()) | ||
except Exception as e: | ||
logging.error(f"Error in processing chunk: {e}") | ||
return embeddings | ||
|
||
def get_model_name(self) -> str: | ||
return self.model_name | ||
def _process_chunk(self, chunk: List[str]) -> Sequence[Sequence[float]]: | ||
"""Process a batch of text messages to retrieve embeddings. | ||
:param chunk: A list of sanitized messages. | ||
:return: A sequence of embedding vectors. | ||
""" | ||
try: | ||
response = openai.Embedding.create(input=chunk, model=self.model_name) | ||
return [item["embedding"] for item in response["data"]] | ||
except Exception as e: | ||
logging.error(f"Error processing chunk: {e}") | ||
return [] | ||
|
||
|
||
class MPNetAdapter(EmbeddingModel): | ||
def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2"): | ||
logging.getLogger().setLevel(logging.INFO) | ||
"""Initialize the MPNet adapter with a specified model name and threading settings. | ||
:param model_name: The model name for sentence transformers. | ||
:param num_threads: The number of CPU threads for inference. | ||
""" | ||
super().__init__(model_name) | ||
self.model = SentenceTransformer(model_name) | ||
self.model_name = model_name # For Weaviate | ||
|
||
def get_embedding(self, text: str): | ||
logging.info(f"Getting embedding for {text}") | ||
def get_embedding(self, text: str) -> Sequence[float]: | ||
"""Retrieve an embedding for a single text input using MPnet. | ||
:param text: The input text to embed. | ||
:return: A sequence of floats representing the embedding. | ||
""" | ||
if not text or not isinstance(text, str): | ||
logging.warning("Empty or invalid text passed to get_embedding") | ||
return [] | ||
text = self.sanitize(text.replace("\n", " ")) | ||
try: | ||
if text is None or text == "" or text is np.nan: | ||
logging.warn(f"Empty text passed to get_embedding") | ||
return None | ||
if isinstance(text, str): | ||
text = text.replace("\n", " ") | ||
text = self.sanitize(text) | ||
return self.model.encode(text) | ||
embedding = self.model.encode(text) | ||
return embedding | ||
except Exception as e: | ||
logging.error(f"Error getting embedding for {text}: {e}") | ||
return None | ||
|
||
def get_embeddings(self, messages: [str]) -> [[float]]: | ||
sanitized_messages = [self.sanitize(message) for message in messages] | ||
return [] | ||
|
||
def get_embeddings(self, messages: List[str], batch_size: int = 64) -> Sequence[Sequence[float]]: | ||
"""Retrieve embeddings for a list of text messages using MPNet. | ||
:param messages: A list of text messages to embed. | ||
:param batch_size: The batch size for processing. | ||
:return: A sequence of embedding vectors. | ||
""" | ||
sanitized_messages = [self.sanitize(msg) for msg in messages] | ||
try: | ||
embeddings = self.model.encode(sanitized_messages) | ||
embeddings = self.model.encode(sanitized_messages, batch_size=batch_size, show_progress_bar=True) | ||
flattened_embeddings = [[float(element) for element in row] for row in embeddings] | ||
return flattened_embeddings | ||
except Exception as e: | ||
logging.error(f"Failed for messages {sanitized_messages}") | ||
flattened_embeddings = [[float(element) for element in row] for row in embeddings] | ||
return flattened_embeddings | ||
|
||
def get_model_name(self) -> str: | ||
return self.model_name | ||
logging.error(f"Failed processing messages: {e}") | ||
return [] | ||
|
||
|
||
class TextEmbedding: | ||
def __init__(self, text: str, embedding: [float]): | ||
def __init__(self, text: str, embedding: List[float]): | ||
self.text = text | ||
self.embedding = embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters