diff --git a/mindone/transformers/models/sam2/__init__.py b/mindone/transformers/models/sam2/__init__.py new file mode 100644 index 0000000000..70990f052e --- /dev/null +++ b/mindone/transformers/models/sam2/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .image_processing_sam2_fast import Sam2ImageProcessorFast +from .modeling_sam2 import Sam2HieraDetModel, Sam2Model, Sam2PreTrainedModel, Sam2VisionModel +from .processing_sam2 import Sam2Processor diff --git a/mindone/transformers/models/sam2/image_processing_sam2_fast.py b/mindone/transformers/models/sam2/image_processing_sam2_fast.py new file mode 100644 index 0000000000..6802085b4e --- /dev/null +++ b/mindone/transformers/models/sam2/image_processing_sam2_fast.py @@ -0,0 +1,714 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from copy import deepcopy +from itertools import product +from typing import Any, Optional, Union + +import numpy as np +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import mint, ops + +from ...mindspore_adapter.batched_nms import batched_nms +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + pil_mindspore_interpolation_mapping, +) +from ...processing_utils import Unpack +from ...utils import TensorType + + +class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + mask_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width": int}` to resize the segmentation maps to. + """ + + mask_size: dict[str, int] + + +def _compute_stability_score(masks: "ms.Tensor", mask_threshold: float, stability_score_offset: int): + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=ms.int16).sum(-1, dtype=ms.int32) + ) + unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=ms.int16).sum(-1, dtype=ms.int32) + stability_scores = intersections / unions + return stability_scores + + +def _mask_to_rle(input_mask: "ms.Tensor"): + """ + Encodes masks the run-length encoding (RLE), in the format expected by pycoco tools. + """ + # Put in fortran order and flatten height and width + batch_size, height, width = input_mask.shape + input_mask = input_mask.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = input_mask[:, 1:] ^ input_mask[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(batch_size): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1 + if len(cur_idxs) == 0: + # No changes => either all 0 or all 1 + # If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width]. + if input_mask[i, 0] == 0: + out.append({"size": [height, width], "counts": [height * width]}) + else: + out.append({"size": [height, width], "counts": [0, height * width]}) + continue + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if input_mask[i, 0] == 0 else [0] + counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()] + out.append({"size": [height, width], "counts": counts}) + return out + + +def _batched_mask_to_box(masks: "ms.Tensor"): + """ + Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which + corresponds the following required indices: + - LEFT: left hand side of the bounding box + - TOP: top of the bounding box + - RIGHT: right of the bounding box + - BOTTOM: bottom of the bounding box + + Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape + is channel_1 x channel_2 x ... x 4. + + Args: + - masks (`ms.Tensor` of shape `(batch, nb_mask, height, width)`) + """ + + + if ops.numel(masks) == 0: + return mint.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to Cxheightxwidth + shape = masks.shape + height, width = shape[-2:] + + # Get top and bottom edges + in_height, _ = mint.max(masks, dim=-1) + in_height_coords = in_height * mint.arange(height, device=in_height.device)[None, :] + bottom_edges, _ = mint.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + height * (~in_height) + top_edges, _ = mint.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = mint.max(masks, dim=-2) + in_width_coords = in_width * mint.arange(width, device=in_width.device)[None, :] + right_edges, _ = mint.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + width * (~in_width) + left_edges, _ = mint.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = mint.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + out = out.reshape(*shape[:-2], 4) + return out + + +def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = ms.tensor(crop_box, dtype=ms.float, device=boxes.device) + orig_box_torch = ms.tensor(orig_box, dtype=ms.float, device=boxes.device) + + left, top, _, _ = crop_box + offset = ms.tensor([[left, top, left, top]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + boxes = (boxes + offset).float() + + near_crop_edge = mint.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = mint.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = mint.logical_and(near_crop_edge, ~near_image_edge) + return mint.any(near_crop_edge, dim=1) + + +def _pad_masks(masks, crop_box: list[int], orig_height: int, orig_width: int): + left, top, right, bottom = crop_box + if left == 0 and top == 0 and right == orig_width and bottom == orig_height: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top) + pad = (left, pad_x - left, top, pad_y - top) + return mint.nn.functional.pad(masks, pad, value=0) + + +def _generate_crop_boxes( + image, + target_size: int, # Is it tuple here? + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, +) -> tuple[list[list[int]], list[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (Union[`numpy.ndarray`, `PIL.Image`, `ms.Tensor`]): + Image to generate crops for. + target_size (`int`): + Size of the smallest crop. + crop_n_layers (`int`, *optional*): + If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of layers + to run, where each layer has 2**i_layer number of image crops. + overlap_ratio (`int`, *optional*): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the + image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*): + Number of points to sam2ple per crop. + crop_n_points_downscale_factor (`int`, *optional*): + The number of points-per-side sam2pled in layer n is scaled down by crop_n_points_downscale_factor**n. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + + if isinstance(image, list): + raise ValueError("Only one image is allowed for crop generation.") + original_size = image.shape[-2:] + + points_grid = [] + for i in range(crop_n_layers + 1): + n_points = int(points_per_crop / (crop_n_points_downscale_factor**i)) + points_grid.append(_build_point_grid(n_points)) + + crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size) + + cropped_images, point_grid_per_crop = _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size + ) + crop_boxes = ms.tensor(crop_boxes) + crop_boxes = crop_boxes.float() + points_per_crop = mint.stack(point_grid_per_crop) + points_per_crop = points_per_crop.unsqueeze(0).permute(0, 2, 1, 3) + cropped_images = mint.stack(cropped_images) + + input_labels = mint.ones_like(points_per_crop[:, :, :, 0], dtype=ms.int64) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + +def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size): + """ + Generates 2 ** (layers idx + 1) crops for each crop_n_layers. Crops are in the XYWH format : The XYWH format + consists of the following required indices: + - X: X coordinate of the top left of the bounding box + - Y: Y coordinate of the top left of the bounding box + - W: width of the bounding box + - H: height of the bounding box + """ + crop_boxes, layer_idxs = [], [] + im_height, im_width = original_size + short_side = min(im_height, im_width) + + # Original image + crop_boxes.append([0, 0, im_width, im_height]) + layer_idxs.append(0) + for i_layer in range(crop_n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side)) + crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side)) + + crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)] + + for left, top in product(crop_box_x0, crop_box_y0): + box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def _build_point_grid(n_per_side: int) -> ms.Tensor: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = mint.linspace(offset, 1 - offset, n_per_side) + points_x = mint.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = mint.tile(points_one_side[:, None], (1, n_per_side)) + points = mint.stack([points_x, points_y], dim=-1).reshape(-1, 2) + return points + + +def _generate_crop_images( + crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None +): + """ + Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are + also passed. + """ + cropped_images = [] + total_points_per_crop = [] + for i, crop_box in enumerate(crop_boxes): + left, top, right, bottom = crop_box + cropped_im = image[:, top:bottom, left:right] + + cropped_images.append(cropped_im) + + cropped_im_size = cropped_im.shape[-2:] + points_scale = ms.tensor(cropped_im_size).flip(dims=(0,)).unsqueeze(0) + + points = points_grid[layer_idxs[i]] * points_scale + normalized_points = _normalize_coordinates(target_size, points, original_size) + total_points_per_crop.append(normalized_points) + + return cropped_images, total_points_per_crop + + +def _normalize_coordinates( + target_size: int, coords: ms.Tensor, original_size: tuple[int, int], is_bounding_box=False +) -> ms.Tensor: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (height, width) + format. + """ + old_height, old_width = original_size + + scale = target_size * 1.0 / max(old_height, old_width) + new_height, new_width = old_height * scale, old_width * scale + new_width = int(new_width + 0.5) + new_height = int(new_height + 0.5) + + coords = deepcopy(coords).float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_width / old_width) + coords[..., 1] = coords[..., 1] * (new_height / old_height) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + +def _rle_to_mask(rle: dict[str, Any]) -> ms.Tensor: + """Compute a binary mask from an uncompressed RLE.""" + height, width = rle["size"] + mask = mint.empty(height * width, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity = not parity + mask = mask.reshape(width, height) + return mask.transpose(0, 1) # Reshape to original shape + + +def _post_process_for_mask_generation(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7): + """ + Perform NMS (Non Maximum Suppression) on the outputs. + + Args: + rle_masks (`ms.Tensor`): + binary masks in the RLE format + iou_scores (`ms.Tensor` of shape (nb_masks, 1)): + iou_scores predicted by the model + mask_boxes (`ms.Tensor`): + The bounding boxes corresponding to segmentation masks + amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7): + NMS threshold. + """ + keep_by_nms = batched_nms( + boxes=mask_boxes.float(), + scores=iou_scores, + idxs=mint.zeros(mask_boxes.shape[0]), + iou_threshold=amg_crops_nms_thresh, + ) + + iou_scores = iou_scores[keep_by_nms] + rle_masks = [rle_masks[i] for i in keep_by_nms] + mask_boxes = mask_boxes[keep_by_nms] + masks = [_rle_to_mask(rle) for rle in rle_masks] + + return masks, iou_scores, rle_masks, mask_boxes + + +class Sam2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 1024, "width": 1024} + mask_size = {"height": 256, "width": 256} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + valid_kwargs = Sam2FastImageProcessorKwargs + + # modular artefacts + do_pad = None + pad_size = None + mask_pad_size = None + + def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _further_process_kwargs( + self, + size: Optional[SizeDict] = None, + mask_size: Optional[SizeDict] = None, + default_to_square: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[ChannelDimension] = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if kwargs is None: + kwargs = {} + if size is not None: + size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) + if mask_size is not None: + mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size")) + if isinstance(image_mean, list): + image_mean = tuple(image_mean) + if isinstance(image_std, list): + image_std = tuple(image_std) + if data_format is None: + data_format = ChannelDimension.FIRST + + kwargs["size"] = size + kwargs["mask_size"] = mask_size + kwargs["image_mean"] = image_mean + kwargs["image_std"] = image_std + kwargs["data_format"] = data_format + kwargs["default_to_square"] = default_to_square + return kwargs + + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[Sam2FastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + **kwargs: Unpack[Sam2FastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format + ) + original_sizes = [image.shape[-2:] for image in images] + images_kwargs = kwargs.copy() + pixel_values = self._preprocess(images, **images_kwargs) + reshaped_input_sizes = [image.shape[-2:] for image in images] + data = { + "pixel_values": pixel_values, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + "interpolation": pil_mindspore_interpolation_mapping[PILImageResampling.NEAREST], + "size": segmentation_maps_kwargs.pop("mask_size"), + } + ) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + data["labels"] = processed_segmentation_maps.squeeze(1).to(ms.int64) + + return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) + + def generate_crop_boxes( + self, + image: "ms.Tensor", + target_size, + crop_n_layers: int = 0, + overlap_ratio: float = 512 / 1500, + points_per_crop: Optional[int] = 32, + crop_n_points_downscale_factor: Optional[list[int]] = 1, + ): + """ + Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. + + Args: + image (`ms.Tensor`): + Input original image + target_size (`int`): + Target size of the resized image + crop_n_layers (`int`, *optional*, defaults to 0): + If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where + each layer has 2**i_layer number of image crops. + overlap_ratio (`float`, *optional*, defaults to 512/1500): + Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + points_per_crop (`int`, *optional*, defaults to 32): + Number of points to sam2ple from each crop. + crop_n_points_downscale_factor (`list[int]`, *optional*, defaults to 1): + The number of points-per-side sam2pled in layer n is scaled down by crop_n_points_downscale_factor**n. + device (`torch.device`, *optional*, defaults to None): + Device to use for the computation. If None, cpu will be used. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + return_tensors (`str`, *optional*, defaults to `ms`): + If `ms`, returns `ms.Tensor`. + """ + image = self._process_image(image) + crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes( + image, + target_size, + crop_n_layers, + overlap_ratio, + points_per_crop, + crop_n_points_downscale_factor, + ) + + return crop_boxes, points_per_crop, cropped_images, input_labels + + def filter_masks( + self, + masks, + iou_scores, + original_size, + cropped_box_image, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + mask_threshold=0, + stability_score_offset=1, + ): + """ + Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being + that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability + score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to + bounding boxes and pad the predicted masks if necessary. + + Args: + masks (`ms.Tensor`): + Input masks. + iou_scores (`ms.Tensor`): + List of IoU scores. + original_size (`tuple[int,int]`): + Size of the original image. + cropped_box_image (`ms.Tensor`): + The cropped image. + pred_iou_thresh (`float`, *optional*, defaults to 0.88): + The threshold for the iou scores. + stability_score_thresh (`float`, *optional*, defaults to 0.95): + The threshold for the stability score. + mask_threshold (`float`, *optional*, defaults to 0): + The threshold for the predicted masks. + stability_score_offset (`float`, *optional*, defaults to 1): + The offset for the stability score used in the `_compute_stability_score` method. + + """ + original_height, original_width = original_size + iou_scores = iou_scores.flatten(0, 1) + masks = masks.flatten(0, 1) + + if masks.shape[0] != iou_scores.shape[0]: + raise ValueError("masks and iou_scores must have the sam2e batch size.") + + if masks.device != iou_scores.device: + iou_scores = iou_scores.to(masks.device) + + batch_size = masks.shape[0] + + keep_mask = mint.ones(batch_size, dtype=ms.bool, device=masks.device) + + if pred_iou_thresh > 0.0: + keep_mask = keep_mask & (iou_scores > pred_iou_thresh) + + # compute stability score + if stability_score_thresh > 0.0: + stability_scores = _compute_stability_score(masks, mask_threshold, stability_score_offset) + keep_mask = keep_mask & (stability_scores > stability_score_thresh) + + scores = iou_scores[keep_mask] + masks = masks[keep_mask] + + # binarize masks + masks = masks > mask_threshold + converted_boxes = _batched_mask_to_box(masks) + + keep_mask = ~_is_box_near_crop_edge( + converted_boxes, cropped_box_image, [0, 0, original_width, original_height] + ) + + scores = scores[keep_mask] + masks = masks[keep_mask] + converted_boxes = converted_boxes[keep_mask] + + masks = _pad_masks(masks, cropped_box_image, original_height, original_width) + # conversion to rle is necessary to run non-maximum suppression + masks = _mask_to_rle(masks) + + return masks, scores, converted_boxes + + def post_process_masks( + self, + masks, + original_sizes, + mask_threshold=0.0, + binarize=True, + max_hole_area=0.0, + max_sprinkle_area=0.0, + apply_non_overlapping_constraints=False, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[ms.Tensor, List[ms.Tensor], np.ndarray, List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[ms.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarization and post-processing operations. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + max_hole_area (`float`, *optional*, defaults to 0.0): + The maximum area of a hole to fill. + max_sprinkle_area (`float`, *optional*, defaults to 0.0): + The maximum area of a sprinkle to fill. + apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): + Whether to apply non-overlapping constraints to the masks. + + Returns: + (`ms.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + if isinstance(original_sizes, (ms.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + # TODO: add connected components kernel for postprocessing + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = ms.Tensor(masks[i]) + elif not isinstance(masks[i], ms.Tensor): + raise TypeError("Input masks should be a list of `ms.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False) + if apply_non_overlapping_constraints: + interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + def post_process_for_mask_generation(self, all_masks, all_scores, all_boxes, crops_nms_thresh): + """ + Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks. + + Args: + all_masks (`ms.Tensor`): + List of all predicted segmentation masks + all_scores (`ms.Tensor`): + List of all predicted iou scores + all_boxes (`ms.Tensor`): + List of all bounding boxes of the predicted masks + crops_nms_thresh (`float`): + Threshold for NMS (Non Maximum Suppression) algorithm. + """ + return _post_process_for_mask_generation(all_masks, all_scores, all_boxes, crops_nms_thresh) + + def pad_image(self): + raise NotImplementedError("No pad_image for SAM 2.") + + def _preprocess( + self, + images: list["ms.Tensor"], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> "ms.Tensor": + return super()._preprocess(images, return_tensors=return_tensors, **kwargs).pixel_values + + def _apply_non_overlapping_constraints(self, pred_masks: ms.Tensor) -> ms.Tensor: + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = mint.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = mint.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = mint.where(keep, pred_masks, mint.clamp(pred_masks, max=-10.0)) + return pred_masks + + +__all__ = ["Sam2ImageProcessorFast"] diff --git a/mindone/transformers/models/sam2/modeling_sam2.py b/mindone/transformers/models/sam2/modeling_sam2.py new file mode 100644 index 0000000000..bfd247eb61 --- /dev/null +++ b/mindone/transformers/models/sam2/modeling_sam2.py @@ -0,0 +1,1549 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam2/modular_sam2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass +from typing import Optional, Union, Tuple, List + +import numpy as np +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.nn.functional as F +from mindspore import Tensor, nn, ops + +from transformers.utils.generic import OutputRecorder +from ...processing_utils import Unpack +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ModelOutput +from ..auto import AutoModel +from mindone.transformers.generation import GenerationMixin +from transformers import ( + Sam2Config, + Sam2HieraDetConfig, + Sam2MaskDecoderConfig, + Sam2PromptEncoderConfig, + Sam2VisionConfig, +) + + +@dataclass +class Sam2VisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`ms.Tensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`Tuple(ms.Tensor)`): + Tuple of `ms.Tensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`Tuple(ms.Tensor)`): + Tuple of `ms.Tensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`Tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`Tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[ms.Tensor] = None + fpn_hidden_states: Optional[ms.Tensor] = None + fpn_position_encoding: Optional[ms.Tensor] = None + hidden_states: Optional[Tuple[ms.Tensor, ...]] = None + attentions: Optional[Tuple[ms.Tensor, ...]] = None + + +@dataclass +class Sam2ImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`ms.Tensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`ms.Tensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`ms.Tensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`Tuple(ms.Tensor)`): + The features from the FPN, which are used by the mask decoder. This is a Tuple of `ms.Tensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`Tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `ms.Tensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`Tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`Tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + """ + + iou_scores: Optional[ms.Tensor] = None + pred_masks: Optional[ms.Tensor] = None + object_score_logits: Optional[ms.Tensor] = None + image_embeddings: Tuple[ms.Tensor, ...] = None + vision_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None + vision_attentions: Optional[Tuple[ms.Tensor, ...]] = None + mask_decoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None + + +class Sam2PatchEmbeddings(nn.Cell): + r""" + Turns pixel values into patch embeddings for transformer consumption. + + Args: + pixel_values (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details. + + Returns: + embeddings (`ms.Tensor`): + Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding + """ + + def __init__(self, config: Sam2HieraDetConfig): + super().__init__() + num_channels = config.num_channels + hidden_size = config.hidden_size + + self.projection = mint.nn.Conv2d( + num_channels, + hidden_size, + kernel_size=config.patch_kernel_size, + stride=config.patch_stride, + padding=config.patch_padding, + ) + + def construct(self, pixel_values): + _, num_channels, height, width = pixel_values.shape + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class Sam2SinePositionEmbedding(nn.Cell): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + def construct( + self, + x: ms.Tensor, + mask: Optional[Tensor] = None, + ) -> ms.Tensor: + if mask is None: + mask = mint.zeros((x.shape[0], x.shape[2], x.shape[3]), dtype=ms.bool) + not_mask = (~mask).to(x.dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = mint.arange(self.num_pos_feats, dtype=ms.int64).type_as(x) + dim_t = self.temperature ** (2 * mint.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = mint.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = mint.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = mint.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class Sam2VisionNeck(nn.Cell): + def __init__(self, config: Sam2VisionConfig): + super().__init__() + self.config = config + + self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True) + self.convs = nn.CellList() + for in_channels in config.backbone_channel_list: + self.convs.append( + mint.nn.Conv2d( + in_channels=in_channels, + out_channels=config.fpn_hidden_size, + kernel_size=config.fpn_kernel_size, + stride=config.fpn_stride, + padding=config.fpn_padding, + ), + ) + self.fpn_top_down_levels = config.fpn_top_down_levels + + def construct(self, hidden_states: ms.Tensor) -> Tuple[Tuple[ms.Tensor, ...], Tuple[ms.Tensor, ...]]: + fpn_hidden_states = () + fpn_position_encoding = () + + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + lateral_features = hidden_states[i].permute(0, 3, 1, 2) + lateral_features = self.convs[n - i](lateral_features) + if i not in self.fpn_top_down_levels or i == n: + prev_features = lateral_features + else: + top_down_features = F.interpolate( + prev_features, + scale_factor=2.0, + mode="nearest", + align_corners=None, + ).to(lateral_features.dtype) + prev_features = lateral_features + top_down_features + + prev_position_encoding = self.position_encoding( + prev_features + ).to(prev_features.dtype) + + fpn_hidden_states += (prev_features,) + fpn_position_encoding += (prev_position_encoding,) + + return fpn_hidden_states, fpn_position_encoding + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = mint.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = mint.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = ops.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def do_pool(x: ms.Tensor, query_stride: Optional[int] = None) -> ms.Tensor: + if query_stride is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + query_stride = tuple(query_stride) + input_dtype = x.dtype + x = x.to(ms.float32) + x = mint.nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False) + # (B, C, H', W') -> (B, H', W', C) + x = x.to(input_dtype) + x = x.permute(0, 2, 3, 1) + return x + + +class Sam2MultiScaleAttention(nn.Cell): + def __init__( + self, + config: Sam2HieraDetConfig, + dim: int, + dim_out: int, + num_attention_heads: int, + query_stride: Optional[Tuple[int, int]] = None, + ): + super().__init__() + + self.config = config + + self.dim = dim + self.dim_out = dim_out + self.query_stride = query_stride + + self.num_attention_heads = num_attention_heads + head_dim = dim_out // num_attention_heads + self.scale = head_dim**-0.5 + self.qkv = mint.nn.Linear(dim, dim_out * 3) + self.proj = mint.nn.Linear(dim_out, dim_out) + + self.is_causal = False + + def construct(self, hidden_states: ms.Tensor, **kwargs) -> ms.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + query, key, value = mint.unbind(qkv, 2) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + attn_weights = mint.softmax(attn_weights, dtype=ms.float32, dim=-1).to(query.dtype) + + # Q pooling (for downsample at stage changes) + if self.query_stride: + query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride) + height, width = query.shape[1:3] # downsampled shape + query = query.reshape(batch_size, height * width, self.num_attention_heads, -1) + + # transpose query, key, value to (B, nHead, H * W, C) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, _ = attention_interface( + self, + query, + key, + value, + attention_mask=None, + is_causal=self.is_causal, + scaling=self.scale, + **kwargs, + ) + attn_output = attn_output.reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + return attn_output + + +class Sam2FeedForward(nn.Cell): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = mint.nn.Linear(input_dim, hidden_dim) + self.proj_out = mint.nn.Linear(hidden_dim, output_dim) + self.layers = nn.CellList([mint.nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def construct(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +def window_partition(hidden_state, window_size): + """ + Partition into non-overlapping windows with padding if needed. + + Args: + hidden_state (`ms.Tensor`): + Input tokens with [batch_size, height, width, num_channels]. + window_size (`int`): + Window size. + + Returns: + `Tuple(ms.Tensor)` comprising various elements: + - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels]. + - (padded_height, padded_width): padded height and width before partition + """ + batch_size, height, width, num_channels = hidden_state.shape + + pad_height = (window_size - height % window_size) % window_size + pad_width = (window_size - width % window_size) % window_size + + # Noop in case pad_width == 0 and pad_height == 0. + hidden_state = mint.nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) + + padded_height, padded_width = height + pad_height, width + pad_width + + hidden_state = hidden_state.view( + batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels + ) + windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows, (padded_height, padded_width) + + +def window_unpartition(windows, window_size, pad_height_width, height_width): + """ + Window unpartition into original sequences and removing padding. + + Args: + windows (`ms.Tensor`): + Input tokens with [batch_size * num_windows, window_size, window_size, num_channels]. + window_size (`int`): + Window size. + pad_height_width (`Tuple[int]`): + Padded height and width (padded_height, padded_width). + height_width (`Tuple[int]`): + Original height and width before padding. + + Returns: + hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels]. + """ + padded_height, padded_width = pad_height_width + height, width = height_width + batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size) + hidden_state = windows.view( + batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1 + ) + hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous() + hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1) + + # We always have height <= padded_height and width <= padded_width + hidden_state = hidden_state[:, :height, :width, :].contiguous() + return hidden_state + + +class Sam2MultiScaleBlock(GradientCheckpointingLayer): + def __init__( + self, + config: Sam2HieraDetConfig, + stage_idx: int, + block_idx: int, + total_block_idx: int, + ): + super().__init__() + + # take embed dim from previous stage if first block of stage + self.dim = ( + config.embed_dim_per_stage[stage_idx - 1] + if stage_idx > 0 and block_idx == 0 + else config.embed_dim_per_stage[stage_idx] + ) + self.dim_out = config.embed_dim_per_stage[stage_idx] + self.layer_norm1 = mint.nn.LayerNorm(self.dim, eps=config.layer_norm_eps) + # take window size from previous stage if first block of stage + self.window_size = ( + config.window_size_per_stage[stage_idx - 1] + if stage_idx > 0 and block_idx == 0 + else config.window_size_per_stage[stage_idx] + ) + self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size + # use query stride for first block of stage if stage is a query pool stage + self.query_stride = ( + config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None + ) + + self.attn = Sam2MultiScaleAttention( + config, + self.dim, + self.dim_out, + num_attention_heads=config.num_attention_heads_per_stage[stage_idx], + query_stride=self.query_stride, + ) + self.layer_norm2 = mint.nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps) + self.mlp = Sam2FeedForward( + self.dim_out, + int(self.dim_out * config.mlp_ratio), + self.dim_out, + num_layers=2, + activation=config.hidden_act, + ) + if self.dim != self.dim_out: + self.proj = mint.nn.Linear(self.dim, self.dim_out) + + def construct( + self, + hidden_states: ms.Tensor, + ) -> ms.Tensor: + residual = hidden_states # batch_size, height, width, channel + + hidden_states = self.layer_norm1(hidden_states) + + # Skip connection + if self.dim != self.dim_out: + residual = do_pool(self.proj(hidden_states), self.query_stride) + + # Window partition + window_size = self.window_size + if self.window_size > 0: + H, W = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, pad_hw = window_partition(hidden_states, window_size) + + # Window Attention + Q Pooling (if stage change) + attn_output = self.attn( + hidden_states=hidden_states, + ) + hidden_states = attn_output + if self.query_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.query_stride[0] + H, W = residual.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + return hidden_states + + +@dataclass +class Sam2HieraDetModelOutput(ModelOutput): + r""" + last_hidden_state (`ms.Tensor` of shape `(batch_size, height, width, hidden_size)`): + hidden-states at the output of the last layer of the model. + intermediate_hidden_states (`Tuple[ms.Tensor]` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the intermediate layers of the model. + """ + + last_hidden_state: Optional[ms.Tensor] = None + intermediate_hidden_states: Optional[Tuple[ms.Tensor, ...]] = None + + +class Sam2PreTrainedModel(PreTrainedModel): + config_class = Sam2Config + base_model_prefix = "sam2" + main_input_name = "pixel_values" + input_modalities = "image" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (mint.nn.Linear, mint.nn.Conv2d, mint.nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, mint.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (mint.nn.LayerNorm, Sam2LayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + if isinstance(module, Sam2HieraDetModel): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + if module.pos_embed_window is not None: + module.pos_embed_window.data.zero_() + if isinstance(module, Sam2Model): + if module.no_memory_embedding is not None: + module.no_memory_embedding.data.zero_() + + +class Sam2HieraDetModel(Sam2PreTrainedModel, GenerationMixin): + config_class = Sam2HieraDetConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Sam2MultiScaleBlock, + "attentions": Sam2MultiScaleAttention, + } + + def __init__(self, config: Sam2HieraDetConfig): + super().__init__(config) + + self.patch_embed = Sam2PatchEmbeddings(config) + # Windowed positional embedding (https://huggingface.co/papers/2311.05613) + self.pos_embed = ms.Parameter( + mint.zeros((1, config.hidden_size, *config.window_positional_embedding_background_size)) + ) + self.pos_embed_window = ms.Parameter( + mint.zeros((1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])) + ) + self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist() + self.blocks = nn.CellList() + total_block_idx = 0 + for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage): + for block_idx in range(blocks_per_stage): + block = Sam2MultiScaleBlock( + config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx + ) + self.blocks.append(block) + total_block_idx += 1 + + def get_input_embeddings(self): + return self.patch_embed + + def _get_pos_embed(self, hw: Tuple[int, int]) -> ms.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile(tuple([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + + def construct( + self, + pixel_values: Optional[ms.Tensor] = None, + ) -> Union[Tuple, Sam2HieraDetModelOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3]) + + intermediate_hidden_states = () + for i, block_module in enumerate(self.blocks): + hidden_states = block_module(hidden_states) + + if i in self.stage_ends: + intermediate_hidden_states = intermediate_hidden_states + (hidden_states,) + + return Sam2HieraDetModelOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate_hidden_states, + ) + + +class Sam2VisionModel(Sam2PreTrainedModel, GenerationMixin): + config_class = Sam2VisionConfig + main_input_name = "pixel_values" + _can_record_outputs = { + "hidden_states": Sam2MultiScaleBlock, + "attentions": Sam2MultiScaleAttention, + } + + def __init__(self, config: Sam2VisionConfig): + super().__init__(config) + self.config = config + + self.backbone = Sam2HieraDetModel(config.backbone_config) + + self.neck = Sam2VisionNeck(config) + self.num_feature_levels = config.num_feature_levels + + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def construct( + self, + pixel_values: Optional[ms.Tensor] = None, + ) -> Union[Tuple, Sam2VisionEncoderOutput]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Forward through backbone + backbone_output = self.backbone(pixel_values) + hidden_states = backbone_output.last_hidden_state + intermediate_hidden_states = backbone_output.intermediate_hidden_states + + fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states) + # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution + fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1] + fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1] + + return Sam2VisionEncoderOutput( + last_hidden_state=hidden_states, + fpn_hidden_states=fpn_hidden_states, + fpn_position_encoding=fpn_position_encoding, + ) + + +class Sam2PositionalEmbedding(nn.Cell): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.scale = config.scale + self.positional_embedding = ms.Parameter(self.scale * mint.randn((2, config.hidden_size // 2))) + #self.register_buffer("positional_embedding", positional_embedding) + + def construct(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(ms.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return mint.cat([mint.sin(coordinates), mint.cos(coordinates)], dim=-1) + + +class Sam2MaskEmbedding(nn.Cell): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = mint.nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = mint.nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = mint.nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = Sam2LayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = Sam2LayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def construct(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class Sam2PromptEncoder(nn.Cell): + def __init__(self, config: Sam2PromptEncoderConfig): + super().__init__() + self.shared_embedding = Sam2PositionalEmbedding(config) + self.mask_embed = Sam2MaskEmbedding(config) + self.no_mask_embed = mint.nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = mint.nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = mint.nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: ms.Tensor, labels: ms.Tensor, pad: bool) -> ms.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = mint.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = mint.nn.functional.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = mint.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = mint.where( + labels[..., None] != -10, + point_embedding, + mint.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: ms.Tensor) -> ms.Tensor: + """Embeds box prompts.""" + boxes += 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = mint.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + def construct( + self, + input_points: Optional[Tuple[ms.Tensor, ms.Tensor]], + input_labels: Optional[ms.Tensor], + input_boxes: Optional[ms.Tensor], + input_masks: Optional[ms.Tensor], + ) -> Tuple[ms.Tensor, ms.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`ms.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`ms.Tensor`, *optional*): + boxes to embed + masks (`ms.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = mint.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + (batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + ) + + return sparse_embeddings, dense_embeddings + + +class Sam2Attention(nn.Cell): + """ + SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = mint.nn.Linear(self.internal_dim, self.hidden_size) + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_similarity: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[ms.Tensor, ms.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class Sam2TwoWayAttentionBlock(nn.Cell): + def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`Sam2MaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + self.self_attn = Sam2Attention(config, downsample_rate=1) + self.layer_norm1 = mint.nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = Sam2Attention(config) + self.layer_norm2 = mint.nn.LayerNorm(config.hidden_size) + + self.mlp = Sam2FeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = mint.nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = mint.nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = Sam2Attention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + def construct( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +class Sam2TwoWayTransformer(nn.Cell): + def __init__(self, config: Sam2MaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.CellList() + + for i in range(self.num_hidden_layers): + self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = Sam2Attention(config) + self.layer_norm_final_attn = mint.nn.LayerNorm(config.hidden_size) + + def construct( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + ) -> Union[Tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class Sam2LayerNorm(mint.nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def construct(self, features: ms.Tensor) -> ms.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().construct(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().construct(features) + return features + + +class Sam2MaskDecoder(nn.Cell): + def __init__(self, config: Sam2MaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = mint.nn.Embedding(1, self.hidden_size) + self.mask_tokens = mint.nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = Sam2TwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = mint.nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = mint.nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = mint.nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.CellList(mlps_list) + self.iou_prediction_head = Sam2FeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = mint.nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = mint.nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = mint.nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def construct( + self, + image_embeddings: ms.Tensor, + image_positional_embeddings: ms.Tensor, + sparse_prompt_embeddings: ms.Tensor, + dense_prompt_embeddings: ms.Tensor, + multimask_output: bool, + high_resolution_features: list[ms.Tensor], + attention_similarity: Optional[ms.Tensor] = None, + target_embedding: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`ms.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`ms.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`ms.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`ms.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[ms.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`ms.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`ms.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = mint.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + sparse_prompt_embeddings = sparse_prompt_embeddings.to(output_tokens.dtype) + tokens = mint.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[ms.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = mint.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = mint.sum(mask_logits > stability_delta, dim=-1).float() + area_u = mint.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = mint.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = mint.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + (-1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)) + ) + best_multimask_logits = mint.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = mint.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = mint.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = mint.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + + +class Sam2Model(Sam2PreTrainedModel, GenerationMixin): + input_modalities = ["image", "text"] + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)} + _keys_to_ignore_on_load_unexpected = [ + r"^memory_.*", + r"^mask_downsample.*", + r"^object_pointer_proj.*", + r"^temporal_positional_encoding_projection_layer.*", + "no_memory_positional_encoding", + "no_object_pointer", + "occlusion_spatial_embedding_parameter", + ] + + def __init__(self, config: Sam2Config): + super().__init__(config) + self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = Sam2VisionModel(config.vision_config) + self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = ms.Parameter(mint.zeros((1, 1, self.hidden_dim))) + + self.post_init() + + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self) -> ms.Tensor: + size = self.prompt_encoder.image_embedding_size + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = mint.ones((size), dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(mint.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + def get_image_embeddings( + self, + pixel_values: ms.Tensor, + ) -> list[ms.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + def get_prompt_embeddings( + self, + input_points: Optional[ms.Tensor] = None, + input_labels: Optional[ms.Tensor] = None, + input_boxes: Optional[ms.Tensor] = None, + input_masks: Optional[ms.Tensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`ms.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`ms.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`ms.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`ms.Tensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + def construct( + self, + pixel_values: Optional[ms.Tensor] = None, + input_points: Optional[ms.Tensor] = None, + input_labels: Optional[ms.Tensor] = None, + input_boxes: Optional[ms.Tensor] = None, + input_masks: Optional[ms.Tensor] = None, + image_embeddings: Optional[ms.Tensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[ms.Tensor] = None, + target_embedding: Optional[ms.Tensor] = None, + ) -> Sam2ImageSegmentationOutput: + r""" + input_points (`ms.Tensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `ms` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`ms.Tensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`ms.Tensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`ms.Tensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`ms.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `construct` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`ms.Tensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`ms.Tensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = mint.ones_like(input_points[:, :, :, 0], dtype=ms.int) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = mint.zeros( + (batch_size, 1, 1, 2), dtype=image_embeddings[-1].dtype + ) + input_labels = -mint.ones((batch_size, 1, 1), dtype=ms.int32) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + ) + + return Sam2ImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_multimasks, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + def get_image_features( + self, + pixel_values: ms.Tensor, + ) -> Tuple[ + list[ms.Tensor], + list[ms.Tensor], + Optional[Tuple[ms.Tensor, ...]], + Optional[Tuple[ms.Tensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`ms.Tensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `Tuple`: A Tuple containing: + - feature_maps (`list[ms.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[ms.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`Tuple[ms.Tensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`Tuple[ms.Tensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder( + pixel_values, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + +__all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"] diff --git a/mindone/transformers/models/sam2/processing_sam2.py b/mindone/transformers/models/sam2/processing_sam2.py new file mode 100644 index 0000000000..b970e55dfd --- /dev/null +++ b/mindone/transformers/models/sam2/processing_sam2.py @@ -0,0 +1,524 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAM2. +""" + +from copy import deepcopy +from typing import Optional, Union + +import numpy as np + +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import BatchEncoding +from ...utils import TensorType, logging +import mindspore as ms +import mindspore.mint as mint + +logger = logging.get_logger(__name__) + + +class Sam2Processor(ProcessorMixin): + r""" + Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`Sam2Processor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of + [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information. + + Args: + image_processor (`Sam2ImageProcessorFast`): + An instance of [`Sam2ImageProcessorFast`]. + target_size (`int`, *optional*): + The target size (target_size, target_size) to which the image will be resized. + point_pad_value (`int`, *optional*, defaults to -10): + The value used for padding input points. + """ + attributes = ["image_processor"] + image_processor_class = "Sam2ImageProcessorFast" + + def __init__(self, image_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs): + super().__init__(image_processor, **kwargs) + self.point_pad_value = point_pad_value + self.target_size = target_size if target_size is not None else self.image_processor.size["height"] + + def __call__( + self, + images: Optional[ImageInput] = None, + segmentation_maps: Optional[ImageInput] = None, + input_points: Optional[Union[list[list[list[list[float]]]], ms.Tensor]] = None, + input_labels: Optional[Union[list[list[list[int]]], ms.Tensor]] = None, + input_boxes: Optional[Union[list[list[list[float]]], ms.Tensor]] = None, + original_sizes: Optional[Union[list[list[float]], ms.Tensor]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This method uses [`Sam2ImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + + Args: + images (`ImageInput`, *optional*): + The image(s) to process. + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to process. + input_points (`list[list[list[list[float]]]]`, `ms.Tensor`, *optional*): + The points to add to the frame. + input_labels (`list[list[list[int]]]`, `ms.Tensor`, *optional*): + The labels for the points. + input_boxes (`list[list[list[float]]]`, `ms.Tensor`, *optional*): + The bounding boxes to add to the frame. + original_sizes (`list[list[float]]`, `ms.Tensor`, *optional*): + The original sizes of the images. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. + **kwargs: + Additional keyword arguments to pass to the image processor. + + Returns: + A [`BatchEncoding`] with the following fields: + - `pixel_values` (`ms.Tensor`): The processed image(s). + - `original_sizes` (`list[list[float]]`): The original sizes of the images. + - `reshaped_input_sizes` (`ms.Tensor`): The reshaped input sizes of the images. + - `labels` (`ms.Tensor`): The processed segmentation maps (if provided). + - `input_points` (`ms.Tensor`): The processed points. + - `input_labels` (`ms.Tensor`): The processed labels. + - `input_boxes` (`ms.Tensor`): The processed bounding boxes. + """ + if images is not None: + encoding_image_processor = self.image_processor( + images, + segmentation_maps=segmentation_maps, + return_tensors=return_tensors, + **kwargs, + ) + elif original_sizes is not None: + if isinstance(original_sizes, ms.Tensor): + original_sizes = original_sizes.cpu().tolist() + encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type=return_tensors) + else: + raise ValueError("Either images or original_sizes must be provided") + + # pop arguments that are not used in the forward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + # Check original_sizes is of length 1 or len(images) + if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images): + raise ValueError( + "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size." + ) + + # Process input points, labels, and boxes if provided + if input_points is not None or input_labels is not None or input_boxes is not None: + # Validate and convert inputs to standardized format + processed_points = self._validate_single_input( + input_points, + expected_depth=4, + input_name="points", + expected_format="[image level, object level, point level, point coordinates]", + expected_coord_size=2, + ) + processed_labels = self._validate_single_input( + input_labels, + expected_depth=3, + input_name="labels", + expected_format="[image level, object level, point level]", + ) + processed_boxes = self._validate_single_input( + input_boxes, + expected_depth=3, + input_name="boxes", + expected_format="[image level, box level, box coordinates]", + expected_coord_size=4, + ) + + # Get padding requirements for all inputs + if processed_points is not None: + points_max_dims = self._get_nested_dimensions(processed_points)[:3] + if processed_labels is not None: + labels_max_dims = self._get_nested_dimensions(processed_labels)[:3] + if processed_boxes is not None: + boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2] + + # Ensure points and labels have consistent dimensions + if processed_points is not None and processed_labels is not None: + if points_max_dims != labels_max_dims: + raise ValueError( + "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions." + ) + + # Check that boxes don't need padding (model limitation) + if processed_boxes is not None and len(processed_boxes) >= 2: + if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes): + raise ValueError( + "Input boxes have inconsistent dimensions that would require padding, " + "but boxes cannot be padded due to model limitations. " + "Please ensure all images have the same number of boxes." + ) + + # Pad and normalize all inputs to final tensor format + if processed_points is not None: + padded_points = self._pad_nested_list(processed_points, points_max_dims + [2]) + final_points = ms.tensor(padded_points, dtype=ms.float32) + self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True) + encoding_image_processor.update({"input_points": final_points}) + + if processed_labels is not None: + padded_labels = self._pad_nested_list(processed_labels, labels_max_dims) + final_labels = ms.tensor(padded_labels, dtype=ms.int64) + encoding_image_processor.update({"input_labels": final_labels}) + + if processed_boxes is not None: + final_boxes = ms.tensor(processed_boxes, dtype=ms.float32) + self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True) + encoding_image_processor.update({"input_boxes": final_boxes}) + encoding_image_processor['pixel_values'] = ms.Tensor(encoding_image_processor['pixel_values'].numpy()) + return encoding_image_processor + + def _normalize_coordinates( + self, target_size: int, coords: "ms.Tensor", original_size, is_bounding_box=False + ) -> "ms.Tensor": + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + + Args: + target_size (`int`): + The target size of the image. + coords (`ms.Tensor`): + The coordinates to be normalized. + original_size (`tuple`): + The original size of the image. + is_bounding_box (`bool`, *optional*, defaults to `False`): + Whether the coordinates are bounding boxes. + """ + old_h, old_w = original_size + new_h, new_w = target_size, target_size + coords = deepcopy(coords).float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _convert_to_nested_list(self, data, expected_depth, current_depth=0): + """ + Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists. + + Args: + data: Input data in any format + expected_depth: Expected nesting depth + current_depth: Current depth in recursion + + Returns: + Nested list representation of the data + """ + if data is None: + return None + + # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array + if isinstance(data, ms.Tensor): # PyTorch tensor + if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor + return data.numpy().tolist() + else: + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, np.ndarray): # NumPy array + if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array + return data.tolist() + else: + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, list): + if current_depth == expected_depth: + # We've reached the expected depth, return as is + return data + else: + # Continue recursion + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, (int, float)): + return data + else: + raise TypeError(f"Unsupported data type: {type(data)}") + + def _get_nested_dimensions(self, nested_list, max_dims=None): + """ + Get the maximum dimensions at each level of nesting. + + Args: + nested_list (`list`): + Nested list structure. + max_dims (`list`, *optional*): + Current maximum dimensions (for recursion). + + Returns: + `list`: A list of maximum dimensions for each nesting level. + """ + if max_dims is None: + max_dims = [] + + if not isinstance(nested_list, list): + return max_dims + + if len(max_dims) == 0: + max_dims.append(len(nested_list)) + else: + max_dims[0] = max(max_dims[0], len(nested_list)) + + if len(nested_list) > 0: + for item in nested_list: + if isinstance(item, list): + sub_dims = self._get_nested_dimensions(item) + # Merge sub_dims into max_dims + for i, dim in enumerate(sub_dims): + if i + 1 >= len(max_dims): + max_dims.append(dim) + else: + max_dims[i + 1] = max(max_dims[i + 1], dim) + + return max_dims + + def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None): + """ + Recursively pad a nested list to match target dimensions. + + Args: + nested_list (`list`): + Nested list to pad. + target_dims (`list`): + Target dimensions for each level. + current_level (`int`, *optional*, defaults to 0): + Current nesting level. + pad_value (`int`, *optional*): + Value to use for padding. + + Returns: + `list`: The padded nested list. + """ + if pad_value is None: + pad_value = self.point_pad_value + + if current_level >= len(target_dims): + return nested_list + + # Ensure we have a list + if not isinstance(nested_list, list): + nested_list = [nested_list] + + # Pad current level + current_size = len(nested_list) + target_size = target_dims[current_level] + + # Pad with appropriate values + if current_level == len(target_dims) - 1: + # At the coordinate level, pad with pad_value + nested_list.extend([pad_value] * (target_size - current_size)) + else: + # At higher levels, pad with nested structures + if current_size > 0: + # Create appropriately sized template + if current_level < len(target_dims) - 2: + # For non-coordinate levels, create empty nested structure + template_dims = target_dims[current_level + 1 :] + template = self._create_empty_nested_structure(template_dims, pad_value) + else: + # For coordinate level, create list of pad_values + template = [pad_value] * target_dims[current_level + 1] + + nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)]) + else: + # Create from scratch + template_dims = target_dims[current_level + 1 :] + template = self._create_empty_nested_structure(template_dims, pad_value) + nested_list.extend([deepcopy(template) for _ in range(target_size)]) + + # Recursively pad sublists + if current_level < len(target_dims) - 1: + for i in range(len(nested_list)): + if isinstance(nested_list[i], list): + nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value) + + return nested_list + + def _create_empty_nested_structure(self, dims, pad_value): + """ + Create an empty nested structure with given dimensions filled with pad_value. + + Args: + dims (`list`): + The dimensions of the nested structure. + pad_value (`int`): + The value to fill the structure with. + """ + if len(dims) == 1: + return [pad_value] * dims[0] + else: + return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])] + + def _get_nesting_level(self, input_list): + """ + Get the nesting level of a list structure. + + Args: + input_list (`list`): + The list to get the nesting level of. + """ + if isinstance(input_list, list): + if len(input_list) == 0: + return 1 + return 1 + self._get_nesting_level(input_list[0]) + elif isinstance(input_list, (np.ndarray, ms.Tensor)): + # For arrays/tensors, the nesting level is the number of dimensions + return len(input_list.shape) + return 0 + + def _validate_single_input( + self, + data: Union[ms.Tensor, np.ndarray, list], + expected_depth: int, + input_name: str, + expected_format: str, + expected_coord_size: Optional[int] = None, + ) -> list: + """ + Validate a single input by ensuring proper nesting and raising an error if the input is not valid. + + Args: + data (`ms.Tensor`, `np.ndarray`, or `list`): + Input data to process. + expected_depth (`int`): + Expected nesting depth. + input_name (`str`): + Name of the input for error messages. + expected_format (`str`): + The expected format of the input. + expected_coord_size (`int`, *optional*): + Expected coordinate size (2 for points, 4 for boxes, None for labels). + . + """ + if data is None: + return None + + # Handle tensors and numpy arrays first + if isinstance(data, (ms.Tensor, np.ndarray)): + # For tensors/arrays, we can directly check the number of dimensions + if data.ndim != expected_depth: + raise ValueError( + f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions." + ) + elif expected_coord_size is not None: + if data.shape[-1] != expected_coord_size: + raise ValueError( + f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}." + ) + return self._convert_to_nested_list(data, expected_depth) + + # Handle nested lists + if isinstance(data, list): + current_depth = self._get_nesting_level(data) + if current_depth != expected_depth: + raise ValueError( + f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels." + ) + return self._convert_to_nested_list(data, expected_depth) + + def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False): + """ + Helper method to normalize coordinates in a tensor across multiple images. + + Args: + tensor (`ms.Tensor`): + Input tensor with coordinates. + original_sizes (`list`): + Original image sizes. + is_bounding_box (`bool`, *optional*, defaults to `False`): + Whether coordinates are bounding boxes. + preserve_padding (`bool`, *optional*, defaults to `False`): + Whether to preserve padding values (for points). + """ + if preserve_padding: + # For points: avoid normalizing pad values + mask = tensor != self.point_pad_value + coord_mask = mask.all(dim=-1, keepdim=True) + + for img_idx in range(len(original_sizes)): + if img_idx < tensor.shape[0]: + original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0] + normalized_coords = self._normalize_coordinates( + self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box + ) + + if preserve_padding: + # Only update non-padded values + img_mask = coord_mask[img_idx] + tensor[img_idx] = mint.where( + img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx] + ) + else: + tensor[img_idx] = normalized_coords + + def post_process_masks( + self, + masks, + original_sizes, + mask_threshold=0.0, + binarize=True, + max_hole_area=0.0, + max_sprinkle_area=0.0, + apply_non_overlapping_constraints=False, + **kwargs, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[ms.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[ms.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarization and post-processing operations. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + max_hole_area (`float`, *optional*, defaults to 0.0): + The maximum area of a hole to fill. + max_sprinkle_area (`float`, *optional*, defaults to 0.0): + The maximum area of a sprinkle to fill. + apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): + Whether to apply non-overlapping constraints to the masks. + + Returns: + (`ms.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + return self.image_processor.post_process_masks( + masks, + original_sizes, + mask_threshold, + binarize, + max_hole_area, + max_sprinkle_area, + apply_non_overlapping_constraints, + **kwargs, + ) + + +__all__ = ["Sam2Processor"]