Skip to content

Commit

Permalink
Merge pull request #75 from iamnotagentleman/add-mps-support
Browse files Browse the repository at this point in the history
  • Loading branch information
luca-medeiros authored Oct 15, 2024
2 parents 3f8af23 + 3b5d453 commit e82f70f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
Empty file added lang_sam/models/__init__.py
Empty file.
9 changes: 7 additions & 2 deletions lang_sam/models/gdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import torch
from PIL import Image
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from lang_sam.models.utils import get_device_type

device_type = get_device_type()
DEVICE = torch.device(device_type)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
Expand All @@ -18,7 +21,9 @@ def __init__(self):
def build_model(self, ckpt_path: str | None = None):
model_id = "IDEA-Research/grounding-dino-base"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
DEVICE
)

def predict(
self,
Expand Down
4 changes: 3 additions & 1 deletion lang_sam/models/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from omegaconf import OmegaConf
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
from lang_sam.models.utils import get_device_type

DEVICE = torch.device(get_device_type())

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
Expand Down
12 changes: 12 additions & 0 deletions lang_sam/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import logging
import torch


def get_device_type() -> str:
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
else:
logging.warning("No GPU found, using CPU instead")
return "cpu"

0 comments on commit e82f70f

Please sign in to comment.