77import mindspore as ms
88
99from segment_anything .build_sam import sam_model_registry
10- from segment_anything .dataset .transform import TransformPipeline , ImageNorm
11- from segment_anything .utils .transforms import ResizeLongestSide
10+ from segment_anything .dataset .transform import TransformPipeline , ImageNorm , ImageResizeAndPad
1211import matplotlib .pyplot as plt
1312import time
1413
@@ -29,65 +28,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
2928 print (f'{ self .name } cost time { self .end - self .start :.3f} ' )
3029
3130
32- class ImageResizeAndPad :
33-
34- def __init__ (self , target_size ):
35- """
36- Args:
37- target_size (int): target size of model input (1024 in sam)
38- """
39- self .target_size = target_size
40- self .transform = ResizeLongestSide (target_size )
41-
42- def __call__ (self , result_dict ):
43- """
44- Resize input to the long size and then pad it to the model input size (1024*1024 in sam).
45- Pad masks and boxes to a fixed length for graph mode
46- Required keys: image, masks, boxes
47- Update keys: image, masks, boxes
48- Add keys:
49- origin_hw (np.array): array with shape (4), represents original image height, width
50- and resized height, width, respectively. This array record the trace of image shape transformation
51- and is used for visualization.
52- image_pad_area (Tuple): image padding area in h and w direction, in the format of
53- ((pad_h_left, pad_h_right), (pad_w_left, pad_w_right))
54- """
55-
56- image = result_dict ['image' ]
57- boxes = result_dict ['boxes' ]
58-
59- og_h , og_w , _ = image .shape
60- image = self .transform .apply_image (image )
61- resized_h , resized_w , _ = image .shape
62-
63- # Pad image and masks to the model input
64- h , w , c = image .shape
65- max_dim = max (h , w ) # long side length
66- assert max_dim == self .target_size
67- # pad 0 to the right and bottom side
68- pad_h = max_dim - h
69- pad_w = max_dim - w
70- img_padding = ((0 , pad_h ), (0 , pad_w ), (0 , 0 ))
71- image = np .pad (image , pad_width = img_padding , constant_values = 0 ) # (h, w, c)
72-
73- # Adjust bounding boxes
74- boxes = self .transform .apply_boxes (boxes , (og_h , og_w )).astype (np .float32 )
75-
76- result_dict ['origin_hw' ] = np .array ([og_h , og_w , resized_h , resized_w ], np .int32 ) # record image shape trace for visualization
77- result_dict ['image' ] = image
78- result_dict ['boxes' ] = boxes
79- result_dict ['image_pad_area' ] = img_padding [:2 ]
80-
81- return result_dict
82-
83-
8431def infer (args ):
8532 ms .context .set_context (mode = args .mode , device_target = args .device )
8633
8734 # Step1: data preparation
8835 with Timer ('preprocess' ):
8936 transform_list = [
90- ImageResizeAndPad (target_size = 1024 ),
37+ ImageResizeAndPad (target_size = 1024 , apply_mask = False ),
9138 ImageNorm (),
9239 ]
9340 transform_pipeline = TransformPipeline (transform_list )
@@ -99,6 +46,9 @@ def infer(args):
9946
10047 transformed = transform_pipeline (dict (image = image_np , boxes = boxes_np ))
10148 image , boxes , origin_hw = transformed ['image' ], transformed ['boxes' ], transformed ['origin_hw' ]
49+ # batch_size for speed test
50+ # image = ms.Tensor(np.expand_dims(image, 0).repeat(8, axis=0)) # b, 3, 1023
51+ # boxes = ms.Tensor(np.expand_dims(boxes, 0).repeat(8, axis=0)) # b, n, 4
10252 image = ms .Tensor (image ).unsqueeze (0 ) # b, 3, 1023
10353 boxes = ms .Tensor (boxes ).unsqueeze (0 ) # b, n, 4
10454
0 commit comments