Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
SVivdich02 committed Mar 4, 2024
1 parent 8ae5cdf commit e8ab8bf
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
16 changes: 9 additions & 7 deletions src/services/preprocessing/init/instances_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,29 @@ def reduce_detail(self, masks, intersection_to_union_ratio_threshold=0.35):
if i in merged_indices:
continue

area_bbox_i = calculate_area(masks[i]['bbox'])
area_bbox_i = calculate_area(masks[i]["bbox"])

indices_merged_with_i = []
for j in range(i + 1, len(masks)):
if j in merged_indices:
continue

area_bbox_j = calculate_area(masks[j]['bbox'])
area_bbox_j = calculate_area(masks[j]["bbox"])

intersection_mask = find_intersection_mask(masks[i], masks[j])
if intersection_mask == None:
continue
area_intersection = intersection_mask['area']
area_intersection = intersection_mask["area"]

area_bbox_intersection = calculate_area(intersection_mask['bbox'])
area_bbox_intersection = calculate_area(intersection_mask["bbox"])
area_bbox_union = area_bbox_i + area_bbox_j - area_bbox_intersection
IU_ratio = area_bbox_intersection / area_bbox_union

if (IU_ratio >= intersection_to_union_ratio_threshold
or area_intersection / masks[i]['area'] >= 0.6
or area_intersection / masks[j]['area'] >= 0.6):
if (
IU_ratio >= intersection_to_union_ratio_threshold
or area_intersection / masks[i]["area"] >= 0.6
or area_intersection / masks[j]["area"] >= 0.6
):
masks[i] = find_union_mask(masks[i], masks[j])
indices_merged_with_i.append(j)

Expand Down
9 changes: 7 additions & 2 deletions src/utils/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ def find_intersection(bbox1, bbox2):
x2_bbox2 = x1_bbox2 + w_bbox2
y2_bbox2 = y1_bbox2 + h_bbox2

if (x1_bbox1 > x2_bbox2 or x2_bbox1 < x1_bbox2 or y1_bbox1 > y2_bbox2 or y2_bbox1 < y1_bbox2):
if (
x1_bbox1 > x2_bbox2
or x2_bbox1 < x1_bbox2
or y1_bbox1 > y2_bbox2
or y2_bbox1 < y1_bbox2
):
return None

x_left = max(x1_bbox1, x1_bbox2)
x_right = min(x2_bbox1, x2_bbox2)
y_top = max(y1_bbox1, y1_bbox2)
Expand Down
20 changes: 10 additions & 10 deletions src/utils/sam_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,26 @@


def find_intersection_mask(mask1, mask2):
bbox = find_intersection(mask1['bbox'], mask2['bbox'])
bbox = find_intersection(mask1["bbox"], mask2["bbox"])
if bbox == None:
return None
segmentation = mask1['segmentation'] * mask2['segmentation']
segmentation = mask1["segmentation"] * mask2["segmentation"]
area = segmentation.sum()

intersection_mask = copy.deepcopy(mask1)
intersection_mask['segmentation'] = segmentation
intersection_mask['bbox'] = bbox
intersection_mask['area'] = area
intersection_mask["segmentation"] = segmentation
intersection_mask["bbox"] = bbox
intersection_mask["area"] = area
return intersection_mask


def find_union_mask(mask1, mask2):
segmentation = mask1['segmentation'] + mask2['segmentation']
bbox = find_union(mask1['bbox'], mask2['bbox'])
segmentation = mask1["segmentation"] + mask2["segmentation"]
bbox = find_union(mask1["bbox"], mask2["bbox"])
area = segmentation.sum()

union_mask = copy.deepcopy(mask1)
union_mask['segmentation'] = segmentation
union_mask['bbox'] = bbox
union_mask['area'] = area
union_mask["segmentation"] = segmentation
union_mask["bbox"] = bbox
union_mask["area"] = area
return union_mask

0 comments on commit e8ab8bf

Please sign in to comment.