diff --git a/datastew/embedding.py b/datastew/embedding.py index 6724021..28eb4ec 100644 --- a/datastew/embedding.py +++ b/datastew/embedding.py @@ -1,11 +1,9 @@ import concurrent.futures import logging -import os from abc import ABC, abstractmethod from typing import List, Sequence import openai -import torch from openai.error import OpenAIError from sentence_transformers import SentenceTransformer @@ -124,19 +122,7 @@ def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2", num_thr :param num_threads: The number of CPU threads for inference. """ super().__init__(model_name) - device, available_threads = self._initialize_device(num_threads) - self.model = SentenceTransformer(model_name).to(device) - - if device == "cpu": - torch.set_num_threads(available_threads) - logging.info( - f"MPNet model '{model_name} initialized on CPU with {available_threads} threads." - ) - elif device == "cuda": - logging.info( - f"MPNet model '{model_name}' initialized on GPU. GPU thread management is handled automatically." - ) - + self.model = SentenceTransformer(model_name) def get_embedding(self, text: str) -> Sequence[float]: """Retrieve an embedding for a single text input using MPnet. @@ -170,28 +156,6 @@ def get_embeddings(self, messages: List[str], batch_size: int = 64) -> Sequence[ except Exception as e: logging.error(f"Failed processing messages: {e}") return [] - - def _initialize_device(self, num_threads: int) -> tuple[str, int]: - """Determine the appropriate device (CPU or GPU) and set the thread count. - - :param num_threads: The requested number of threads for inference. - :return: A tuple containing the selected device ("cuda" or "cpu") and the final number of threads to use. - :raise RuntimeError: If CPU core count cannot be determined when no GPU is available. - """ - if torch.cuda.is_available(): - return "cuda", num_threads # num_threads does not affect GPU operations - - # Fallback to CPU - cpu_count = os.cpu_count() - if cpu_count is None: - raise RuntimeError("Unable to determine the number of CPU cores.") - - threads = min(num_threads, cpu_count) - if num_threads > cpu_count: - logging.warning( - f"Requested {num_threads} threads, but only {cpu_count} CPU cores available. Using {cpu_count} threads." - ) - return "cpu", threads class TextEmbedding: