diff --git a/lang_sam/models/__init__.py b/lang_sam/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lang_sam/models/gdino.py b/lang_sam/models/gdino.py index e9b6afe..4522a90 100644 --- a/lang_sam/models/gdino.py +++ b/lang_sam/models/gdino.py @@ -2,8 +2,11 @@ import torch from PIL import Image from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor +from lang_sam.models.utils import get_device_type + +device_type = get_device_type() +DEVICE = torch.device(device_type) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: @@ -18,7 +21,9 @@ def __init__(self): def build_model(self, ckpt_path: str | None = None): model_id = "IDEA-Research/grounding-dino-base" self.processor = AutoProcessor.from_pretrained(model_id) - self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE) + self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( + DEVICE + ) def predict( self, diff --git a/lang_sam/models/sam.py b/lang_sam/models/sam.py index e013b4e..220352e 100644 --- a/lang_sam/models/sam.py +++ b/lang_sam/models/sam.py @@ -5,8 +5,10 @@ from omegaconf import OmegaConf from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.sam2_image_predictor import SAM2ImagePredictor +from lang_sam.models.utils import get_device_type + +DEVICE = torch.device(get_device_type()) -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: diff --git a/lang_sam/models/utils.py b/lang_sam/models/utils.py new file mode 100644 index 0000000..bb4f003 --- /dev/null +++ b/lang_sam/models/utils.py @@ -0,0 +1,12 @@ +import logging +import torch + + +def get_device_type() -> str: + if torch.backends.mps.is_available(): + return "mps" + elif torch.cuda.is_available(): + return "cuda" + else: + logging.warning("No GPU found, using CPU instead") + return "cpu"