In this document, we will give an example on how to write customized code in ATEK for SegmentAnything2 model. User is also encouraged to check out Demo_4 for end-to-end inference workflow.
sam2
is a image / video segmentation model designed to handle arbitrary objects in images.
To run inference on a SAM2
model, the input should contain the following data:
- An upright RGB image.
- Prompt information to guide the segmentation process, e.g. 2D object bounding boxes to specify segmentation regions.
In ATEK, users can customize preprocessing by simply adjusting the preprocessing configuration yaml file. For SAM2
's requirements listed above, we can adjust the following yaml fields accordingly (see config.yaml for example config file):
-
Setting the
selected
flags inprocessors
will pick RGB camera data, and object annotation data to include in preprocessing.processors: rgb: selected: true obb_gt: selected: true
-
Setting the following flags will transform the RGB image to upright position.
rgb: rotate_image_cw90deg: true
-
Setting the following value will automatically sync rgb image with groundtruth annotation data, with a tolerance window:
tolerance_ns: 10_000_000
Then user can simply run the following code to generate preprocessed WDS data that is suitable for sam2
model inference:
preprocessor = create_general_atek_preprocessor_from_conf(
conf=sam2_preprocessing_config,
raw_data_folder="/path/to/raw/data",
sequence_name="sequence_01",
output_wds_folder = "./output",
)
num_samples = preprocessor.process_all_samples(write_to_wds_flag = True, viz_flag = True)
To feed ATEK-preprocessed data into a sam2
model for inference, a ModelAdaptor
class is needed. We will show the core code below, while users are encouraged to checkout the source code for details.
class Sam2ModelAdaptor:
@staticmethod
def get_dict_key_mapping_all():
dict_key_mapping = {
"mfcd#camera-rgb+images": "image",
"gt_data": "gt_data",
}
return dict_key_mapping
def atek_to_sam2(self, data):
for atek_wds_sample in data:
sample = {}
# Add images
# from [1, C, H, W] to [H, W, C]
image_torch = atek_wds_sample["image"].clone().detach()
image_np = image_torch.squeeze(0).permute(1, 2, 0).numpy()
sample["image"] = image_np
# Select boxes as prompts
obb2_gt = atek_wds_sample["gt_data"]["obb2_gt"]["camera-rgb"]
bbox_ranges = obb2_gt["box_ranges"][
:, [0, 2, 1, 3]
] # xxyy -> xyxy
sample["boxes"] = bbox_ranges.numpy()
yield sample
def create_atek_dataloader_as_sam2(
urls: List[str],
batch_size: Optional[int] = None,
repeat_flag: bool = False,
shuffle_flag: bool = False,
num_workers: int = 0,
num_prompt_boxes: int = 5,
) -> torch.utils.data.DataLoader:
adaptor = Sam2ModelAdaptor(num_boxes=num_prompt_boxes)
wds_dataset = load_atek_wds_dataset(
urls,
batch_size=batch_size,
dict_key_mapping=Sam2ModelAdaptor.get_dict_key_mapping_all(),
data_transform_fn=pipelinefilter(adaptor.atek_to_sam2)(),
collation_fn=simple_list_collation_fn,
repeat_flag=repeat_flag,
shuffle_flag=shuffle_flag,
)
return torch.utils.data.DataLoader(
wds_dataset, batch_size=None, num_workers=num_workers, pin_memory=True
)
Within this class:
-
get_dict_key_mapping_all()
function returns a mapping from ATEK dictionary keys tosam2
dictionary keys. Here, since we only need the RGB image and the 2D bounding box information, we only need to map 2 keys. ATEK will automatically discard other key-value content in ATEK dict. -
atek_to_sam2
is the actual data transform function. The inputdata
is a generator of dictionaries, whose keys are already remapped byget_dict_key_mapping_all()
. We perform 2 operations in this data transform:- Reshape RGB image tensor from
[1, Channel, Height, Width]
to[Height, Width, Channel]
, store it insample
dict. - From
gt_data
dictionary, take the firstnum_box
2D bounding boxes for the current RGB image, re-order the box corners from[xmin, xmax, ymin, ymax]
to[xmin, ymin, xmax, ymax]
, and store them tosample
dict.
- Reshape RGB image tensor from
create_atek_dataloader_as_sam2
is a thin wrapper on top of the Sam2ModelAdaptor
class, which allows user to input the URLs of the WDS files, and return a PyTorch DataLoader object that produces data samples that can be directly used in SAM2 inference:
sam2_dataloader = create_atek_dataloader_as_sam2(tar_list)
first_sam2_sample = next(iter(sam2_dataloader))
print(f"Loading WDS into SAM2 format, each sample contains the following keys: {first_sam2_sample[0].keys()}")
With the created PyTorch DataLoader, user can run SAM2 inference easily with the following code:
# create SAM2 predictor
predictor = SAM2ImagePredictor(build_sam2(sam2_model_cfg, sam2_model_checkpoint))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
for sam_dict in sam2_dataloader:
# perform inference
predictor.set_image(sam_dict["image"])
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=sam_dict["boxes"],
multimask_output=False,
)
# Visualize results
...