diff --git a/mindone/transformers/models/sam2_video/__init__.py b/mindone/transformers/models/sam2_video/__init__.py new file mode 100644 index 0000000000..eba37b7392 --- /dev/null +++ b/mindone/transformers/models/sam2_video/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .modeling_sam2_video import Sam2VideoInferenceSession, Sam2VideoModel, Sam2VideoPreTrainedModel +from .processing_sam2_video import Sam2VideoProcessor +from .video_processing_sam2_video import Sam2VideoVideoProcessor diff --git a/mindone/transformers/models/sam2_video/modeling_sam2_video.py b/mindone/transformers/models/sam2_video/modeling_sam2_video.py new file mode 100644 index 0000000000..6d888e3835 --- /dev/null +++ b/mindone/transformers/models/sam2_video/modeling_sam2_video.py @@ -0,0 +1,2572 @@ + +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam2_video/modular_sam2_video.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam2_video.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections import OrderedDict +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.nn.functional as F +from mindspore import Tensor, nn, ops +from mindspore.common.initializer import Zero, initializer +from tqdm import tqdm + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput +from ...utils.generic import OutputRecorder +from ..auto import AutoModel +from transformers import Sam2VideoConfig, Sam2VideoMaskDecoderConfig, Sam2VideoPromptEncoderConfig + + +class Sam2VideoInferenceCache: + """Cache for vision features and model constants.""" + + def __init__( + self, + max_vision_features_cache_size: int = 1, + ): + self.max_vision_features_cache_size = max_vision_features_cache_size + + self._vision_features = {} + + def cache_vision_features(self, frame_idx: int, features: dict): + """Cache vision features with automatic device management.""" + cached = {} + if len(self._vision_features) >= self.max_vision_features_cache_size: + # remove the oldest frame + self._vision_features.pop(min(self._vision_features.keys())) + for key, value in features.items(): + cached[key] = value + self._vision_features[frame_idx] = cached + + def get_vision_features(self, frame_idx: int) -> Optional[dict]: + """Get cached vision features, automatically moved to inference device.""" + if frame_idx not in self._vision_features: + return None + cached = self._vision_features[frame_idx] + moved = {} + for key, value in cached.items(): + moved[key] = value + return moved + + def clear_all(self): + """Clear all cached data.""" + self._vision_features.clear() + + +class Sam2VideoInferenceSession: + r""" + Manages video inference session parameters, state and cache. + + Args: + video (`ms.Tensor`, *optional*): + The video to process. No need to provide when streaming. + video_height (`int`, *optional*): + The height of the video. + video_width (`int`, *optional*): + The width of the video. + dtype (`ms.dtype`, *optional*, defaults to `"float32"`): + The dtype to use for the video. + max_vision_features_cache_size (`int`, *optional*, defaults to 1): + The maximum number of vision features to cache. + """ + + def __init__( + self, + video: Optional[ms.Tensor] = None, + video_height: Optional[int] = None, + video_width: Optional[int] = None, + dtype: ms.dtype = ms.float32, + max_vision_features_cache_size: int = 1, + ): + # store as a dictionary to avoid double memory allocation with mint.cat when adding new frames + self.processed_frames = ( + dict(enumerate(video.to(dtype=dtype))) if video is not None else None + ) + self.video_height = video_height + self.video_width = video_width + self.dtype = dtype + self.max_vision_features_cache_size = max_vision_features_cache_size + + # Cache for computed features + self.cache = Sam2VideoInferenceCache( + max_vision_features_cache_size=self.max_vision_features_cache_size, + ) + + # Persistent object tracking state + self._obj_id_to_idx = OrderedDict() + self._obj_idx_to_id = OrderedDict() + self.obj_ids = [] + + # Persistent user inputs + self.point_inputs_per_obj = {} + self.mask_inputs_per_obj = {} + + # Persistent model outputs/history + self.output_dict_per_obj = {} + self.frames_tracked_per_obj = {} + + # Session state flags + self.obj_with_new_inputs = [] + + @property + def num_frames(self) -> Optional[int]: + return len(self.processed_frames) if self.processed_frames is not None else None + + # Object management + def obj_id_to_idx(self, obj_id: int) -> int: + """Map object ID to index, creating new entry if needed.""" + obj_idx = self._obj_id_to_idx.get(obj_id, None) + if obj_idx is not None: + return obj_idx + + obj_idx = len(self._obj_id_to_idx) + self._obj_id_to_idx[obj_id] = obj_idx + self._obj_idx_to_id[obj_idx] = obj_id + self.obj_ids = list(self._obj_id_to_idx) + + self.point_inputs_per_obj[obj_idx] = {} + self.mask_inputs_per_obj[obj_idx] = {} + self.output_dict_per_obj[obj_idx] = { + "cond_frame_outputs": {}, + "non_cond_frame_outputs": {}, + } + self.frames_tracked_per_obj[obj_idx] = {} + + return obj_idx + + # Video Inference specific functions + def obj_idx_to_id(self, obj_idx: int) -> int: + """Map model-side object index to client-side object id.""" + return self._obj_idx_to_id[obj_idx] + + def get_obj_num(self) -> int: + """Get the total number of unique object ids received so far in this session.""" + return len(self._obj_idx_to_id) + + # Input management with device handling + def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict): + """Add point inputs with automatic device placement.""" + device_inputs = {} + for key, value in inputs.items(): + device_inputs[key] = value + self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs + + def remove_point_inputs(self, obj_idx: int, frame_idx: int): + """Remove point inputs.""" + self.point_inputs_per_obj[obj_idx].pop(frame_idx, None) + + def add_mask_inputs(self, obj_idx: int, frame_idx: int, inputs: ms.Tensor): + """Add mask inputs with automatic device placement.""" + self.mask_inputs_per_obj[obj_idx][frame_idx] = inputs.to( + dtype=self.dtype, non_blocking=True + ) + + def remove_mask_inputs(self, obj_idx: int, frame_idx: int): + """Remove mask inputs.""" + self.mask_inputs_per_obj[obj_idx].pop(frame_idx, None) + + # Output management with smart device placement + def store_output( + self, + obj_idx: int, + frame_idx: int, + output_key: Optional[str] = None, + output_value: Optional[Union[ms.Tensor, dict]] = None, + is_conditioning_frame: bool = True, + ): + """ + Store output with smart device management. + If output_key is None, the output is stored as a dictionary. + + Args: + obj_idx (int): The index of the object. + frame_idx (int): The index of the frame. + output_key (Optional[str]): The key of the output. If None, the output is stored as a dictionary. + output_value (Optional[Union[ms.Tensor, dict]]): The value of the output. + is_conditioning_frame (bool): Whether the output is for a conditioning frame. + """ + storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs" + + if output_key is None and isinstance(output_value, dict): + self.output_dict_per_obj[obj_idx][storage_key][frame_idx] = {} + for key, value in output_value.items(): + self.store_output(obj_idx, frame_idx, key, value, is_conditioning_frame) + return + + # Device placement: small tensors stay on inference device, large ones are stored directly + if output_key in ["object_pointer", "object_score_logits"]: # Small tensors + self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value + elif isinstance(output_value, ms.Tensor): # Large tensors like masks, features + self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value + else: + self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value + + def get_output( + self, + obj_idx: int, + frame_idx: int, + output_key: str, + is_conditioning_frame: bool = True, + ): + """ + Get output with smart device management. + + Args: + obj_idx (int): The index of the object. + frame_idx (int): The index of the frame. + output_key (str): The key of the output. + is_conditioning_frame (bool): Whether the output is for a conditioning frame. + """ + storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs" + out = self.output_dict_per_obj[obj_idx][storage_key].get(frame_idx, None) + # move to inference device if needed + if out is None: + return None + value = out[output_key] + return value + + # Video frame management + def add_new_frame(self, pixel_values: ms.Tensor, frame_idx: Optional[int] = None) -> int: + """Add new frame with automatic device placement.""" + pixel_values = pixel_values.to(dtype=self.dtype, non_blocking=True) + if pixel_values.dim() == 4: + pixel_values = pixel_values.squeeze(0) + + if frame_idx is None: + frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0 + + if self.processed_frames is None: + self.processed_frames = {frame_idx: pixel_values} + else: + self.processed_frames[frame_idx] = pixel_values + + return frame_idx + + def get_frame(self, frame_idx: int) -> ms.Tensor: + """Get frame from video.""" + return self.processed_frames[frame_idx] + + def reset_tracking_data(self): + """Reset tracking data but keep cache.""" + self._obj_id_to_idx.clear() + self._obj_idx_to_id.clear() + self.obj_ids.clear() + self.point_inputs_per_obj.clear() + self.mask_inputs_per_obj.clear() + self.output_dict_per_obj.clear() + self.frames_tracked_per_obj.clear() + self.obj_with_new_inputs = [] + # Note: cache and video data are preserved + + def reset_inference_session(self): + """Reset tracking data and cache.""" + self._obj_id_to_idx.clear() + self._obj_idx_to_id.clear() + self.obj_ids.clear() + self.point_inputs_per_obj.clear() + self.mask_inputs_per_obj.clear() + self.output_dict_per_obj.clear() + self.frames_tracked_per_obj.clear() + self.obj_with_new_inputs = [] + self.cache.clear_all() + + +class Sam2VideoLayerNorm(mint.nn.LayerNorm): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs): + super().__init__(normalized_shape, eps=eps, **kwargs) + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {data_format}") + self.data_format = data_format + + def construct(self, features: ms.Tensor) -> ms.Tensor: + """ + Args: + features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels) + """ + if self.data_format == "channels_first": + features = features.permute(0, 2, 3, 1) + features = super().construct(features) + features = features.permute(0, 3, 1, 2) + else: + features = super().construct(features) + return features + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class Sam2VideoPositionEmbeddingSine(nn.Cell): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale + + def construct( + self, + shape, + dtype: ms.dtype, + mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + if mask is None: + mask = mint.zeros((shape[0], shape[2], shape[3]), dtype=ms.bool) + not_mask = (~mask).to(dtype) + y_embed = not_mask.cumsum(1) + x_embed = not_mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = mint.arange(self.num_pos_feats, dtype=ms.int64).to(dtype) + dim_t = self.temperature ** (2 * mint.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = mint.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = mint.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = mint.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = ops.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = mint.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=module.training) + attn_output = ops.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Sam2VideoAttention(nn.Cell): + """ + SAM2_VIDEO's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + self.config = config + self.hidden_size = config.hidden_size + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.internal_dim // config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.q_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.o_proj = mint.nn.Linear(self.internal_dim, self.hidden_size) + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_similarity: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ms.Tensor, ms.Tensor]: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=attention_similarity, + dropout=0.0, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class Sam2VideoTwoWayAttentionBlock(nn.Cell): + def __init__(self, config: Sam2VideoMaskDecoderConfig, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`Sam2VideoMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + self.self_attn = Sam2VideoAttention(config, downsample_rate=1) + self.layer_norm1 = mint.nn.LayerNorm(config.hidden_size) + + self.cross_attn_token_to_image = Sam2VideoAttention(config) + self.layer_norm2 = mint.nn.LayerNorm(config.hidden_size) + + self.mlp = Sam2VideoFeedForward( + config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers + ) + self.layer_norm3 = mint.nn.LayerNorm(config.hidden_size) + + self.layer_norm4 = mint.nn.LayerNorm(config.hidden_size) + self.cross_attn_image_to_token = Sam2VideoAttention(config) + + self.skip_first_layer_pe = skip_first_layer_pe + + def construct( + self, + queries: ms.Tensor, + keys: ms.Tensor, + query_point_embedding: ms.Tensor, + key_point_embedding: ms.Tensor, + attention_similarity: ms.Tensor, + **kwargs: Unpack[FlashAttentionKwargs], + ): + # Self attention block + if self.skip_first_layer_pe: + queries, _ = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out, _ = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + return queries, keys, attn_out + + +class Sam2VideoFeedForward(nn.Cell): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: str = "relu", + sigmoid_output: bool = False, + ): + super().__init__() + self.num_layers = num_layers + self.activation = ACT2FN[activation] + self.proj_in = mint.nn.Linear(input_dim, hidden_dim) + self.proj_out = mint.nn.Linear(hidden_dim, output_dim) + self.layers = nn.CellList([mint.nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def construct(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +@dataclass +class Sam2VideoImageSegmentationOutput(ModelOutput): + r""" + iou_scores (`ms.Tensor` of shape `(batch_size, point_batch_size, num_masks)`): + The Intersection over Union (IoU) scores of the predicted masks. + pred_masks (`ms.Tensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`): + The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed + by the processor to be brought to the original image size. + object_score_logits (`ms.Tensor` of shape `(batch_size, point_batch_size, 1)`): + Logits for the object score, indicating if an object is present. + image_embeddings (`tuple(ms.Tensor)`): + The features from the FPN, which are used by the mask decoder. This is a tuple of `ms.Tensor` where each + tensor has shape `(batch_size, channels, height, width)`. + vision_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `ms.Tensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. + Hidden-states of the vision model at the output of each stage. + vision_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the vision model. + mask_decoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights of the mask decoder. + high_res_masks (`ms.Tensor` of shape `(batch_size, point_batch_size, num_masks, image_size, image_size)`, *optional*): + The predicted masks, upscaled to the original image size. Only used for Sam2VideoModel. + object_pointer (`ms.Tensor` of shape `(batch_size, point_batch_size, hidden_size)`, *optional*): + A tensor representing the object pointer, used for tracking in videos. Only used for Sam2VideoModel. + """ + + iou_scores: Optional[ms.Tensor] = None + pred_masks: Optional[ms.Tensor] = None + object_score_logits: Optional[ms.Tensor] = None + image_embeddings: tuple[ms.Tensor, ...] = None + vision_hidden_states: Optional[tuple[ms.Tensor, ...]] = None + vision_attentions: Optional[tuple[ms.Tensor, ...]] = None + mask_decoder_attentions: Optional[tuple[ms.Tensor, ...]] = None + + high_res_masks: Optional[ms.Tensor] = None + object_pointer: Optional[ms.Tensor] = None + + +@dataclass +class Sam2VideoSegmentationOutput(ModelOutput): + r""" + pred_masks (`ms.Tensor` of shape `(batch_size, num_masks, height, width)`): + The predicted masks stored at the model's resolution. + frame_idx (`int`): + The frame index of the video. + """ + + pred_masks: Optional[ms.Tensor] = None + frame_idx: Optional[int] = None + + +class Sam2VideoPreTrainedModel(PreTrainedModel): + config_class = Sam2VideoConfig + base_model_prefix = "sam2_video" + main_input_name = "pixel_values" + input_modalities = "video" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Sam2VideoModel): + if module.no_memory_positional_encoding is not None: + module.no_memory_positional_encoding.set_data(initializer(Zero(), module.no_memory_positional_encoding.shape, module.no_memory_positional_encoding.dtype)) + if module.memory_temporal_positional_encoding is not None: + module.memory_temporal_positional_encoding.set_data(initializer(Zero(), module.memory_temporal_positional_encoding.shape, module.memory_temporal_positional_encoding.dtype)) + if module.no_object_pointer is not None: + module.no_object_pointer.set_data(initializer(Zero(), module.no_object_pointer.shape, module.no_object_pointer.dtype)) + if module.occlusion_spatial_embedding_parameter is not None: + module.occlusion_spatial_embedding_parameter.set_data(initializer(Zero(), module.occlusion_spatial_embedding_parameter.shape, module.occlusion_spatial_embedding_parameter.dtype)) + + if isinstance(module, Sam2VideoMemoryFuserCXBlock): + if module.scale is not None: + module.scale.set_data(initializer(Zero(), module.scale.shape, module.scale.dtype)) + +class Sam2VideoVisionRotaryEmbedding(nn.Cell): + """ + Vision Rotary Position Embedding for SAM2, following transformers library standards. + Supports 2D (axial) rotary embeddings for spatial dimensions. + """ + + def __init__(self, config: Sam2VideoConfig): + super().__init__() + dim = config.memory_attention_hidden_size // ( + config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads + ) + # Ensure even dimension for proper axial splitting + if dim % 4 != 0: + raise ValueError("Dimension must be divisible by 4 for axial RoPE") + end_x, end_y = config.memory_attention_rope_feat_sizes + freqs = 1.0 / (config.memory_attention_rope_theta ** (mint.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + # Generate 2D position indices for axial rotary embedding + flattened_indices = mint.arange(end_x * end_y, dtype=ms.long) + x_positions = flattened_indices % end_x + y_positions = mint.div(flattened_indices, end_x, rounding_mode="floor") + freqs_x = mint.outer(x_positions, freqs).float() + freqs_y = mint.outer(y_positions, freqs).float() + inv_freq = mint.cat([freqs_x, freqs_y], dim=-1) + inv_freq = inv_freq.repeat_interleave(2, dim=-1) + # directly register the cos and sin embeddings as we have a fixed feature shape + self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False) + self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False) + + def construct(self) -> tuple[ms.Tensor, ms.Tensor]: + # As the feature map size is fixed, we can just return the pre-computed embeddings. + return self.rope_embeddings_cos, self.rope_embeddings_sin + + +def rotate_pairwise(x): + """ + pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation. + + This is an optimized version of the following more explicit implementation: + ```python + x_rotated = mint.zeros_like(x, dtype=x.dtype) + x_rotated[..., ::2] = -x[..., 1::2] + x_rotated[..., 1::2] = x[..., ::2] + return x_rotated + ``` + """ + x = x.view(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = mint.stack((-x2, x1), dim=-1) + return x.flatten(start_dim=-2) + + +# TODO: This leads to ~1e-07 max diff and ~1e-09 avg diff for q_embed and k_embed from the original implementation, most likely due to the use of complex tensors in the original implementation. +def apply_rotary_pos_emb_2d( + q: ms.Tensor, + k: ms.Tensor, + cos: ms.Tensor, + sin: ms.Tensor, + num_k_exclude_rope: int = 0, + repeat_freqs_k: bool = False, +) -> tuple[ms.Tensor, ms.Tensor]: + """ + Apply rotary position embedding to query and key tensors for vision models. + Follows the standard transformers library pattern. + + Args: + q: Query tensor of shape (..., seq_len, head_dim) + k: Key tensor of shape (..., seq_len, head_dim) + cos: Cosine position embedding of shape (seq_len, head_dim) + sin: Sine position embedding of shape (seq_len, head_dim) + repeat_freqs_k: Whether to repeat frequencies for keys (for cross-attention) + + Returns: + Rotated (q, k) tensors + """ + k_rot, k_pass = k[..., : k.shape[-2] - num_k_exclude_rope, :], k[..., k.shape[-2] - num_k_exclude_rope :, :] + q_embed = q.float() # force upscale to float32 as in the original implementation + q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin) + if k_rot.shape[-2] == 0: + # Handle case where keys might be empty due to dropout + return q_embed.type_as(q), mint.cat([k_rot, k_pass], dim=-2) + + # Handle key tensor - may need to repeat frequencies if different sequence length + if repeat_freqs_k and k_rot.shape[-2] != q.shape[-2]: + # Repeat cos/sin to match key sequence length + repeat_factor = k_rot.shape[-2] // q.shape[-2] + cos_k = cos.repeat(1, 1, repeat_factor, 1) + sin_k = sin.repeat(1, 1, repeat_factor, 1) + else: + cos_k = cos + sin_k = sin + + # Apply rotary embedding to keys + k_embed = k_rot.float() # force upscale to float32 as in the original implementation + k_embed = (k_embed * cos_k) + (rotate_pairwise(k_embed) * sin_k) + # Concatenate back to full shape + k_embed = mint.cat([k_embed.type_as(k), k_pass], dim=-2) + return q_embed.type_as(q), k_embed + + +class Sam2VideoRoPEAttention(nn.Cell): + """Attention with rotary position encoding.""" + + def __init__( + self, + config: Sam2VideoConfig, + kv_in_dim: Optional[int] = None, + rope_k_repeat=False, + ): + super().__init__() + self.config = config + self.hidden_size = config.memory_attention_hidden_size + self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate + self.num_attention_heads = config.memory_attention_num_attention_heads + self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else self.hidden_size + + self.q_proj = mint.nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = mint.nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = mint.nn.Linear(self.kv_in_dim, self.internal_dim) + self.o_proj = mint.nn.Linear(self.internal_dim, self.hidden_size) + + self.rope_k_repeat = rope_k_repeat + self.dropout_p = config.memory_attention_rope_dropout + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + num_k_exclude_rope: int = 0, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> ms.Tensor: + # Input projections + batch_size, point_batch_size = query.shape[:2] + new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim) + + query = self.q_proj(query).view(*new_shape).transpose(1, 2) + key = self.k_proj(key).view(*new_shape).transpose(1, 2) + value = self.v_proj(value).view(*new_shape).transpose(1, 2) + + cos, sin = position_embeddings + # Apply rotary position encoding, excluding some keys if specified + query, key = apply_rotary_pos_emb_2d( + query, key, cos, sin, repeat_freqs_k=self.rope_k_repeat, num_k_exclude_rope=num_k_exclude_rope + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask=None, + dropout=0.0 if not self.training else self.dropout_p, + scaling=self.scaling, + is_causal=self.is_causal, + **kwargs, + ) + attn_output = attn_output.reshape( + batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Sam2VideoMemoryAttentionLayer(nn.Cell): + def __init__(self, config: Sam2VideoConfig): + super().__init__() + hidden_size = config.memory_attention_hidden_size + self.self_attn = Sam2VideoRoPEAttention(config) + self.cross_attn_image = Sam2VideoRoPEAttention(config, kv_in_dim=64, rope_k_repeat=True) + + # Implementation of Feedforward model + self.linear1 = mint.nn.Linear(hidden_size, config.memory_attention_feed_forward_hidden_size) + self.dropout = mint.nn.Dropout(config.memory_attention_dropout) + self.linear2 = mint.nn.Linear(config.memory_attention_feed_forward_hidden_size, hidden_size) + + self.layer_norm1 = mint.nn.LayerNorm(hidden_size) + self.layer_norm2 = mint.nn.LayerNorm(hidden_size) + self.layer_norm3 = mint.nn.LayerNorm(hidden_size) + self.dropout1 = mint.nn.Dropout(config.memory_attention_dropout) + self.dropout2 = mint.nn.Dropout(config.memory_attention_dropout) + self.dropout3 = mint.nn.Dropout(config.memory_attention_dropout) + + self.activation = ACT2FN[config.memory_attention_feed_forward_hidden_act] + + def construct( + self, + queries: Tensor, + keys: Tensor, + key_point_embedding: Tensor, + rope_position_embeddings: tuple[Tensor, Tensor], + num_k_exclude_rope: int = 0, + ) -> ms.Tensor: + # Self-Attention + query = self.layer_norm1(queries) + query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings) + queries = queries + self.dropout1(query) + + # Cross-Attention + query = self.layer_norm2(queries) + query, _ = self.cross_attn_image( + query=query, + key=keys + key_point_embedding, + value=keys, + position_embeddings=rope_position_embeddings, + num_k_exclude_rope=num_k_exclude_rope, + ) + queries = queries + self.dropout2(query) + # MLP + query = self.layer_norm3(queries) + query = self.linear2(self.dropout(self.activation(self.linear1(query)))) + queries = queries + self.dropout3(query) + return queries + + +class Sam2VideoMemoryAttention(nn.Cell): + def __init__(self, config: Sam2VideoConfig): + super().__init__() + self.layers = nn.CellList( + [Sam2VideoMemoryAttentionLayer(config) for _ in range(config.memory_attention_num_layers)] + ) + self.layer_norm = mint.nn.LayerNorm(config.memory_attention_hidden_size) + self.rotary_emb = Sam2VideoVisionRotaryEmbedding(config=config) + + def construct( + self, + current_vision_features: ms.Tensor, + memory: ms.Tensor, + current_vision_position_embeddings: Optional[Tensor] = None, + memory_posision_embeddings: Optional[Tensor] = None, + num_object_pointer_tokens: int = 0, + ): + """ + Args: + current_vision_features (`ms.Tensor`): + The current vision features used for self-attention. + memory (`ms.Tensor`): + The memory features used for cross-attention. + current_vision_position_embeddings (`ms.Tensor`, *optional*): + The position embeddings for the current vision features. + memory_posision_embeddings (`ms.Tensor`, *optional*): + The position embeddings for the memory features. + num_object_pointer_tokens (`int`, *optional*, defaults to 0): + The number of object pointer tokens. + """ + output = current_vision_features + if current_vision_position_embeddings is not None: + output = output + 0.1 * current_vision_position_embeddings + + # Convert to batch first + output = output.transpose(0, 1) + memory = memory.transpose(0, 1).unsqueeze(1) + memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1) + rope_position_embeddings = self.rotary_emb() + for layer in self.layers: + output = layer( + queries=output.unsqueeze(1) if output.ndim == 3 else output, + keys=memory, + key_point_embedding=memory_posision_embeddings, + rope_position_embeddings=rope_position_embeddings, + num_k_exclude_rope=num_object_pointer_tokens, + ) + + normed_output = self.layer_norm(output) + + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + + return normed_output + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class Sam2VideoMemoryFuserCXBlock(GradientCheckpointingLayer): + def __init__(self, config: Sam2VideoConfig): + super().__init__() + self.depthwise_conv = mint.nn.Conv2d( + config.memory_fuser_embed_dim, + config.memory_fuser_embed_dim, + kernel_size=config.memory_fuser_kernel_size, + padding=config.memory_fuser_padding, + groups=config.memory_fuser_embed_dim, + ) # depthwise conv + self.layer_norm = Sam2VideoLayerNorm(config.memory_fuser_embed_dim, eps=1e-6, data_format="channels_first") + self.activation = ACT2FN[config.memory_fuser_hidden_act] + self.pointwise_conv1 = mint.nn.Linear( + config.memory_fuser_embed_dim, config.memory_fuser_intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.pointwise_conv2 = mint.nn.Linear(config.memory_fuser_intermediate_dim, config.memory_fuser_embed_dim) + self.scale = ms.Parameter( + config.memory_fuser_layer_scale_init_value * mint.ones((config.memory_fuser_embed_dim)), + requires_grad=True, + ) + + def construct(self, hidden_states): + input = hidden_states + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + hidden_states = self.pointwise_conv1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.scale * hidden_states + hidden_states = hidden_states.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + hidden_states = input + hidden_states + return hidden_states + + +class Sam2VideoMemoryFuser(nn.Cell): + def __init__(self, config: Sam2VideoConfig): + super().__init__() + self.layers = nn.CellList( + [Sam2VideoMemoryFuserCXBlock(config) for _ in range(config.memory_fuser_num_layers)] + ) + + def construct(self, hidden_states): + # normally hidden_states: (N, C, H, W) + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class Sam2VideoMaskDownSamplerLayer(nn.Cell): + def __init__(self, config: Sam2VideoConfig, in_channels: int, out_channels: int): + super().__init__() + self.conv = mint.nn.Conv2d( + in_channels, + out_channels, + kernel_size=config.mask_downsampler_kernel_size, + stride=config.mask_downsampler_stride, + padding=config.mask_downsampler_padding, + ) + self.layer_norm = Sam2VideoLayerNorm(out_channels, eps=1e-6, data_format="channels_first") + self.activation = ACT2FN[config.mask_downsampler_hidden_act] + + def construct(self, x): + return self.activation(self.layer_norm(self.conv(x))) + + +class Sam2VideoMaskDownSampler(nn.Cell): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__(self, config: Sam2VideoConfig): + super().__init__() + + num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride)) + + self.layers = nn.CellList() + self.activation = ACT2FN[config.mask_downsampler_hidden_act] + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2) + self.layers.append(Sam2VideoMaskDownSamplerLayer(config, mask_in_chans, mask_out_chans)) + mask_in_chans = mask_out_chans + + self.final_conv = mint.nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1) + + def construct(self, x): + for layer in self.layers: + x = layer(x) + x = self.final_conv(x) + return x + + +class Sam2VideoMemoryEncoder(nn.Cell): + def __init__(self, config: Sam2VideoConfig): + super().__init__() + + hidden_size = config.memory_encoder_hidden_size + output_channels = config.memory_encoder_output_channels + self.mask_downsampler = Sam2VideoMaskDownSampler(config) + self.feature_projection = mint.nn.Conv2d(hidden_size, hidden_size, kernel_size=1) + self.memory_fuser = Sam2VideoMemoryFuser(config) + self.position_encoding = Sam2VideoPositionEmbeddingSine(num_pos_feats=output_channels // 2, normalize=True) + self.projection = mint.nn.Conv2d(hidden_size, output_channels, kernel_size=1) + + def construct( + self, + vision_features: ms.Tensor, + masks: ms.Tensor, + ) -> tuple[ms.Tensor, ms.Tensor]: + ## Process masks + masks = self.mask_downsampler(masks) + ## Fuse pixel_features and downsampled masks + + vision_features = self.feature_projection(vision_features) + vision_features = vision_features + masks + vision_features = self.memory_fuser(vision_features) + vision_features = self.projection(vision_features) + + vision_pos_enc = self.position_encoding(vision_features.shape, vision_features.dtype) + + return vision_features, vision_pos_enc + + +@dataclass +class Sam2VideoVisionEncoderOutput(ModelOutput): + r""" + last_hidden_state (`ms.Tensor` of shape `(batch_size, height, width, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + fpn_hidden_states (`tuple(ms.Tensor)`): + Tuple of `ms.Tensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck. + fpn_position_encoding (`tuple(ms.Tensor)`): + Tuple of `ms.Tensor` (one for each feature level, from high to low resolution) of shape + `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`. + hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the + model at the output of each stage. + attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[ms.Tensor] = None + fpn_hidden_states: Optional[ms.Tensor] = None + fpn_position_encoding: Optional[ms.Tensor] = None + hidden_states: Optional[tuple[ms.Tensor, ...]] = None + attentions: Optional[tuple[ms.Tensor, ...]] = None + + +class Sam2VideoPositionalEmbedding(nn.Cell): + def __init__(self, config: Sam2VideoPromptEncoderConfig): + super().__init__() + self.scale = config.scale + self.positional_embedding = ms.Parameter(self.scale * ops.randn((2, config.hidden_size // 2))) + #self.register_buffer("positional_embedding", positional_embedding) + + def construct(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + coordinates.to(ms.float32) + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return mint.cat([mint.sin(coordinates), mint.cos(coordinates)], dim=-1) + + +class Sam2VideoMaskEmbedding(nn.Cell): + def __init__(self, config: Sam2VideoPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = mint.nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = mint.nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = mint.nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = Sam2VideoLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = Sam2VideoLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def construct(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class Sam2VideoPromptEncoder(nn.Cell): + def __init__(self, config: Sam2VideoPromptEncoderConfig): + super().__init__() + self.shared_embedding = Sam2VideoPositionalEmbedding(config) + self.mask_embed = Sam2VideoMaskEmbedding(config) + self.no_mask_embed = mint.nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size) + self.input_image_size = config.image_size + + self.point_embed = mint.nn.Embedding(config.num_point_embeddings, config.hidden_size) + self.hidden_size = config.hidden_size + self.not_a_point_embed = mint.nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: ms.Tensor, labels: ms.Tensor, pad: bool) -> ms.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + points = F.pad(points, (0, 0, 0, 1), mode="constant", value=0) + labels = F.pad(labels, (0, 1), mode="constant", value=-1) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # mint.where and expanding the labels tensor is required by the ONNX export + point_embedding = mint.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = mint.where( + labels[..., None] != -10, + point_embedding, + mint.zeros_like(point_embedding), + ) + + # Add point embeddings for labels >= 0 + point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1) + + return point_embedding + + def _embed_boxes(self, boxes: ms.Tensor) -> ms.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.view(*boxes.shape[:2], 2, 2) + # add padding point for consistency with the original implementation + coords = F.pad(coords, (0, 0, 0, 1), mode="constant", value=0) + corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size)) + corner_embedding[:, :, 0, :] += self.point_embed.weight[2] + corner_embedding[:, :, 1, :] += self.point_embed.weight[3] + corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :]) + return corner_embedding + + def construct( + self, + input_points: Optional[tuple[ms.Tensor, ms.Tensor]], + input_labels: Optional[ms.Tensor], + input_boxes: Optional[ms.Tensor], + input_masks: Optional[ms.Tensor], + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`ms.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`ms.Tensor`, *optional*): + boxes to embed + masks (`ms.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + if input_points is not None: + batch_size = input_points.shape[0] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = mint.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + (batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + ) + + return sparse_embeddings, dense_embeddings + + +class Sam2VideoTwoWayTransformer(nn.Cell): + def __init__(self, config: Sam2VideoMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.CellList() + + for i in range(self.num_hidden_layers): + self.layers.append(Sam2VideoTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = Sam2VideoAttention(config) + self.layer_norm_final_attn = mint.nn.LayerNorm(config.hidden_size) + + def construct( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutput]: + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, _ = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + **kwargs, + ) + # Apply the final attention layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys + + +class Sam2VideoMaskDecoder(nn.Cell): + def __init__(self, config: Sam2VideoMaskDecoderConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = mint.nn.Embedding(1, self.hidden_size) + self.mask_tokens = mint.nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = Sam2VideoTwoWayTransformer(config) + + # should we create a new class for this? + self.upscale_conv1 = mint.nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = mint.nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = Sam2VideoLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [Sam2VideoFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.CellList(mlps_list) + self.iou_prediction_head = Sam2VideoFeedForward( + self.hidden_size, + config.iou_head_hidden_dim, + self.num_mask_tokens, + config.iou_head_depth, + sigmoid_output=True, + ) + + self.conv_s0 = mint.nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1) + self.conv_s1 = mint.nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1) + + self.obj_score_token = mint.nn.Embedding(1, self.hidden_size) + self.pred_obj_score_head = Sam2VideoFeedForward(self.hidden_size, self.hidden_size, 1, 3) + + self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh + + def construct( + self, + image_embeddings: ms.Tensor, + image_positional_embeddings: ms.Tensor, + sparse_prompt_embeddings: ms.Tensor, + dense_prompt_embeddings: ms.Tensor, + multimask_output: bool, + high_resolution_features: list[ms.Tensor], + attention_similarity: Optional[ms.Tensor] = None, + target_embedding: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor, ms.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Args: + image_embeddings (`ms.Tensor`): + The embeddings from the image encoder. + image_positional_embeddings (`ms.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`ms.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`ms.Tensor`): + The embeddings of the mask inputs. + multimask_output (`bool`): + Whether to return multiple masks or a single mask. + high_resolution_features (`list[ms.Tensor]`, *optional*): + The high-resolution features from the vision encoder. + attention_similarity (`ms.Tensor`, *optional*): + The attention similarity tensor. + target_embedding (`ms.Tensor`, *optional*): + The target embedding. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + # Concatenate output tokens + output_tokens = mint.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if sparse_prompt_embeddings.shape[0] != 0: + tokens = mint.cat((output_tokens, sparse_prompt_embeddings), dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + # Expand per-image data in batch direction to be per-mask + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + # Run the transformer + point_embeddings, image_embeddings = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + iou_token_out = point_embeddings[:, :, 1, :] + mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + image_embeddings = image_embeddings.transpose(2, 3).view( + batch_size * point_batch_size, num_channels, height, width + ) + + feat_s0, feat_s1 = high_resolution_features + feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0) + feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0) + upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1 + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0) + + hyper_in_list: list[ms.Tensor] = [] + for i in range(self.num_mask_tokens): + current_mlp = self.output_hypernetworks_mlps[i] + hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])] + hyper_in = mint.stack(hyper_in_list, dim=2) + + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width) + masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :]) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + elif self.dynamic_multimask_via_stability and not self.training: + mask_slice = slice(0, 1) + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + mask_slice = slice(0, 1) + masks = masks[:, :, mask_slice, :, :] + iou_pred = iou_pred[:, :, mask_slice] + + sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape + + return masks, iou_pred, sam_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = mint.sum(mask_logits > stability_delta, dim=-1).float() + area_u = mint.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = mint.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, :, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, :, 1:] + best_scores_inds = mint.argmax(multimask_iou_scores, dim=-1) # [B, P] + best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + best_scores_inds_expanded = best_scores_inds_expanded.expand( + (-1, -1, 1, multimask_logits.shape[-2], multimask_logits.shape[-1]) + ) + best_multimask_logits = mint.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W] + best_multimask_iou_scores = mint.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1] + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, :, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, :, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = mint.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = mint.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = mint.arange(pe_dim, dtype=ms.float32) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = mint.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +class Sam2VideoModel(Sam2VideoPreTrainedModel): + input_modalities = ["video", "text"] + _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)} + _keys_to_ignore_on_load_unexpected = [] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } + # need to be ignored, as it's a buffer and will not be correctly detected as tied weight + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config: Sam2VideoConfig): + super().__init__(config) + self.shared_image_embedding = Sam2VideoPositionalEmbedding(config.prompt_encoder_config) + self.vision_encoder = AutoModel.from_config(config.vision_config) + self.prompt_encoder = Sam2VideoPromptEncoder(config.prompt_encoder_config) + # The module using it is not a PreTrainedModel subclass so we need this + config.mask_decoder_config._attn_implementation = config._attn_implementation + self.mask_decoder = Sam2VideoMaskDecoder(config.mask_decoder_config) + + self.num_feature_levels = config.vision_config.num_feature_levels + self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes + # a single token to indicate no memory embedding from previous frames + self.hidden_dim = config.vision_config.fpn_hidden_size + self.no_memory_embedding = ms.Parameter(mint.zeros((1, 1, self.hidden_dim))) + self.config = config + # For video sequence inference + self.image_size = config.image_size + self.memory_attention = Sam2VideoMemoryAttention(config) + self.memory_encoder = Sam2VideoMemoryEncoder(config) + self.no_memory_positional_encoding = ms.Parameter( + mint.zeros((1, 1, config.vision_config.fpn_hidden_size)) + ) + self.mem_dim = config.memory_encoder_output_channels + self.num_maskmem = config.num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.memory_temporal_positional_encoding = ms.Parameter( + mint.zeros((self.num_maskmem, 1, 1, self.mem_dim)) + ) + + self.no_object_pointer = ms.Parameter(mint.zeros((1, self.hidden_dim))) + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = mint.nn.Conv2d(1, 1, kernel_size=4, stride=4) + # a feedforward layer on SAM output tokens to turn them into object pointers + self.object_pointer_proj = Sam2VideoFeedForward(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + + if self.config.enable_temporal_pos_encoding_for_object_pointers: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.temporal_positional_encoding_projection_layer = mint.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.temporal_positional_encoding_projection_layer = mint.nn.Identity() + + self.occlusion_spatial_embedding_parameter = None # compatibility with Sam2 + if config.enable_occlusion_spatial_embedding: + self.occlusion_spatial_embedding_parameter = ms.Parameter(mint.zeros((1, self.mem_dim))) + + self.post_init() + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self) -> ms.Tensor: + size = self.prompt_encoder.image_embedding_size + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = mint.ones((size), dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size[0] + x_embed = x_embed / size[1] + + positional_embedding = self.shared_image_embedding(mint.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + def get_image_embeddings( + self, + pixel_values: ms.Tensor, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> list[ms.Tensor]: + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + """ + batch_size = pixel_values.shape[0] + feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + return image_embeddings + + def get_prompt_embeddings( + self, + input_points: Optional[ms.Tensor] = None, + input_labels: Optional[ms.Tensor] = None, + input_boxes: Optional[ms.Tensor] = None, + input_masks: Optional[ms.Tensor] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`ms.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`ms.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`ms.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`ms.Tensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + def construct( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: Optional[int] = None, + frame: Optional[ms.Tensor] = None, + reverse: bool = False, + ) -> Sam2VideoSegmentationOutput: + r""" + inference_session (`Sam2VideoInferenceSession`): + The video inference session object. + frame_idx (`int`, *optional*): + The index of the frame on which to run inference. No need to provide when inferring + on a new streamed frame. + frame (`ms.Tensor`, *optional*): + The frame to process. Provide when streaming. + reverse (`bool`, *optional*, defaults to `False`): + Whether to propagate in reverse. + """ + if frame is not None: + frame_idx = inference_session.add_new_frame(frame, frame_idx) + + if frame is not None and inference_session.get_obj_num() == 0: + raise ValueError("No objects are provided for tracking; please add inputs first.") + + num_objects = inference_session.get_obj_num() + pred_masks_per_obj = [None] * num_objects + # Note: We avoid batched inference here because per-object inputs (clicks/masks) + # can differ across objects. + for obj_idx in range(num_objects): + obj_id = inference_session.obj_idx_to_id(obj_idx) + has_new_inputs = obj_id in inference_session.obj_with_new_inputs + has_cond_output = frame_idx in inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + # If this object has no new inputs and this frame already has a + # conditioning output, reuse the cached masks instead of recomputing. + if (not has_new_inputs) and has_cond_output: + pred_masks = inference_session.get_output(obj_idx, frame_idx, "pred_masks", is_conditioning_frame=True) + is_init_cond_frame = True + else: + # Defaults when there are no new inputs + is_init_cond_frame = False + point_inputs = None + mask_inputs = None + + if has_new_inputs: + is_init_cond_frame = frame_idx not in inference_session.frames_tracked_per_obj[obj_idx] + if is_init_cond_frame: + reverse = False + point_inputs = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None) + mask_inputs = inference_session.mask_inputs_per_obj[obj_idx].get(frame_idx, None) + if point_inputs is not None or mask_inputs is not None: + inference_session.obj_with_new_inputs.remove(obj_id) + + current_out = self._run_single_frame_inference( + inference_session=inference_session, + obj_idx=obj_idx, + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + reverse=reverse, + run_mem_encoder=True, + streaming=frame is not None, + ) + inference_session.store_output( + obj_idx, frame_idx, output_value=current_out, is_conditioning_frame=is_init_cond_frame + ) + pred_masks = current_out["pred_masks"] + + pred_masks_per_obj[obj_idx] = pred_masks + if not is_init_cond_frame: + # only for tracked frames, not for initial conditioning frames + inference_session.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + if len(pred_masks_per_obj) > 1: + all_pred_masks = mint.cat(pred_masks_per_obj, dim=0) + else: + all_pred_masks = pred_masks_per_obj[0] + + return Sam2VideoSegmentationOutput(pred_masks=all_pred_masks, frame_idx=frame_idx) + + def get_image_features( + self, + pixel_values: ms.Tensor, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + list[ms.Tensor], + list[ms.Tensor], + Optional[tuple[ms.Tensor, ...]], + Optional[tuple[ms.Tensor, ...]], + ]: + r""" + Extract and preprocess image features using the vision encoder. + + Args: + pixel_values (`ms.Tensor`): + Input pixel values of shape `(batch_size, num_channels, height, width)`. + + Returns: + `tuple`: A tuple containing: + - feature_maps (`list[ms.Tensor]`): List of feature maps from different levels. + - feature_maps_position_embeddings (`list[ms.Tensor]`): List of positional embeddings for each feature level. + - vision_hidden_states (`tuple[ms.Tensor]`, *optional*): Hidden states from the vision encoder. + - vision_attentions (`tuple[ms.Tensor]`, *optional*): Attention weights from the vision encoder. + """ + vision_outputs: Sam2VideoVisionEncoderOutput = self.vision_encoder( + pixel_values, + **kwargs, + ) + + feature_maps = vision_outputs.fpn_hidden_states + feature_maps_position_embeddings = vision_outputs.fpn_position_encoding + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + feature_maps = list(feature_maps) + feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0]) + feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1]) + + # flatten NxCxHxW to HWxNxC + feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps] + feature_maps_position_embeddings = [ + feature_map_position_embedding.flatten(2).permute(2, 0, 1) + for feature_map_position_embedding in feature_maps_position_embeddings + ] + + return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions + + def _prepare_vision_features( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: int, + batch_size: int, + ) -> tuple[ms.Tensor, list[ms.Tensor]]: + """Prepare vision features for a frame.""" + + # Check if features are cached + if cached_features := inference_session.cache.get_vision_features(frame_idx): + vision_feats = cached_features["vision_feats"] + vision_pos_embeds = cached_features["vision_pos_embeds"] + else: + # Compute features using image encoder + image_batch = inference_session.get_frame(frame_idx).unsqueeze(0) # Add batch dimension + vision_feats, vision_pos_embeds, _, _ = self.get_image_features(image_batch) + # Cache features + inference_session.cache.cache_vision_features( + frame_idx, {"vision_feats": vision_feats, "vision_pos_embeds": vision_pos_embeds} + ) + + # Expand to batch size if needed + if batch_size > 1: + vision_feats = vision_feats.expand((batch_size, -1, -1, -1)) + vision_pos_embeds = [pe.expand(batch_size, -1, -1, -1) for pe in vision_pos_embeds] + + return vision_feats, vision_pos_embeds + + + def _single_frame_forward( + self, + pixel_values: Optional[ms.Tensor] = None, + input_points: Optional[ms.Tensor] = None, + input_labels: Optional[ms.Tensor] = None, + input_boxes: Optional[ms.Tensor] = None, + input_masks: Optional[ms.Tensor] = None, + image_embeddings: Optional[ms.Tensor] = None, + multimask_output: bool = True, + attention_similarity: Optional[ms.Tensor] = None, + target_embedding: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Sam2VideoImageSegmentationOutput: + """ + input_points (`ms.Tensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `mindspore` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`ms.Tensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`ms.Tensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `mindspore` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + input_masks (`ms.Tensor` of shape `(batch_size, image_size, image_size)`): + SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + image_embeddings (`ms.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `construct` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`ms.Tensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + target_embedding (`ms.Tensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048). + """ + if not ((pixel_values is None) ^ (image_embeddings is None)): + raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.") + if input_points is not None and input_boxes is not None: + if input_points.shape[1] != input_boxes.shape[1]: + raise ValueError( + f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}." + ) + elif input_points is not None: + num_objects = input_points.shape[1] + elif input_boxes is not None: + num_objects = input_boxes.shape[1] + elif input_masks is not None: + num_objects = input_masks.shape[1] + else: + num_objects = 1 + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features( + pixel_values, + **kwargs, + ) + + # add no memory embedding to the last feature map + feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding + + # reshape feature maps to the same shape as the backbone feature sizes + image_embeddings = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes) + ] + + if input_points is not None and input_labels is None: + input_labels = mint.ones_like(input_points[:, :, :, 0], dtype=ms.int32) + + if input_points is None and input_boxes is None: + # If no points are provide, pad with an empty point (with label -1) + input_points = mint.zeros( + (batch_size, 1, 1, 2), dtype=image_embeddings[-1].dtype + ) + input_labels = -mint.ones((batch_size, 1, 1), dtype=ms.int32) + + if input_masks is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size: + input_masks = F.interpolate( + input_masks.float(), + size=self.prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + ).to(input_masks.dtype) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + low_res_multimasks, iou_scores, sam_output_tokens, object_score_logits = self.mask_decoder( + image_embeddings=image_embeddings[-1], + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + high_resolution_features=image_embeddings[:-1], + attention_similarity=attention_similarity, + target_embedding=target_embedding, + **kwargs, + ) + + is_obj_appearing = object_score_logits > 0 + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = mint.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + high_res_multimasks = ( + F.interpolate( + low_res_multimasks.squeeze(1).float(), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + .unsqueeze(1) + .to(low_res_multimasks.dtype) + ) + sam_output_token = sam_output_tokens[:, :, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = mint.argmax(iou_scores, dim=-1) + batch_inds = mint.arange(batch_size) + object_batch_inds = mint.arange(num_objects) + low_res_masks = low_res_multimasks[batch_inds, object_batch_inds, best_iou_inds] + high_res_masks = high_res_multimasks[batch_inds, object_batch_inds, best_iou_inds] + if sam_output_tokens.shape[2] > 1: + sam_output_token = sam_output_tokens[batch_inds, object_batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks[:, :, 0], high_res_multimasks[:, :, 0] + # Extract object pointer from the SAM output token (with occlusion handling) + object_pointer = self.object_pointer_proj(sam_output_token) + lambda_is_obj_appearing = is_obj_appearing.to(object_pointer.dtype) + + object_pointer = lambda_is_obj_appearing * object_pointer + object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer + + return Sam2VideoImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_masks, + high_res_masks=high_res_masks, + object_pointer=object_pointer, + object_score_logits=object_score_logits, + image_embeddings=image_embeddings, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + ) + + def _use_mask_as_output( + self, + backbone_features: ms.Tensor, + high_res_features: list[ms.Tensor], + mask_inputs: ms.Tensor, + ) -> Sam2VideoImageSegmentationOutput: + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in forward above). + """ + # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.to(backbone_features[0].dtype) + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks.float(), + size=(high_res_masks.shape[-2] // 4, high_res_masks.shape[-1] // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ).to(backbone_features[0].dtype) + # a dummy IoU prediction of all 1's under mask input + iou_scores = mask_inputs.new_ones((mask_inputs.shape[0], 1)).to(backbone_features[0].dtype) + # produce an object pointer using the SAM decoder from the mask input + object_pointer = self._single_frame_forward( + input_masks=self.mask_downsample(mask_inputs_float.to(backbone_features[0].dtype)), + image_embeddings=high_res_features + [backbone_features], + ).object_pointer + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = mint.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.to(backbone_features[0].dtype) + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + object_pointer = lambda_is_obj_appearing * object_pointer + object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer + return Sam2VideoImageSegmentationOutput( + iou_scores=iou_scores, + pred_masks=low_res_masks, + high_res_masks=high_res_masks, + object_pointer=object_pointer, + object_score_logits=object_score_logits, + image_embeddings=high_res_features + [backbone_features], + ) + + def _gather_memory_frame_outputs( + self, + inference_session: Sam2VideoInferenceSession, + obj_idx: int, + frame_idx: int, + track_in_reverse_time: bool = False, + ) -> list[tuple[int, dict]]: + """ + Get memory frames from conditioning and non-conditioning outputs. + + Returns: + List of (relative_temporal_offset, output_data) tuples. + """ + temporal_positions_and_previous_outputs = [] + + # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints) + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + if not conditioning_outputs: + raise ValueError( + "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame" + ) + + # Store (temporal_position, output_data) tuples + temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()] + + # Add non-conditioning memory frames (up to self.num_maskmem - 1) + # These are typically frames tracked by the model without direct user input. + # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity. + for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1): + # relative_temporal_offset: how many frames before (or after if reversing) the current frame + if not track_in_reverse_time: + previous_frame_idx = frame_idx - relative_temporal_offset + else: + previous_frame_idx = frame_idx + relative_temporal_offset + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + previous_frame_idx, None + ) + + temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data)) + + return temporal_positions_and_previous_outputs + + def _build_memory_attention_inputs( + self, + temporal_positions_and_previous_outputs: list[tuple[int, dict]], + ) -> tuple[list[ms.Tensor], list[ms.Tensor]]: + """ + Concatenate memory features and positional embeddings from previous frames. + + Returns: + Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate). + """ + memories_to_concatenate = [] + memory_positional_embeddings_to_concatenate = [] + + for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs: + if prev_output_data is None: + continue # Skip if no output data for this temporal position (e.g., padding frames) + + # Load memory features + # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels) + memory_features = prev_output_data.get("maskmem_features") + if memory_features is None: + continue # Skip if maskmem_features is None + memories_to_concatenate.append(memory_features) + + # Spatial positional encoding + spatial_memory_pos_embed = prev_output_data.get("maskmem_pos_enc") + if spatial_memory_pos_embed is None: + continue # Skip if maskmem_pos_enc is None + + # Add temporal positional encoding + # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim) + combined_memory_pos_embed = ( + spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1] + ) + memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed) + + return memories_to_concatenate, memory_positional_embeddings_to_concatenate + + def _get_object_pointers( + self, + inference_session: Sam2VideoInferenceSession, + obj_idx: int, + frame_idx: int, + num_total_frames: int, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> tuple[list[int], list[ms.Tensor], int]: + """ + Get object pointers and their positional embeddings from past frames. + + Returns: + Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use). + """ + temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1 + + # Determine max object pointers to use + if streaming: + max_object_pointers_to_use = self.config.max_object_pointers_in_encoder + else: + max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder) + + temporal_offsets: list[int] = [] + pointer_tokens: list[ms.Tensor] = [] + + # Add object pointers from selected conditioning frames + # Optionally, only include pointers from past frames during evaluation + conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"] + eligible_conditioning_outputs = conditioning_outputs + if not self.training: + eligible_conditioning_outputs = { + temporal_idx: out + for temporal_idx, out in conditioning_outputs.items() + if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx) + } + + for temporal_idx, out_data in eligible_conditioning_outputs.items(): + temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier + temporal_offsets.append(temporal_difference) + pointer_tokens.append(out_data["object_pointer"]) + + # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1) + for t_diff_offset in range(1, max_object_pointers_to_use): + ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset + if ref_frame_idx < 0 or ( + not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames + ): + break # Stop if frame index is out of bounds + + # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU + out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get( + ref_frame_idx, None + ) + if out_data is not None: + temporal_offsets.append(t_diff_offset) + pointer_tokens.append(out_data["object_pointer"]) + + return temporal_offsets, pointer_tokens, max_object_pointers_to_use + + def _process_object_pointers( + self, + temporal_offsets: list[int], + pointer_tokens: list[ms.Tensor], + max_object_pointers_to_use: int, + batch_size: int, + num_channels: int, + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + Process object pointers and compute their positional embeddings. + + Returns: + Tuple of (object_pointers, object_pointers_pos_embed). + """ + if not pointer_tokens: + return None, None + + # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels) + object_pointers = mint.stack(pointer_tokens, dim=0) + + if self.config.enable_temporal_pos_encoding_for_object_pointers: + max_temporal_diff = float(max_object_pointers_to_use - 1) + # Determine dimensionality for temporal positional encoding of pointers + pointer_tpos_dim = num_channels + + # Normalize temporal differences before sine PE calculation + normalized_temporal_diffs = ( + ms.tensor(temporal_offsets, dtype=ms.float32) / max_temporal_diff + ) + sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype) + projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe) + object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand((-1, batch_size, self.mem_dim)) + else: + object_pointers_pos_embed = object_pointers.new_zeros( + (len(temporal_offsets), batch_size, self.mem_dim), dtype=object_pointers.dtype + ) + + if self.mem_dim < num_channels: + # If memory dimension is smaller, reshape/split pointers and repeat positional encoding + num_splits = num_channels // self.mem_dim + object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim) + object_pointers = object_pointers.permute(0, 2, 1, 3).flatten( + 0, 1 + ) # (SeqLen_ptr*num_splits, Batch, MemDim) + object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0) + + return object_pointers, object_pointers_pos_embed + + def _prepare_memory_conditioned_features( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: int, + obj_idx: int, + is_initial_conditioning_frame: bool, + current_vision_features: list[ms.Tensor], + current_vision_positional_embeddings: list[ms.Tensor], + num_total_frames: int, + track_in_reverse_time: bool = False, + streaming: bool = False, + ) -> ms.Tensor: + """ + Fuse current frame's visual features with memory from previous frames for enhanced object tracking. + + This method conditions the current frame's visual features on temporal memory from previous frames, + enabling consistent object tracking across video sequences. For initial conditioning frames, it uses + no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both + conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention. + + Args: + inference_session (`Sam2VideoInferenceSession`): + The video inference session object. + frame_idx (`int`): + Index of the current frame being processed. + obj_idx (`int`): + Index of the object being processed. + is_initial_conditioning_frame (`bool`): + Whether this is an initial conditioning frame with user inputs (True) or a subsequent + tracking frame (False). + current_vision_features (`ms.Tensor`): + Highest-level vision features of shape `(seq_len, batch_size, channels)`. + current_vision_positional_embeddings (`ms.Tensor`): + Positional embedding tensors corresponding to the highest-level vision features. + num_total_frames (`int`): + Total number of frames in the video sequence. + track_in_reverse_time (`bool`, *optional*, defaults to `False`): + Whether tracking is performed in reverse temporal order. + streaming (`bool`, *optional*, defaults to `False`): + Whether this is streaming inference mode. + + Returns: + `ms.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)` + suitable for input to the SAM decoder. + """ + # Get dimensions from the highest-level (lowest-resolution) feature map + batch_size = current_vision_features.shape[1] + num_channels = self.hidden_dim + height, width = self.backbone_feature_sizes[-1] + + # If memory is disabled (e.g., for single image SAM), return current features directly. + if self.num_maskmem == 0: + # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width) + # Assuming SeqLen = Height * Width for the last feature map + current_feature_map = current_vision_features.permute(1, 2, 0).view( + batch_size, num_channels, height, width + ) + return current_feature_map + + # Step 1: Handle initial conditioning frames + if is_initial_conditioning_frame: + # For initial conditioning frames, no prior memory is used directly in this block. + # If configured, directly add a learnable "no memory" embedding. + # current_vision_features has shape (SeqLen, Batch, Channels) + conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding + # Reshape to (Batch, Channels, Height, Width) + conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view( + batch_size, num_channels, height, width + ) + return conditioned_feature_map + + # Step 2: Get memory frames and concatenate their features + temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs( + inference_session, obj_idx, frame_idx, track_in_reverse_time + ) + + memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs( + temporal_positions_and_previous_outputs + ) + + # Step 3: Get and process object pointers + temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers( + inference_session, obj_idx, frame_idx, num_total_frames, track_in_reverse_time, streaming + ) + + num_object_pointer_tokens = 0 + if pointer_tokens: + object_pointers, object_pointers_pos_embed = self._process_object_pointers( + temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels + ) + + if object_pointers is not None: + memories_to_concatenate.append(object_pointers) + memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed) + num_object_pointer_tokens = object_pointers.shape[0] + + # Step 4: Concatenate all retrieved memories and their positional embeddings + combined_memory = mint.cat(memories_to_concatenate, dim=0) + combined_memory_positional_embeddings = mint.cat(memory_positional_embeddings_to_concatenate, dim=0) + + # Step 5: Forward through the memory attention mechanism + conditioned_feature_map_flat = self.memory_attention( + current_vision_features=current_vision_features, + current_vision_position_embeddings=current_vision_positional_embeddings, + memory=combined_memory, + memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API + num_object_pointer_tokens=num_object_pointer_tokens, + ) + + # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width) + conditioned_feature_map = ( + conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width) + ) + return conditioned_feature_map + + def _use_multimask(self, is_init_cond_frame: bool, point_inputs: Optional[dict]) -> bool: + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].shape[2] + multimask_output = ( + self.config.multimask_output_in_sam + and (is_init_cond_frame or self.config.multimask_output_for_tracking) + and (self.config.multimask_min_pt_num <= num_pts <= self.config.multimask_max_pt_num) + ) + return multimask_output + + def _run_single_frame_inference( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: int, + obj_idx: int, + batch_size: int, + is_init_cond_frame: bool, + point_inputs: Optional[ms.Tensor], + mask_inputs: Optional[ms.Tensor], + reverse: bool, + run_mem_encoder: bool, + prev_sam_mask_logits: Optional[ms.Tensor] = None, + streaming: bool = False, + ) -> dict[str, Any]: + """ + Perform a single tracking step for video object segmentation. + + Args: + inference_session (`Sam2VideoInferenceSession`): + The video inference session object. + frame_idx (`int`): + Index of the current frame. + obj_idx (`int`): + Index of the current object. + batch_size (`int`): + Batch size of the current frame. + is_init_cond_frame (`bool`): + Whether this is an initial conditioning frame with user inputs. + point_inputs (`dict`, *optional*): + Point prompt inputs for the current frame. + mask_inputs (`ms.Tensor`, *optional*): + Mask prompt inputs for the current frame. + reverse (`bool`, *optional*, defaults to `False`): + Whether to track in reverse time order. + run_mem_encoder (`bool`, *optional*, defaults to `True`): + Whether to run the memory encoder on predicted masks. + prev_sam_mask_logits (`ms.Tensor`, *optional*): + Previously predicted SAM mask logits that can be fed with new clicks. + streaming (`bool`, *optional*, defaults to `False`): + Whether this is streaming inference. + + Returns: + `dict`: Dictionary containing the tracking results for the current frame, including: + - pred_masks: Predicted low-resolution masks. + - object_pointer: Object pointer for memory. + - object_score_logits: Object score logits (inference only). + - maskmem_features: Memory features for future frames. + - maskmem_pos_enc: Memory positional encodings. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds = self._prepare_vision_features( + inference_session, frame_idx, batch_size + ) + # point and mask should not appear as input simultaneously on the same frame + if point_inputs is not None and mask_inputs is not None: + raise ValueError( + "point_inputs and mask_inputs should not appear as input simultaneously on the same frame" + ) + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.shape[1], x.shape[2], *s) + for x, s in zip(current_vision_feats[:-1], self.backbone_feature_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None: + # We directly output the mask input (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *self.backbone_feature_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + inference_session=inference_session, + frame_idx=frame_idx, + obj_idx=obj_idx, + is_initial_conditioning_frame=is_init_cond_frame, + current_vision_features=current_vision_feats[-1], + current_vision_positional_embeddings=current_vision_pos_embeds[-1], + num_total_frames=inference_session.num_frames, + track_in_reverse_time=reverse, + streaming=streaming, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._single_frame_forward( + pixel_values=None, # Vision features already computed + input_points=point_inputs["point_coords"] if point_inputs is not None else None, + input_labels=point_inputs["point_labels"] if point_inputs is not None else None, + input_masks=mask_inputs, + image_embeddings=high_res_features + [pix_feat], + multimask_output=multimask_output, + ) + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (which will be used to condition vision features in future frames) + maskmem_features = None + maskmem_pos_enc = None + if run_mem_encoder and self.num_maskmem > 0: + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats[-1], + pred_masks_high_res=sam_outputs.high_res_masks, + object_score_logits=sam_outputs.object_score_logits, + is_mask_from_pts=(point_inputs is not None or mask_inputs is not None), + ) + + current_out = { + "pred_masks": sam_outputs.pred_masks, + "object_pointer": sam_outputs.object_pointer, + "maskmem_features": maskmem_features if maskmem_features is not None else None, + "maskmem_pos_enc": maskmem_pos_enc, + } + if not self.training: + current_out["object_score_logits"] = sam_outputs.object_score_logits + + return current_out + + def _encode_new_memory( + self, + current_vision_feats: ms.Tensor, + pred_masks_high_res: ms.Tensor, + object_score_logits: ms.Tensor, + is_mask_from_pts: bool, + ) -> tuple[ms.Tensor, list[ms.Tensor]]: + """Encode the current image and its prediction into a memory feature.""" + batch_size = current_vision_feats.shape[1] # batch size on this frame + channels = self.hidden_dim + height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width) + if is_mask_from_pts and not self.training: + # binarize the mask logits + mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype) + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = mint.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc + mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc + + maskmem_features, maskmem_pos_enc = self.memory_encoder( + pix_feat, + mask_for_mem, + ) + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.occlusion_spatial_embedding_parameter is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[ + ..., None, None + ].expand(maskmem_features.shape) + + # for consistency with the original implementation + maskmem_features = maskmem_features.flatten(2).permute(2, 0, 1) + maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype).flatten(2).permute(2, 0, 1) + + return maskmem_features, maskmem_pos_enc + + + def propagate_in_video_iterator( + self, + inference_session: Sam2VideoInferenceSession, + start_frame_idx: Optional[int] = None, + max_frame_num_to_track: Optional[int] = None, + reverse: bool = False, + ) -> Iterator[Sam2VideoSegmentationOutput]: + r""" + inference_session (`Sam2VideoInferenceSession`): + The video inference session object. + start_frame_idx (`int`, *optional*): + The starting frame index for propagation. + Need to be provided if `forward` hasn't been called on new inputs yet. + If not provided, the starting frame index will be the earliest frame with input points. + max_frame_num_to_track (`int`, *optional*): + The maximum number of frames to track. + reverse (`bool`, *optional*, defaults to `False`): + Whether to propagate in reverse. + """ + num_frames = inference_session.num_frames + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + frames_with_inputs = [ + frame_idx + for obj_output_dict in inference_session.output_dict_per_obj.values() + for frame_idx in obj_output_dict["cond_frame_outputs"] + ] + if not frames_with_inputs: + raise ValueError( + "Cannot determine the starting frame index; please specify it manually, or run inference on a frame with inputs first." + ) + start_frame_idx = min(frames_with_inputs) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + sam2_video_output = self(inference_session, frame_idx=frame_idx, reverse=reverse) + yield sam2_video_output + + +__all__ = ["Sam2VideoModel", "Sam2VideoInferenceSession", "Sam2VideoPreTrainedModel"] diff --git a/mindone/transformers/models/sam2_video/processing_sam2_video.py b/mindone/transformers/models/sam2_video/processing_sam2_video.py new file mode 100644 index 0000000000..f231d647ac --- /dev/null +++ b/mindone/transformers/models/sam2_video/processing_sam2_video.py @@ -0,0 +1,799 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam2_video/modular_sam2_video.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam2_video.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from typing import Optional, Union + +import numpy as np +import mindspore as ms +from mindspore import mint + +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import BatchEncoding +from ...utils import TensorType + +from ...video_utils import VideoInput +from .modeling_sam2_video import Sam2VideoInferenceSession + + +class Sam2VideoProcessor(ProcessorMixin): + r""" + Constructs a SAM2 processor which wraps a SAM2 image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`Sam2VideoProcessor`] offers all the functionalities of [`Sam2ImageProcessorFast`] and [`Sam2VideoProcessor`]. See the docstring of + [`~Sam2ImageProcessorFast.__call__`] and [`~Sam2VideoProcessor.__call__`] for more information. + + Args: + image_processor (`Sam2ImageProcessorFast`): + An instance of [`Sam2ImageProcessorFast`]. + video_processor (`Sam2VideoVideoProcessor`): + An instance of [`Sam2VideoVideoProcessor`]. + target_size (`int`, *optional*): + The target size (target_size, target_size) to which the image will be resized. + point_pad_value (`int`, *optional*, defaults to -10): + The value used for padding input points. + """ + attributes = ["image_processor", "video_processor"] + image_processor_class = "Sam2ImageProcessorFast" + video_processor_class = "Sam2VideoVideoProcessor" + + def __init__( + self, image_processor, video_processor, target_size: Optional[int] = None, point_pad_value: int = -10, **kwargs + ): + super().__init__(image_processor, video_processor, **kwargs) + self.point_pad_value = point_pad_value + self.target_size = target_size if target_size is not None else self.image_processor.size["height"] + + def __call__( + self, + images: Optional[ImageInput] = None, + segmentation_maps: Optional[ImageInput] = None, + input_points: Optional[Union[list[list[list[list[float]]]], ms.Tensor]] = None, + input_labels: Optional[Union[list[list[list[int]]], ms.Tensor]] = None, + input_boxes: Optional[Union[list[list[list[float]]], ms.Tensor]] = None, + original_sizes: Optional[Union[list[list[float]], ms.Tensor]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchEncoding: + r""" + This method uses [`Sam2VideoImageProcessorFast.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + + Args: + images (`ImageInput`, *optional*): + The image(s) to process. + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to process. + input_points (`list[list[list[list[float]]]]`, `ms.Tensor`, *optional*): + The points to add to the frame. + input_labels (`list[list[list[int]]]`, `ms.Tensor`, *optional*): + The labels for the points. + input_boxes (`list[list[list[float]]]`, `ms.Tensor`, *optional*): + The bounding boxes to add to the frame. + original_sizes (`list[list[float]]`, `ms.Tensor`, *optional*): + The original sizes of the images. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. + **kwargs: + Additional keyword arguments to pass to the image processor. + + Returns: + A [`BatchEncoding`] with the following fields: + - `pixel_values` (`ms.Tensor`): The processed image(s). + - `original_sizes` (`list[list[float]]`): The original sizes of the images. + - `reshaped_input_sizes` (`ms.Tensor`): The reshaped input sizes of the images. + - `labels` (`ms.Tensor`): The processed segmentation maps (if provided). + - `input_points` (`ms.Tensor`): The processed points. + - `input_labels` (`ms.Tensor`): The processed labels. + - `input_boxes` (`ms.Tensor`): The processed bounding boxes. + """ + if images is not None: + encoding_image_processor = self.image_processor( + images, + segmentation_maps=segmentation_maps, + return_tensors=return_tensors, + **kwargs, + ) + elif original_sizes is not None: + if isinstance(original_sizes, ms.Tensor): + original_sizes = original_sizes.cpu().tolist() + encoding_image_processor = BatchEncoding({"original_sizes": original_sizes}, tensor_type="np") + else: + raise ValueError("Either images or original_sizes must be provided") + + # pop arguments that are not used in the forward but used nevertheless + original_sizes = encoding_image_processor["original_sizes"] + # Check original_sizes is of length 1 or len(images) + if images is not None and len(original_sizes) != 1 and len(original_sizes) != len(images): + raise ValueError( + "original_sizes must be of length 1 or len(images). If you are passing a single image, you must pass a single original_size." + ) + + # Process input points, labels, and boxes if provided + if input_points is not None or input_labels is not None or input_boxes is not None: + # Validate and convert inputs to standardized format + processed_points = self._validate_single_input( + input_points, + expected_depth=4, + input_name="points", + expected_format="[image level, object level, point level, point coordinates]", + expected_coord_size=2, + ) + processed_labels = self._validate_single_input( + input_labels, + expected_depth=3, + input_name="labels", + expected_format="[image level, object level, point level]", + ) + processed_boxes = self._validate_single_input( + input_boxes, + expected_depth=3, + input_name="boxes", + expected_format="[image level, box level, box coordinates]", + expected_coord_size=4, + ) + + # Get padding requirements for all inputs + if processed_points is not None: + points_max_dims = self._get_nested_dimensions(processed_points)[:3] + if processed_labels is not None: + labels_max_dims = self._get_nested_dimensions(processed_labels)[:3] + if processed_boxes is not None: + boxes_max_dims = self._get_nested_dimensions(processed_boxes)[:2] + + # Ensure points and labels have consistent dimensions + if processed_points is not None and processed_labels is not None: + if points_max_dims != labels_max_dims: + raise ValueError( + "Input points and labels have inconsistent dimensions. Please ensure they have the same dimensions." + ) + + # Check that boxes don't need padding (model limitation) + if processed_boxes is not None and len(processed_boxes) >= 2: + if any(len(img_boxes) < boxes_max_dims[1] for img_boxes in processed_boxes): + raise ValueError( + "Input boxes have inconsistent dimensions that would require padding, " + "but boxes cannot be padded due to model limitations. " + "Please ensure all images have the same number of boxes." + ) + + # Pad and normalize all inputs to final tensor format + if processed_points is not None: + padded_points = self._pad_nested_list(processed_points, points_max_dims + [2]) + final_points = ms.tensor(padded_points, dtype=ms.float32) + self._normalize_tensor_coordinates(final_points, original_sizes, preserve_padding=True) + encoding_image_processor.update({"input_points": final_points}) + + if processed_labels is not None: + padded_labels = self._pad_nested_list(processed_labels, labels_max_dims) + final_labels = ms.tensor(padded_labels, dtype=ms.int64) + encoding_image_processor.update({"input_labels": final_labels}) + + if processed_boxes is not None: + final_boxes = ms.tensor(processed_boxes, dtype=ms.float32) + self._normalize_tensor_coordinates(final_boxes, original_sizes, is_bounding_box=True) + encoding_image_processor.update({"input_boxes": final_boxes}) + + return encoding_image_processor + + def _normalize_coordinates( + self, target_size: int, coords: "ms.Tensor", original_size, is_bounding_box=False + ) -> "ms.Tensor": + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. + + Args: + target_size (`int`): + The target size of the image. + coords (`ms.Tensor`): + The coordinates to be normalized. + original_size (`tuple`): + The original size of the image. + is_bounding_box (`bool`, *optional*, defaults to `False`): + Whether the coordinates are bounding boxes. + """ + old_h, old_w = original_size + new_h, new_w = target_size, target_size + coords = coords.clone().float() + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _convert_to_nested_list(self, data, expected_depth, current_depth=0): + """ + Recursively convert various input formats (tensors, numpy arrays, lists) to nested lists. + + Args: + data: Input data in any format + expected_depth: Expected nesting depth + current_depth: Current depth in recursion + + Returns: + Nested list representation of the data + """ + if data is None: + return None + + # Convert tensor/numpy to list if we're at a leaf level or if it's a multi-dimensional array + if isinstance(data, ms.Tensor): # Mindspore tensor + if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small tensor + return data.numpy().tolist() + else: + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, np.ndarray): # NumPy array + if current_depth == expected_depth - 2 or len(data.shape) <= 2: # At coordinate level or small array + return data.tolist() + else: + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, list): + if current_depth == expected_depth: + # We've reached the expected depth, return as is + return data + else: + # Continue recursion + return [self._convert_to_nested_list(item, expected_depth, current_depth + 1) for item in data] + elif isinstance(data, (int, float)): + return data + else: + raise TypeError(f"Unsupported data type: {type(data)}") + + def _get_nested_dimensions(self, nested_list, max_dims=None): + """ + Get the maximum dimensions at each level of nesting. + + Args: + nested_list (`list`): + Nested list structure. + max_dims (`list`, *optional*): + Current maximum dimensions (for recursion). + + Returns: + `list`: A list of maximum dimensions for each nesting level. + """ + if max_dims is None: + max_dims = [] + + if not isinstance(nested_list, list): + return max_dims + + if len(max_dims) == 0: + max_dims.append(len(nested_list)) + else: + max_dims[0] = max(max_dims[0], len(nested_list)) + + if len(nested_list) > 0: + for item in nested_list: + if isinstance(item, list): + sub_dims = self._get_nested_dimensions(item) + # Merge sub_dims into max_dims + for i, dim in enumerate(sub_dims): + if i + 1 >= len(max_dims): + max_dims.append(dim) + else: + max_dims[i + 1] = max(max_dims[i + 1], dim) + + return max_dims + + def _pad_nested_list(self, nested_list, target_dims, current_level=0, pad_value=None): + """ + Recursively pad a nested list to match target dimensions. + + Args: + nested_list (`list`): + Nested list to pad. + target_dims (`list`): + Target dimensions for each level. + current_level (`int`, *optional*, defaults to 0): + Current nesting level. + pad_value (`int`, *optional*): + Value to use for padding. + + Returns: + `list`: The padded nested list. + """ + if pad_value is None: + pad_value = self.point_pad_value + + if current_level >= len(target_dims): + return nested_list + + # Ensure we have a list + if not isinstance(nested_list, list): + nested_list = [nested_list] + + # Pad current level + current_size = len(nested_list) + target_size = target_dims[current_level] + + # Pad with appropriate values + if current_level == len(target_dims) - 1: + # At the coordinate level, pad with pad_value + nested_list.extend([pad_value] * (target_size - current_size)) + else: + # At higher levels, pad with nested structures + if current_size > 0: + # Create appropriately sized template + if current_level < len(target_dims) - 2: + # For non-coordinate levels, create empty nested structure + template_dims = target_dims[current_level + 1 :] + template = self._create_empty_nested_structure(template_dims, pad_value) + else: + # For coordinate level, create list of pad_values + template = [pad_value] * target_dims[current_level + 1] + + nested_list.extend([deepcopy(template) for _ in range(target_size - current_size)]) + else: + # Create from scratch + template_dims = target_dims[current_level + 1 :] + template = self._create_empty_nested_structure(template_dims, pad_value) + nested_list.extend([deepcopy(template) for _ in range(target_size)]) + + # Recursively pad sublists + if current_level < len(target_dims) - 1: + for i in range(len(nested_list)): + if isinstance(nested_list[i], list): + nested_list[i] = self._pad_nested_list(nested_list[i], target_dims, current_level + 1, pad_value) + + return nested_list + + def _create_empty_nested_structure(self, dims, pad_value): + """ + Create an empty nested structure with given dimensions filled with pad_value. + + Args: + dims (`list`): + The dimensions of the nested structure. + pad_value (`int`): + The value to fill the structure with. + """ + if len(dims) == 1: + return [pad_value] * dims[0] + else: + return [self._create_empty_nested_structure(dims[1:], pad_value) for _ in range(dims[0])] + + def _get_nesting_level(self, input_list): + """ + Get the nesting level of a list structure. + + Args: + input_list (`list`): + The list to get the nesting level of. + """ + if isinstance(input_list, list): + if len(input_list) == 0: + return 1 + return 1 + self._get_nesting_level(input_list[0]) + elif isinstance(input_list, (np.ndarray, ms.Tensor)): + # For arrays/tensors, the nesting level is the number of dimensions + return len(input_list.shape) + return 0 + + def _validate_single_input( + self, + data: Union[ms.Tensor, np.ndarray, list], + expected_depth: int, + input_name: str, + expected_format: str, + expected_coord_size: Optional[int] = None, + ) -> list: + """ + Validate a single input by ensuring proper nesting and raising an error if the input is not valid. + + Args: + data (`ms.Tensor`, `np.ndarray`, or `list`): + Input data to process. + expected_depth (`int`): + Expected nesting depth. + input_name (`str`): + Name of the input for error messages. + expected_format (`str`): + The expected format of the input. + expected_coord_size (`int`, *optional*): + Expected coordinate size (2 for points, 4 for boxes, None for labels). + . + """ + if data is None: + return None + + # Handle tensors and numpy arrays first + if isinstance(data, (ms.Tensor, np.ndarray)): + # For tensors/arrays, we can directly check the number of dimensions + if data.ndim != expected_depth: + raise ValueError( + f"Input {input_name} must be a tensor/array with {expected_depth} dimensions. The expected nesting format is {expected_format}. Got {data.ndim} dimensions." + ) + elif expected_coord_size is not None: + if data.shape[-1] != expected_coord_size: + raise ValueError( + f"Input {input_name} must be a tensor/array with {expected_coord_size} as the last dimension, got {data.shape[-1]}." + ) + return self._convert_to_nested_list(data, expected_depth) + + # Handle nested lists + if isinstance(data, list): + current_depth = self._get_nesting_level(data) + if current_depth != expected_depth: + raise ValueError( + f"Input {input_name} must be a nested list with {expected_depth} levels. The expected nesting format is {expected_format}. Got {current_depth} levels." + ) + return self._convert_to_nested_list(data, expected_depth) + + def _normalize_tensor_coordinates(self, tensor, original_sizes, is_bounding_box=False, preserve_padding=False): + """ + Helper method to normalize coordinates in a tensor across multiple images. + + Args: + tensor (`ms.Tensor`): + Input tensor with coordinates. + original_sizes (`list`): + Original image sizes. + is_bounding_box (`bool`, *optional*, defaults to `False`): + Whether coordinates are bounding boxes. + preserve_padding (`bool`, *optional*, defaults to `False`): + Whether to preserve padding values (for points). + """ + if preserve_padding: + # For points: avoid normalizing pad values + mask = tensor != self.point_pad_value + coord_mask = mask.all(dim=-1, keepdim=True) + + for img_idx in range(len(original_sizes)): + if img_idx < tensor.shape[0]: + original_size = original_sizes[img_idx] if img_idx < len(original_sizes) else original_sizes[0] + normalized_coords = self._normalize_coordinates( + self.target_size, tensor[img_idx], original_size, is_bounding_box=is_bounding_box + ) + + if preserve_padding: + # Only update non-padded values + img_mask = coord_mask[img_idx] + tensor[img_idx] = mint.where( + img_mask.expand_as(tensor[img_idx]), normalized_coords, tensor[img_idx] + ) + else: + tensor[img_idx] = normalized_coords + + def post_process_masks( + self, + masks, + original_sizes, + mask_threshold=0.0, + binarize=True, + max_hole_area=0.0, + max_sprinkle_area=0.0, + apply_non_overlapping_constraints=False, + **kwargs, + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[ms.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[ms.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + mask_threshold (`float`, *optional*, defaults to 0.0): + Threshold for binarization and post-processing operations. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + max_hole_area (`float`, *optional*, defaults to 0.0): + The maximum area of a hole to fill. + max_sprinkle_area (`float`, *optional*, defaults to 0.0): + The maximum area of a sprinkle to fill. + apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`): + Whether to apply non-overlapping constraints to the masks. + + Returns: + (`ms.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + return self.image_processor.post_process_masks( + masks, + original_sizes, + mask_threshold, + binarize, + max_hole_area, + max_sprinkle_area, + apply_non_overlapping_constraints, + **kwargs, + ) + + def init_video_session( + self, + video: Optional[VideoInput] = None, + max_vision_features_cache_size: int = 1, + dtype: ms.dtype = ms.float32, + ): + """ + Initializes a video session for inference. + + Args: + video (`VideoInput`, *optional*): + The video to process. No need to provide when streaming. + max_vision_features_cache_size (`int`, *optional*, defaults to 1): + The maximum number of vision features to cache. + dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): + The mindspore dtype to use for the whole session. + """ + pixel_values_video = None + video_height = None + video_width = None + if video is not None: + processed_video = self.video_processor(videos=video,return_tensors="ms") + pixel_values_video = processed_video.pixel_values_videos[0] + video_height = processed_video.original_sizes[0][0] + video_width = processed_video.original_sizes[0][1] + inference_session = Sam2VideoInferenceSession( + video=pixel_values_video, + video_height=video_height, + video_width=video_width, + dtype=dtype, + max_vision_features_cache_size=max_vision_features_cache_size, + ) + return inference_session + + def add_inputs_to_inference_session( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: int, + obj_ids: Union[list[int], int], + input_points: Optional[Union[list[list[list[list[float]]]], ms.Tensor]] = None, + input_labels: Optional[Union[list[list[list[int]]], ms.Tensor]] = None, + input_boxes: Optional[Union[list[list[list[float]]], ms.Tensor]] = None, + input_masks: Optional[Union[np.ndarray, ms.Tensor, list[np.ndarray], list[ms.Tensor]]] = None, + original_size: Optional[tuple[int, int]] = None, + clear_old_inputs: bool = True, + ) -> Sam2VideoInferenceSession: + """ + Process new points, boxes, or masks for a video frame and add them to the inference session. + + Args: + inference_session (`Sam2VideoInferenceSession`): + The inference session for the video. + frame_idx (`int`): + The index of the frame to process. + obj_ids (`list[int]` or `int`): + The object ID(s) to associate with the points or box. + These can be any integers and can be reused later on to specify an object. + input_points (`list[list[list[list[float]]]]`, `ms.Tensor`, *optional*): + The points to add to the frame. + input_labels (`list[list[list[int]]]`, `ms.Tensor`, *optional*): + The labels for the points. + input_boxes (`list[list[list[float]]]`, `ms.Tensor`, *optional*): + The bounding boxes to add to the frame. + input_masks (`np.ndarray`, `ms.Tensor`, `list[np.ndarray]`, or `list[ms.Tensor]`, *optional*): + The mask(s) to add to the frame. + original_size (`tuple[int, int]`, *optional*): + The original size of the video. Provide when streaming. + clear_old_inputs (`bool`, *optional*, defaults to `True`): + Whether to clear old inputs for the object. + """ + + if isinstance(obj_ids, int): + obj_ids = [obj_ids] + + # Validate inputs + if (input_points is not None) != (input_labels is not None): + raise ValueError("points and labels must be provided together") + if input_points is None and input_boxes is None and input_masks is None: + raise ValueError("at least one of points, boxes, or masks must be provided as input") + if input_masks is not None and (input_points is not None or input_boxes is not None): + raise ValueError("masks cannot be provided together with points or boxes") + + if input_masks is not None: + return self.process_new_mask_for_video_frame(inference_session, frame_idx, obj_ids, input_masks) + else: + return self.process_new_points_or_boxes_for_video_frame( + inference_session, + frame_idx, + obj_ids, + input_points, + input_labels, + input_boxes, + original_size, + clear_old_inputs, + ) + + def process_new_points_or_boxes_for_video_frame( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: int, + obj_ids: list[int], + input_points: Optional[Union[list[list[list[list[float]]]], ms.Tensor]] = None, + input_labels: Optional[Union[list[list[list[int]]], ms.Tensor]] = None, + input_boxes: Optional[Union[list[list[list[float]]], ms.Tensor]] = None, + original_size: Optional[tuple[int, int]] = None, + clear_old_inputs: bool = True, + ) -> Sam2VideoInferenceSession: + """ + Process new points or boxes for a video frame and add them to the inference session. + + Args: + inference_session (`Sam2VideoInferenceSession`): + The inference session for the video. + frame_idx (`int`): + The index of the frame to process. + obj_ids (`list[int]`): + The object ID(s) to associate with the points or box. + These can be any integers and can be reused later on to specify an object. + input_points (`list[list[list[list[float]]]]`, `ms.Tensor`, *optional*): + The points to add to the frame. + input_labels (`list[list[list[int]]]`, `ms.Tensor`, *optional*): + The labels for the points. + input_boxes (`list[list[list[float]]]`, `ms.Tensor`, *optional*): + The bounding boxes to add to the frame. + original_size (`tuple[int, int]`, *optional*): + The original size of the video. Provide when streaming. + clear_old_inputs (`bool`, *optional*, defaults to `True`): + Whether to clear old inputs for the object. + """ + if original_size is not None: + inference_session.video_height = original_size[0] + inference_session.video_width = original_size[1] + elif inference_session.video_height is None or inference_session.video_width is None: + raise ValueError("original_size must be provided when adding points or boxes on a first streamed frame") + + original_sizes = [[inference_session.video_height, inference_session.video_width]] + + encoded_inputs = self( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + original_sizes=original_sizes, + return_tensors="ms", + ) + input_points = encoded_inputs.get("input_points", None) + input_labels = encoded_inputs.get("input_labels", None) + input_boxes = encoded_inputs.get("input_boxes", None) + + if input_points is not None: + if input_points.shape[1] != len(obj_ids): + raise ValueError( + f"Number of object ids ({len(obj_ids)}) does not match number of points ({input_points.shape[1]})" + ) + else: + input_points = mint.zeros(1, len(obj_ids), 0, 2, dtype=ms.float32) + if input_labels is not None: + if input_labels.shape[1] != len(obj_ids): + raise ValueError( + f"Number of object ids ({len(obj_ids)}) does not match number of labels ({input_labels.shape[1]})" + ) + else: + input_labels = mint.zeros(1, len(obj_ids), 0, dtype=ms.int32) + if input_boxes is not None: + if input_boxes.shape[1] != len(obj_ids): + raise ValueError( + f"Number of object ids ({len(obj_ids)}) does not match number of boxes ({input_boxes.shape[1]})" + ) + + if input_boxes is not None: + if not clear_old_inputs: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + box_coords = input_boxes.reshape(1, -1, 2, 2) + box_labels = ms.tensor([2, 3], dtype=ms.int32).repeat(1, box_coords.shape[1], 1) + input_points = mint.cat([box_coords, input_points], dim=2) + input_labels = mint.cat([box_labels, input_labels], dim=2) + + for obj_id, idx in zip(obj_ids, range(len(obj_ids))): + obj_idx = inference_session.obj_id_to_idx(obj_id) + input_points_for_obj = input_points[:, idx, :, :].unsqueeze(1) + input_labels_for_obj = input_labels[:, idx, :].unsqueeze(1) + # Handle existing points + if not clear_old_inputs: + existing_points = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None) + if existing_points is not None: + # Concatenate with existing points + input_points_for_obj = mint.cat( + [existing_points["point_coords"], input_points_for_obj], dim=2 + ) + input_labels_for_obj = mint.cat( + [existing_points["point_labels"], input_labels_for_obj], dim=2 + ) + point_inputs = { + "point_coords": input_points_for_obj, + "point_labels": input_labels_for_obj, + } + + inference_session.add_point_inputs(obj_idx, frame_idx, point_inputs) + inference_session.remove_mask_inputs(obj_idx, frame_idx) # Clear any mask inputs + + inference_session.obj_with_new_inputs = obj_ids + + def process_new_mask_for_video_frame( + self, + inference_session: Sam2VideoInferenceSession, + frame_idx: int, + obj_ids: list[int], + input_masks: Union[np.ndarray, ms.Tensor, list[np.ndarray], list[ms.Tensor]], + ): + """ + Add new mask to a frame and add them to the inference session. + + Args: + inference_session (`Sam2VideoInferenceSession`): + The inference session for the video. + frame_idx (`int`): + The index of the frame to process. + obj_ids (`list[int]`): + The object ID(s) to associate with the mask. + These can be any integers and can be reused later on to specify an object. + input_masks (`np.ndarray`, `ms.Tensor`, `list[np.ndarray]`, or `list[ms.Tensor]`): + The mask(s) to add to the frame. + """ + if not isinstance(input_masks, list): + input_masks = [input_masks] + if len(input_masks) != len(obj_ids): + raise ValueError( + f"Number of object ids ({len(obj_ids)}) does not match number of masks ({len(input_masks)})" + ) + + for obj_id, mask in zip(obj_ids, input_masks): + obj_idx = inference_session.obj_id_to_idx(obj_id) + + + # Process mask + if not isinstance(mask, ms.Tensor): + mask = ms.tensor(mask, dtype=ms.bool) + nb_dim = mask.dim() + if nb_dim > 4 or nb_dim < 2: + raise ValueError(f"Mask has an unsupported number of dimensions: {nb_dim}") + for i in range(4 - nb_dim): + mask = mask.unsqueeze(0) + + mask_H, mask_W = mask.shape[-2:] + mask_inputs_orig = mask + mask_inputs_orig = mask_inputs_orig.float() + + # Resize mask if needed + if mask_H != self.target_size or mask_W != self.target_size: + mask_inputs = mint.nn.functional.interpolate( + mask_inputs_orig, + size=(self.target_size, self.target_size), + align_corners=False, + mode="bilinear", + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + inference_session.add_mask_inputs(obj_idx, frame_idx, mask_inputs) + inference_session.remove_point_inputs(obj_idx, frame_idx) # Clear any point inputs + + inference_session.obj_with_new_inputs = obj_ids + + +__all__ = ["Sam2VideoProcessor"] diff --git a/mindone/transformers/models/sam2_video/video_processing_sam2_video.py b/mindone/transformers/models/sam2_video/video_processing_sam2_video.py new file mode 100644 index 0000000000..3eaaf6669b --- /dev/null +++ b/mindone/transformers/models/sam2_video/video_processing_sam2_video.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for SAM2.""" + +from typing import Optional, Union + +import numpy as np +import mindspore as ms +import mindspore.mint as mint +import mindspore.mint.nn.functional as F + +from ...image_processing_utils import BatchFeature +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling, SizeDict +from ...utils import TensorType +from ...video_processing_utils import BaseVideoProcessor + + +class Sam2VideoVideoProcessor(BaseVideoProcessor): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 1024, "width": 1024} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + + def _preprocess( + self, + videos: list["ms.Tensor"], + size: SizeDict, + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + original_sizes = [video.shape[-2:] for video in videos] + reshaped_input_sizes = [(size.height, size.width) for _ in range(len(videos))] + batch_feature = super()._preprocess(videos, size=size, return_tensors=return_tensors, **kwargs) + batch_feature = BatchFeature( + data={ + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + **batch_feature.data, + }, + tensor_type=return_tensors, + ) + return batch_feature + + def post_process_masks( + self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None + ): + """ + Remove padding and upscale masks to the original image size. + + Args: + masks (`Union[List[ms.Tensor], List[np.ndarray]]`): + Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format. + original_sizes (`Union[ms.Tensor, List[Tuple[int,int]]]`): + The original sizes of each image before it was resized to the model's expected input shape, in (height, + width) format. + reshaped_input_sizes (`Union[ms.Tensor, List[Tuple[int,int]]]`): + The size of each image as it is fed to the model, in (height, width) format. Used to remove padding. + mask_threshold (`float`, *optional*, defaults to 0.0): + The threshold to use for binarizing the masks. + binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the masks. + pad_size (`int`, *optional*, defaults to `self.pad_size`): + The target size the images were padded to before being passed to the model. If None, the target size is + assumed to be the processor's `pad_size`. + Returns: + (`ms.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) + is given by original_size. + """ + pad_size = self.size if pad_size is None else pad_size + target_image_size = (pad_size["height"], pad_size["width"]) + if isinstance(original_sizes, (ms.Tensor, np.ndarray)): + original_sizes = original_sizes.tolist() + if isinstance(reshaped_input_sizes, (ms.Tensor, np.ndarray)): + reshaped_input_sizes = reshaped_input_sizes.tolist() + output_masks = [] + for i, original_size in enumerate(original_sizes): + if isinstance(masks[i], np.ndarray): + masks[i] = ms.from_numpy(masks[i]) + elif not isinstance(masks[i], ms.Tensor): + raise TypeError("Input masks should be a list of `ms.tensors` or a list of `np.ndarray`") + interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False) + interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]] + interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False) + if binarize: + interpolated_mask = interpolated_mask > mask_threshold + output_masks.append(interpolated_mask) + + return output_masks + + +__all__ = ["Sam2VideoVideoProcessor"]