diff --git a/samgeo/text_sam.py b/samgeo/text_sam.py index aa5679c3..1c010111 100644 --- a/samgeo/text_sam.py +++ b/samgeo/text_sam.py @@ -337,7 +337,10 @@ def predict( masks = torch.tensor([]) if len(boxes) > 0: masks = self.predict_sam(image_pil, boxes) - if 1 in masks.shape: + # If masks have 4 dimensions and the second dimension is 1 (e.g., [boxes, 1, height, width]), + # squeeze that dimension to reduce it to 3 dimensions ([boxes, height, width]). + # If boxes = 1, the mask's shape will be [1, height, width] after squeezing. + if masks.ndim == 4 and masks.shape[1] == 1: masks = masks.squeeze(1) if boxes.nelement() == 0: # No "object" instances found