Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solved minor bugs in inverse canonicalization for discrete groups #18

Merged
merged 3 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion equiadapt/images/canonicalization/discrete_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
len(in_shape) == 3
), "Input shape should be in the format (channels, height, width)"

# DEfine all the image transformations here which are used during canonicalization
# Define all the image transformations here which are used during canonicalization
# pad and crop the input image if it is not rotated MNIST
is_grayscale = in_shape[0] == 1

Expand Down
14 changes: 7 additions & 7 deletions equiadapt/images/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ def get_action_on_image_features(
batch_size, C, H, W = feature_map.shape
if induced_rep_type == "regular":
assert feature_map.shape[1] % num_group == 0
angles = group_element_dict["group"]["rotation"]
angles = group_element_dict["rotation"]
x_out = K.geometry.rotate(feature_map, angles)

if "reflection" in group_element_dict["group"]:
reflect_indicator = group_element_dict["group"]["reflection"]
if "reflection" in group_element_dict:
reflect_indicator = group_element_dict["reflection"]
x_out_reflected = K.geometry.hflip(x_out)
x_out = x_out * reflect_indicator[:, None, None, None] + x_out_reflected * (
1 - reflect_indicator[:, None, None, None]
)

x_out = x_out.reshape(batch_size, C // num_group, num_group, H, W)
shift = angles / 360.0 * num_rotations
if "reflection" in group_element_dict["group"]:
if "reflection" in group_element_dict:
x_out = torch.cat(
[
roll_by_gather(x_out[:, :, :num_rotations], shift),
Expand All @@ -78,10 +78,10 @@ def get_action_on_image_features(
x_out = x_out.reshape(batch_size, -1, H, W)
return x_out
elif induced_rep_type == "scalar":
angles = group_element_dict["group"][0]
angles = group_element_dict["rotation"]
x_out = K.geometry.rotate(feature_map, angles)
if "reflection" in group_element_dict["group"]:
reflect_indicator = group_element_dict["group"]["reflection"]
if "reflection" in group_element_dict:
reflect_indicator = group_element_dict["reflection"]
x_out_reflected = K.geometry.hflip(x_out)
x_out = x_out * reflect_indicator[:, None, None, None] + x_out_reflected * (
1 - reflect_indicator[:, None, None, None]
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ norecursedirs =
build
.tox
testpaths = tests
filterwarnings = ignore::Warning
# Use pytest markers to select/deselect specific tests
# markers =
# slow: mark tests as slow (deselect with '-m "not slow"')
Expand Down
69 changes: 69 additions & 0 deletions tests/images/canonicalization/test_discrete_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import torch
from omegaconf import DictConfig

from equiadapt.images.canonicalization.discrete_group import (
GroupEquivariantImageCanonicalization,
)
from equiadapt.images.canonicalization_networks.escnn_networks import (
ESCNNEquivariantNetwork,
)


@pytest.fixture
def init_args() -> dict:
"""
Initialize the arguments for the canonicalization function.

Returns:
dict: A dictionary containing the initialization arguments.
"""
# Mock initialization arguments
canonicalization_hyperparams = DictConfig(
{
"input_crop_ratio": 0.9,
"resize_shape": (32, 32),
"beta": 0.1,
}
)
return {
"canonicalization_network": ESCNNEquivariantNetwork(
in_shape=(3, 64, 64),
out_channels=32,
kernel_size=3,
group_type="rotation",
num_rotations=4,
num_layers=2,
),
"canonicalization_hyperparams": canonicalization_hyperparams,
"in_shape": (3, 64, 64),
}


# try both types of induced representations (regular and scalar)
@pytest.mark.parametrize("induced_rep, num_channels", [("regular", 12), ("scalar", 3)])
def test_invert_canonicalization_induced_rep(
induced_rep: str, num_channels: int, init_args: dict
) -> None:
"""
Test the inversion of the canonicalization-induced representation.

Args:
induced_rep (str): The type of induced representation.
num_channels (int): The number of channels in the sample image.
"""

# Initialize the canonicalization function
dgic = GroupEquivariantImageCanonicalization(**init_args)

# Apply the canonicalization function
image = torch.randn((1, 3, 64, 64))

_ = dgic(image) # to populate the canonicalization_info_dict

canonicalized_image = torch.randn((1, num_channels, 64, 64))

# Invert the canonicalization-induced representation
inverted_image = dgic.invert_canonicalization(
canonicalized_image, **{"induced_rep_type": induced_rep}
)
Loading