Skip to content

Commit

Permalink
fixed mypy errors of equiadapt and examples/pointcloud
Browse files Browse the repository at this point in the history
  • Loading branch information
arnab39 committed Mar 13, 2024
1 parent 5f19380 commit 302c4d8
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 285 deletions.
30 changes: 0 additions & 30 deletions equiadapt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from equiadapt import common
from equiadapt import images
from equiadapt import pointcloud

from equiadapt.common import (
BaseCanonicalization,
ContinuousGroupCanonicalization,
Expand Down Expand Up @@ -34,33 +30,22 @@
custom_group_equivariant_layers,
custom_nonequivariant_networks,
escnn_networks,
flip_boxes,
flip_masks,
get_action_on_image_features,
roll_by_gather,
rotate_boxes,
rotate_masks,
rotate_points,
)
from equiadapt.pointcloud import (
ContinuousGroupPointcloudCanonicalization,
EPS,
EquivariantPointcloudCanonicalization,
VNBatchNorm,
VNBilinear,
VNLeakyReLU,
VNLinear,
VNLinearAndLeakyReLU,
VNLinearLeakyReLU,
VNMaxPool,
VNSmall,
VNSoftplus,
VNStdFeature,
equivariant_networks,
get_graph_feature_cross,
knn,
mean_pool,
vector_neuron_layers,
)

__all__ = [
Expand All @@ -72,7 +57,6 @@
"CustomEquivariantNetwork",
"DiscreteGroupCanonicalization",
"DiscreteGroupImageCanonicalization",
"EPS",
"ESCNNEquivariantNetwork",
"ESCNNSteerableNetwork",
"ESCNNWRNEquivariantNetwork",
Expand All @@ -94,32 +78,18 @@
"VNBilinear",
"VNLeakyReLU",
"VNLinear",
"VNLinearAndLeakyReLU",
"VNLinearLeakyReLU",
"VNMaxPool",
"VNSmall",
"VNSoftplus",
"VNStdFeature",
"basecanonicalization",
"common",
"custom_equivariant_networks",
"custom_group_equivariant_layers",
"custom_nonequivariant_networks",
"equivariant_networks",
"escnn_networks",
"flip_boxes",
"flip_masks",
"get_action_on_image_features",
"get_graph_feature_cross",
"gram_schmidt",
"images",
"knn",
"mean_pool",
"pointcloud",
"roll_by_gather",
"rotate_boxes",
"rotate_masks",
"rotate_points",
"utils",
"vector_neuron_layers",
]
30 changes: 19 additions & 11 deletions equiadapt/common/basecanonicalization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional

import torch

Expand All @@ -13,7 +13,7 @@ def __init__(self, canonicalization_network: torch.nn.Module):
self.canonicalization_info_dict: Dict[str, torch.Tensor] = {}

def forward(
self, x: torch.Tensor, targets: List = None, **kwargs: Any
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
Forward method for the canonicalization which takes the input data and
Expand All @@ -33,7 +33,7 @@ def forward(
return self.canonicalize(x, targets, **kwargs)

def canonicalize(
self, x: torch.Tensor, targets: List = None, **kwargs: Any
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
This method takes an input data with, optionally, targets that need to be canonicalized
Expand All @@ -48,7 +48,9 @@ def canonicalize(
"""
raise NotImplementedError()

def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""
This method takes the output of the canonicalized data
and returns the output for the original data orientation
Expand All @@ -69,14 +71,16 @@ def __init__(self, canonicalization_network: torch.nn.Module = torch.nn.Identity
super().__init__(canonicalization_network)

def canonicalize(
self, x: torch.Tensor, targets: List = None, **kwargs: Any
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
if targets:
return x, targets
return x

def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
return x
def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
return x_canonicalized_out

def get_prior_regularization_loss(self) -> torch.Tensor:
return torch.tensor(0.0)
Expand Down Expand Up @@ -135,7 +139,7 @@ def groupactivations_to_groupelementonehot(
return group_element_onehot

def canonicalize(
self, x: torch.Tensor, targets: List = None, **kwargs: Any
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
This method takes an input data and
Expand All @@ -145,7 +149,9 @@ def canonicalize(
"""
raise NotImplementedError()

def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""
This method takes the output of the canonicalized data
and returns the output for the original data orientation
Expand Down Expand Up @@ -185,7 +191,7 @@ def canonicalizationnetworkout_to_groupelement(
raise NotImplementedError()

def canonicalize(
self, x: torch.Tensor, targets: List = None, **kwargs: Any
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
This method takes an input data and
Expand All @@ -195,7 +201,9 @@ def canonicalize(
"""
raise NotImplementedError()

def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""
This method takes the output of the canonicalized data
and returns the output for the original data orientation
Expand Down
52 changes: 29 additions & 23 deletions equiadapt/images/canonicalization/continuous_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from typing import Any, List, Tuple, Union
from omegaconf import DictConfig
from typing import Any, List, Tuple, Union, Optional, Dict

import kornia as K
import torch
Expand All @@ -15,7 +16,7 @@ class ContinuousGroupImageCanonicalization(ContinuousGroupCanonicalization):
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: dict,
canonicalization_hyperparams: DictConfig,
in_shape: tuple,
):
super().__init__(canonicalization_network)
Expand Down Expand Up @@ -55,9 +56,9 @@ def __init__(
if is_grayscale
else transforms.Resize(size=canonicalization_hyperparams.resize_shape)
)
self.group_info_dict = {}
self.group_info_dict: Dict[str, Any] = {}

def get_groupelement(self, x: torch.Tensor):
def get_groupelement(self, x: torch.Tensor) -> dict:
"""
This method takes the input image and
maps it to the group element
Expand All @@ -70,7 +71,9 @@ def get_groupelement(self, x: torch.Tensor):
"""
raise NotImplementedError("get_groupelement method is not implemented")

def transformations_before_canonicalization_network_forward(self, x: torch.Tensor):
def transformations_before_canonicalization_network_forward(
self, x: torch.Tensor
) -> torch.Tensor:
"""
This method takes an image as input and
returns the pre-canonicalized image
Expand All @@ -79,7 +82,9 @@ def transformations_before_canonicalization_network_forward(self, x: torch.Tenso
x = self.resize_canonization(x)
return x

def get_group_from_out_vectors(self, out_vectors: torch.Tensor):
def get_group_from_out_vectors(
self, out_vectors: torch.Tensor
) -> Tuple[dict, torch.Tensor]:
"""
This method takes the output of the canonicalization network and
returns the group element
Expand Down Expand Up @@ -128,7 +133,7 @@ def get_group_from_out_vectors(self, out_vectors: torch.Tensor):
)

def canonicalize(
self, x: torch.Tensor, targets: List = None, **kwargs: Any
self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any
) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]:
"""
This method takes an image as input and
Expand Down Expand Up @@ -178,17 +183,18 @@ def canonicalize(
return x

def invert_canonicalization(
self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = "vector"
):
self, x_canonicalized_out: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""
This method takes the output of canonicalized image as input and
returns output of the original image
"""
induced_rep_type = kwargs.get("induced_rep_type", "vector")
return get_action_on_image_features(
feature_map=x_canonicalized_out,
group_info_dict=self.group_info_dict,
group_element_dict=self.canonicalization_info_dict["group_element"],
group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore
induced_rep_type=induced_rep_type,
)

Expand All @@ -197,15 +203,15 @@ class SteerableImageCanonicalization(ContinuousGroupImageCanonicalization):
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: dict,
canonicalization_hyperparams: DictConfig,
in_shape: tuple,
):
super().__init__(
canonicalization_network, canonicalization_hyperparams, in_shape
)
self.group_type = canonicalization_network.group_type

def get_rotation_matrix_from_vector(self, vectors: torch.Tensor):
def get_rotation_matrix_from_vector(self, vectors: torch.Tensor) -> torch.Tensor:
"""
This method takes the input vector and returns the rotation matrix
Expand All @@ -220,7 +226,7 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor):
rotation_matrices = torch.stack([v1, v2], dim=1)
return rotation_matrices

def get_groupelement(self, x: torch.Tensor):
def get_groupelement(self, x: torch.Tensor) -> dict:
"""
This method takes the input image and
maps it to the group element
Expand All @@ -232,7 +238,7 @@ def get_groupelement(self, x: torch.Tensor):
group_element: group element
"""

group_element_dict = {}
group_element_dict: Dict[str, Any] = {}

x = self.transformations_before_canonicalization_network_forward(x)

Expand All @@ -251,7 +257,7 @@ def get_groupelement(self, x: torch.Tensor):
group_element_representation
)

self.canonicalization_info_dict["group_element"] = group_element_dict
self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore

return group_element_dict

Expand All @@ -260,15 +266,15 @@ class OptimizedSteerableImageCanonicalization(ContinuousGroupImageCanonicalizati
def __init__(
self,
canonicalization_network: torch.nn.Module,
canonicalization_hyperparams: dict,
canonicalization_hyperparams: DictConfig,
in_shape: tuple,
):
super().__init__(
canonicalization_network, canonicalization_hyperparams, in_shape
)
self.group_type = canonicalization_hyperparams.group_type

def get_rotation_matrix_from_vector(self, vectors: torch.Tensor):
def get_rotation_matrix_from_vector(self, vectors: torch.Tensor) -> torch.Tensor:
"""
This method takes the input vector and returns the rotation matrix
Expand All @@ -283,7 +289,7 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor):
rotation_matrices = torch.stack([v1, v2], dim=1)
return rotation_matrices

def group_augment(self, x):
def group_augment(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Augmentation of the input images by applying random rotations and,
if applicable, reflections, with corresponding transformation matrices.
Expand Down Expand Up @@ -322,7 +328,7 @@ def group_augment(self, x):
x = self.pad(x)

# Note: F.affine_grid expects theta of shape (N, 2, 3) for 2D affine transformations
grid = F.affine_grid(rotation_matrices, x.size(), align_corners=False)
grid = F.affine_grid(rotation_matrices, list(x.size()), align_corners=False)

augmented_images = F.grid_sample(x, grid, align_corners=False)

Expand All @@ -336,7 +342,7 @@ def group_augment(self, x):
# Return augmented images and the transformation matrices used
return augmented_images, rotation_matrices[:, :, :2]

def get_groupelement(self, x: torch.Tensor):
def get_groupelement(self, x: torch.Tensor) -> dict:
"""
This method takes the input image and
maps it to the group element
Expand All @@ -348,7 +354,7 @@ def get_groupelement(self, x: torch.Tensor):
group_element: group element
"""

group_element_dict = {}
group_element_dict: Dict[str, Any] = {}

batch_size = x.shape[0]

Expand Down Expand Up @@ -385,7 +391,7 @@ def get_groupelement(self, x: torch.Tensor):
self.canonicalization_info_dict["group_element_matrix_representation"] = (
group_element_representations
)
self.canonicalization_info_dict["group_element"] = group_element_dict
self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore

_, group_element_representations_augmented = self.get_group_from_out_vectors(
out_vectors_augmented
Expand All @@ -399,7 +405,7 @@ def get_groupelement(self, x: torch.Tensor):

return group_element_dict

def get_optimization_specific_loss(self):
def get_optimization_specific_loss(self) -> torch.Tensor:
"""
This method returns the optimization specific loss
Expand Down
Loading

0 comments on commit 302c4d8

Please sign in to comment.