Skip to content

Commit

Permalink
docs for equiadapt/images/utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sibasmarak committed Mar 13, 2024
1 parent 35ef33d commit 666ecd9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
73 changes: 71 additions & 2 deletions equiadapt/images/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@


def roll_by_gather(feature_map: torch.Tensor, shifts: torch.Tensor) -> torch.Tensor:
"""
Shifts the feature map along the group dimension by the specified shifts.
Args:
feature_map (torch.Tensor): The input feature map. It should have the shape (batch, channel, group, x_dim, y_dim).
shifts (torch.Tensor): The shifts for each feature map in the batch.
Returns:
torch.Tensor: The shifted feature map.
"""
device = shifts.device
# assumes 2D array
batch, channel, group, x_dim, y_dim = feature_map.shape
Expand All @@ -26,8 +36,16 @@ def get_action_on_image_features(
induced_rep_type: str = "regular",
) -> torch.Tensor:
"""
This function takes the feature map and the action and returns the feature map
after the action has been applied
Applies a group action to the feature map.
Args:
feature_map (torch.Tensor): The input feature map.
group_info_dict (dict): A dictionary containing information about the group.
group_element_dict (dict): A dictionary containing the group elements.
induced_rep_type (str, optional): The type of induced representation. Defaults to "regular".
Returns:
torch.Tensor: The feature map after the group action has been applied.
"""
num_rotations = group_info_dict["num_rotations"]
num_group = group_info_dict["num_group"]
Expand Down Expand Up @@ -77,21 +95,61 @@ def get_action_on_image_features(


def flip_boxes(boxes: torch.Tensor, width: int) -> torch.Tensor:
"""
Flips bounding boxes horizontally.
Args:
boxes (torch.Tensor): The bounding boxes to flip.
width (int): The width of the image.
Returns:
torch.Tensor: The flipped bounding boxes.
"""
boxes[:, [0, 2]] = width - boxes[:, [2, 0]]
return boxes


def flip_masks(masks: torch.Tensor) -> torch.Tensor:
"""
Flips masks horizontally.
Args:
masks (torch.Tensor): The masks to flip.
Returns:
torch.Tensor: The flipped masks.
"""
return masks.flip(-1)


def rotate_masks(masks: torch.Tensor, angle: torch.Tensor) -> torch.Tensor:
"""
Rotates masks by a specified angle.
Args:
masks (torch.Tensor): The masks to rotate.
angle (torch.Tensor): The angle to rotate the masks by.
Returns:
torch.Tensor: The rotated masks.
"""
return transforms.functional.rotate(masks, angle)


def rotate_points(
origin: List[float], point: torch.Tensor, angle: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Rotates a point around an origin by a specified angle.
Args:
origin (List[float]): The origin to rotate the point around.
point (torch.Tensor): The point to rotate.
angle (torch.Tensor): The angle to rotate the point by.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The rotated point.
"""
ox, oy = origin
px, py = point

Expand All @@ -101,6 +159,17 @@ def rotate_points(


def rotate_boxes(boxes: torch.Tensor, angle: torch.Tensor, width: int) -> torch.Tensor:
"""
Rotates bounding boxes by a specified angle.
Args:
boxes (torch.Tensor): The bounding boxes to rotate.
angle (torch.Tensor): The angle to rotate the bounding boxes by.
width (int): The width of the image.
Returns:
torch.Tensor: The rotated bounding boxes.
"""
# rotate points
origin: List[float] = [width / 2, width / 2]
x_min_rot, y_min_rot = rotate_points(origin, boxes[:, :2].T, torch.deg2rad(angle))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ formats = bdist_wheel
[flake8]
# Some sane defaults for the code style checker flake8
max_line_length = 130
extend_ignore = E203, W503, E401, E501, E741, E266, D100, D107, D400, D401
extend_ignore = E203, W503, E401, E501, E741, E266, D100, D107, D400, D401, D104
# ^ Black-compatible
# E203 and W503 have edge cases handled by black
exclude =
Expand Down

0 comments on commit 666ecd9

Please sign in to comment.