Skip to content

Commit 8ae5cdf

Browse files
committed
reduce mask details
1 parent 5656bdd commit 8ae5cdf

File tree

5 files changed

+195
-5
lines changed

5 files changed

+195
-5
lines changed

example.ipynb

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
" **{\n",
2727
" \"dataset\": kitti,\n",
2828
" \"start_index\": 20,\n",
29-
" \"end_index\": 23,\n",
29+
" \"end_index\": 24,\n",
3030
" \"start_image_index_offset\": 3,\n",
3131
" \"cam_name\": \"cam2\",\n",
3232
" \"R\": 16,\n",
@@ -50,6 +50,30 @@
5050
"points2instances = InitInstancesMatrixProcessor().process(config, init_pcd)"
5151
]
5252
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"points2instances.shape"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"metadata": {},
66+
"outputs": [],
67+
"source": [
68+
"import copy\n",
69+
"\n",
70+
"from src.utils.pcd_utils import color_pcd_by_labels\n",
71+
"from src.utils.pcd_utils import visualize_pcd\n",
72+
"\n",
73+
"colored_pcd = color_pcd_by_labels(copy.deepcopy(init_pcd), points2instances[:, 0])\n",
74+
"visualize_pcd(colored_pcd)"
75+
]
76+
},
5377
{
5478
"cell_type": "code",
5579
"execution_count": null,
@@ -79,7 +103,8 @@
79103
"for processor in processors:\n",
80104
" pcd, points2instances = processor.process(config, pcd, points2instances)\n",
81105
"\n",
82-
"pcd_for_clustering = copy.deepcopy(pcd)"
106+
"pcd_for_clustering = copy.deepcopy(pcd)\n",
107+
"points2instances_pcd_for_clustering = copy.deepcopy(points2instances)"
83108
]
84109
},
85110
{
@@ -107,7 +132,7 @@
107132
"points = np.asarray(pcd.points)\n",
108133
"spatial_distance = cdist(points, points)\n",
109134
"\n",
110-
"dist, masks = sam_label_distance(points2instances, spatial_distance, 5, 10, 2)"
135+
"dist, masks = sam_label_distance(points2instances, spatial_distance, 3, 3, 5)"
111136
]
112137
},
113138
{
@@ -156,7 +181,7 @@
156181
"source": [
157182
"from src.services.normalized_cut_service import normalized_cut\n",
158183
"\n",
159-
"T = 0.07\n",
184+
"T = 0.02\n",
160185
"eigenval = 2\n",
161186
"\n",
162187
"clusters = normalized_cut(\n",
@@ -198,6 +223,18 @@
198223
"visualize_pcd(pcd_colored)"
199224
]
200225
},
226+
{
227+
"cell_type": "code",
228+
"execution_count": null,
229+
"metadata": {},
230+
"outputs": [],
231+
"source": [
232+
"from src.utils.pcd_utils import color_pcd_by_labels\n",
233+
"\n",
234+
"pcd_src_colored = color_pcd_by_labels(pcd_for_clustering, points2instances_pcd_for_clustering[:, 2])\n",
235+
"visualize_pcd(pcd_src_colored)"
236+
]
237+
},
201238
{
202239
"cell_type": "code",
203240
"execution_count": null,
Loading

src/services/preprocessing/init/instances_matrix.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
import zope.interface
1818

1919
from src.services.preprocessing.common.interface import IProcessor
20+
from src.utils.geometry_utils import calculate_area
2021
from src.utils.pcd_utils import get_subpcd
2122
from src.utils.pcd_utils import get_visible_points
23+
from src.utils.sam_mask_utils import find_intersection_mask
24+
from src.utils.sam_mask_utils import find_union_mask
2225

2326

2427
@zope.interface.implementer(IProcessor)
@@ -56,7 +59,8 @@ def build_points2instances_matrix(
5659
points2instances = np.zeros((N, end_image_index - start_image_index), dtype=int)
5760

5861
for view_id, view in enumerate(range(start_image_index, end_image_index)):
59-
masks = dataset.get_image_instances(cam_name, view)
62+
full_masks = dataset.get_image_instances(cam_name, view)
63+
masks = self.reduce_detail(full_masks)
6064
image_labels = self.masks_to_image(masks)
6165

6266
T = dataset.get_lidar_pose(view)
@@ -111,3 +115,52 @@ def masks_to_image(self, masks):
111115
for i, mask in enumerate(masks):
112116
image_labels[mask["segmentation"]] = i + 1
113117
return image_labels
118+
119+
def reduce_detail(self, masks, intersection_to_union_ratio_threshold=0.35):
120+
merged_mask = []
121+
merged_indices = []
122+
123+
for i in range(len(masks)):
124+
if i in merged_indices:
125+
continue
126+
127+
area_bbox_i = calculate_area(masks[i]['bbox'])
128+
129+
indices_merged_with_i = []
130+
for j in range(i + 1, len(masks)):
131+
if j in merged_indices:
132+
continue
133+
134+
area_bbox_j = calculate_area(masks[j]['bbox'])
135+
136+
intersection_mask = find_intersection_mask(masks[i], masks[j])
137+
if intersection_mask == None:
138+
continue
139+
area_intersection = intersection_mask['area']
140+
141+
area_bbox_intersection = calculate_area(intersection_mask['bbox'])
142+
area_bbox_union = area_bbox_i + area_bbox_j - area_bbox_intersection
143+
IU_ratio = area_bbox_intersection / area_bbox_union
144+
145+
if (IU_ratio >= intersection_to_union_ratio_threshold
146+
or area_intersection / masks[i]['area'] >= 0.6
147+
or area_intersection / masks[j]['area'] >= 0.6):
148+
masks[i] = find_union_mask(masks[i], masks[j])
149+
indices_merged_with_i.append(j)
150+
151+
if indices_merged_with_i:
152+
merged_mask.append(masks[i])
153+
154+
merged_indices.append(i)
155+
for ind in indices_merged_with_i:
156+
merged_indices.append(ind)
157+
158+
masks_result = []
159+
for ind, mask in enumerate(masks):
160+
if ind not in merged_indices:
161+
masks_result.append(mask)
162+
163+
for mask in merged_mask:
164+
masks_result.append(mask)
165+
166+
return masks_result

src/utils/geometry_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2023, Sofia Vivdich and Anastasiia Kornilova
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def find_intersection(bbox1, bbox2):
17+
x1_bbox1, y1_bbox1, w_bbox1, h_bbox1 = bbox1
18+
x1_bbox2, y1_bbox2, w_bbox2, h_bbox2 = bbox2
19+
20+
x2_bbox1 = x1_bbox1 + w_bbox1
21+
y2_bbox1 = y1_bbox1 + h_bbox1
22+
x2_bbox2 = x1_bbox2 + w_bbox2
23+
y2_bbox2 = y1_bbox2 + h_bbox2
24+
25+
if (x1_bbox1 > x2_bbox2 or x2_bbox1 < x1_bbox2 or y1_bbox1 > y2_bbox2 or y2_bbox1 < y1_bbox2):
26+
return None
27+
28+
x_left = max(x1_bbox1, x1_bbox2)
29+
x_right = min(x2_bbox1, x2_bbox2)
30+
y_top = max(y1_bbox1, y1_bbox2)
31+
y_bottom = min(y2_bbox1, y2_bbox2)
32+
33+
return x_left, y_top, x_right - x_left, y_bottom - y_top
34+
35+
36+
def find_union(bbox1, bbox2):
37+
x1_bbox1, y1_bbox1, w_bbox1, h_bbox1 = bbox1
38+
x1_bbox2, y1_bbox2, w_bbox2, h_bbox2 = bbox2
39+
40+
x2_bbox1 = x1_bbox1 + w_bbox1
41+
y2_bbox1 = y1_bbox1 + h_bbox1
42+
x2_bbox2 = x1_bbox2 + w_bbox2
43+
y2_bbox2 = y1_bbox2 + h_bbox2
44+
45+
x_left = min(x1_bbox1, x1_bbox2)
46+
y_top = min(y1_bbox1, y1_bbox2)
47+
x_right = max(x2_bbox1, x2_bbox2)
48+
y_bottom = max(y2_bbox1, y2_bbox2)
49+
50+
return x_left, y_top, x_right - x_left, y_bottom - y_top
51+
52+
53+
def calculate_area(bbox):
54+
w = bbox[2]
55+
h = bbox[3]
56+
return w * h

src/utils/sam_mask_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2023, Sofia Vivdich and Anastasiia Kornilova
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
17+
from src.utils.geometry_utils import find_intersection
18+
from src.utils.geometry_utils import find_union
19+
20+
21+
def find_intersection_mask(mask1, mask2):
22+
bbox = find_intersection(mask1['bbox'], mask2['bbox'])
23+
if bbox == None:
24+
return None
25+
segmentation = mask1['segmentation'] * mask2['segmentation']
26+
area = segmentation.sum()
27+
28+
intersection_mask = copy.deepcopy(mask1)
29+
intersection_mask['segmentation'] = segmentation
30+
intersection_mask['bbox'] = bbox
31+
intersection_mask['area'] = area
32+
return intersection_mask
33+
34+
35+
def find_union_mask(mask1, mask2):
36+
segmentation = mask1['segmentation'] + mask2['segmentation']
37+
bbox = find_union(mask1['bbox'], mask2['bbox'])
38+
area = segmentation.sum()
39+
40+
union_mask = copy.deepcopy(mask1)
41+
union_mask['segmentation'] = segmentation
42+
union_mask['bbox'] = bbox
43+
union_mask['area'] = area
44+
return union_mask

0 commit comments

Comments
 (0)