diff --git a/groundingdino/models/GroundingDINO/fuse_modules.py b/groundingdino/models/GroundingDINO/fuse_modules.py index 2753b3d..ea1bb04 100644 --- a/groundingdino/models/GroundingDINO/fuse_modules.py +++ b/groundingdino/models/GroundingDINO/fuse_modules.py @@ -203,19 +203,22 @@ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None): # mask vison for language if attention_mask_v is not None: - attention_mask_v = ( - attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) - ) - attn_weights_l.masked_fill_(attention_mask_v, float("-inf")) - + attention_mask_v = (attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)) + if attn_weights_l.device.type == "mps": + attn_weights_l.masked_fill(attention_mask_v, float("-inf")) + else: + attn_weights_l.masked_fill_(attention_mask_v, float("-inf")) + attn_weights_l = attn_weights_l.softmax(dim=-1) - # mask language for vision if attention_mask_l is not None: attention_mask_l = ( attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1) ) - attn_weights.masked_fill_(attention_mask_l, float("-inf")) + if attn_weights.device.type == "mps": + attn_weights = attn_weights.masked_fill(attention_mask_l, float("-inf")) + else: + attn_weights.masked_fill_(attention_mask_l, float("-inf")) attn_weights_v = attn_weights.softmax(dim=-1) attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)