Skip to content

Commit

Permalink
refactor: remove device handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetcanay committed Dec 20, 2024
1 parent eefefde commit c85ebe7
Showing 1 changed file with 1 addition and 37 deletions.
38 changes: 1 addition & 37 deletions datastew/embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import concurrent.futures
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Sequence

import openai
import torch
from openai.error import OpenAIError
from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -124,19 +122,7 @@ def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2", num_thr
:param num_threads: The number of CPU threads for inference.
"""
super().__init__(model_name)
device, available_threads = self._initialize_device(num_threads)
self.model = SentenceTransformer(model_name).to(device)

if device == "cpu":
torch.set_num_threads(available_threads)
logging.info(
f"MPNet model '{model_name} initialized on CPU with {available_threads} threads."
)
elif device == "cuda":
logging.info(
f"MPNet model '{model_name}' initialized on GPU. GPU thread management is handled automatically."
)

self.model = SentenceTransformer(model_name)

def get_embedding(self, text: str) -> Sequence[float]:
"""Retrieve an embedding for a single text input using MPnet.
Expand Down Expand Up @@ -170,28 +156,6 @@ def get_embeddings(self, messages: List[str], batch_size: int = 64) -> Sequence[
except Exception as e:
logging.error(f"Failed processing messages: {e}")
return []

def _initialize_device(self, num_threads: int) -> tuple[str, int]:
"""Determine the appropriate device (CPU or GPU) and set the thread count.
:param num_threads: The requested number of threads for inference.
:return: A tuple containing the selected device ("cuda" or "cpu") and the final number of threads to use.
:raise RuntimeError: If CPU core count cannot be determined when no GPU is available.
"""
if torch.cuda.is_available():
return "cuda", num_threads # num_threads does not affect GPU operations

# Fallback to CPU
cpu_count = os.cpu_count()
if cpu_count is None:
raise RuntimeError("Unable to determine the number of CPU cores.")

threads = min(num_threads, cpu_count)
if num_threads > cpu_count:
logging.warning(
f"Requested {num_threads} threads, but only {cpu_count} CPU cores available. Using {cpu_count} threads."
)
return "cpu", threads


class TextEmbedding:
Expand Down

0 comments on commit c85ebe7

Please sign in to comment.