From 302c4d892ee6314e4547065a9af32065c55bc3be Mon Sep 17 00:00:00 2001 From: arnab39 Date: Wed, 13 Mar 2024 01:04:04 -0400 Subject: [PATCH] fixed mypy errors of equiadapt and examples/pointcloud --- equiadapt/__init__.py | 30 ----- equiadapt/common/basecanonicalization.py | 30 +++-- .../canonicalization/continuous_group.py | 52 ++++---- .../images/canonicalization/discrete_group.py | 44 ++++--- .../custom_equivariant_networks.py | 29 +++-- .../custom_group_equivariant_layers.py | 98 +++++++------- equiadapt/images/utils.py | 19 +-- equiadapt/pointcloud/__init__.py | 2 - .../canonicalization/continuous_group.py | 15 ++- .../canonicalization_networks/__init__.py | 2 - .../equivariant_networks.py | 14 +- .../vector_neuron_layers.py | 121 +++++++----------- .../pointcloud/classification/model_utils.py | 2 +- examples/pointcloud/classification/train.py | 2 +- .../pointcloud/classification/train_utils.py | 24 +--- .../part_segmentation/model_utils.py | 2 +- .../pointcloud/part_segmentation/train.py | 2 +- .../part_segmentation/train_utils.py | 24 +--- 18 files changed, 227 insertions(+), 285 deletions(-) diff --git a/equiadapt/__init__.py b/equiadapt/__init__.py index 14b3d6b..7e91706 100644 --- a/equiadapt/__init__.py +++ b/equiadapt/__init__.py @@ -1,7 +1,3 @@ -from equiadapt import common -from equiadapt import images -from equiadapt import pointcloud - from equiadapt.common import ( BaseCanonicalization, ContinuousGroupCanonicalization, @@ -34,23 +30,15 @@ custom_group_equivariant_layers, custom_nonequivariant_networks, escnn_networks, - flip_boxes, - flip_masks, get_action_on_image_features, - roll_by_gather, - rotate_boxes, - rotate_masks, - rotate_points, ) from equiadapt.pointcloud import ( ContinuousGroupPointcloudCanonicalization, - EPS, EquivariantPointcloudCanonicalization, VNBatchNorm, VNBilinear, VNLeakyReLU, VNLinear, - VNLinearAndLeakyReLU, VNLinearLeakyReLU, VNMaxPool, VNSmall, @@ -58,9 +46,6 @@ VNStdFeature, equivariant_networks, get_graph_feature_cross, - knn, - mean_pool, - vector_neuron_layers, ) __all__ = [ @@ -72,7 +57,6 @@ "CustomEquivariantNetwork", "DiscreteGroupCanonicalization", "DiscreteGroupImageCanonicalization", - "EPS", "ESCNNEquivariantNetwork", "ESCNNSteerableNetwork", "ESCNNWRNEquivariantNetwork", @@ -94,32 +78,18 @@ "VNBilinear", "VNLeakyReLU", "VNLinear", - "VNLinearAndLeakyReLU", "VNLinearLeakyReLU", "VNMaxPool", "VNSmall", "VNSoftplus", "VNStdFeature", "basecanonicalization", - "common", "custom_equivariant_networks", "custom_group_equivariant_layers", "custom_nonequivariant_networks", "equivariant_networks", "escnn_networks", - "flip_boxes", - "flip_masks", "get_action_on_image_features", "get_graph_feature_cross", "gram_schmidt", - "images", - "knn", - "mean_pool", - "pointcloud", - "roll_by_gather", - "rotate_boxes", - "rotate_masks", - "rotate_points", - "utils", - "vector_neuron_layers", ] diff --git a/equiadapt/common/basecanonicalization.py b/equiadapt/common/basecanonicalization.py index d90a066..c860067 100644 --- a/equiadapt/common/basecanonicalization.py +++ b/equiadapt/common/basecanonicalization.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import torch @@ -13,7 +13,7 @@ def __init__(self, canonicalization_network: torch.nn.Module): self.canonicalization_info_dict: Dict[str, torch.Tensor] = {} def forward( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ Forward method for the canonicalization which takes the input data and @@ -33,7 +33,7 @@ def forward( return self.canonicalize(x, targets, **kwargs) def canonicalize( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an input data with, optionally, targets that need to be canonicalized @@ -48,7 +48,9 @@ def canonicalize( """ raise NotImplementedError() - def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + def invert_canonicalization( + self, x_canonicalized_out: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: """ This method takes the output of the canonicalized data and returns the output for the original data orientation @@ -69,14 +71,16 @@ def __init__(self, canonicalization_network: torch.nn.Module = torch.nn.Identity super().__init__(canonicalization_network) def canonicalize( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: if targets: return x, targets return x - def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: - return x + def invert_canonicalization( + self, x_canonicalized_out: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: + return x_canonicalized_out def get_prior_regularization_loss(self) -> torch.Tensor: return torch.tensor(0.0) @@ -135,7 +139,7 @@ def groupactivations_to_groupelementonehot( return group_element_onehot def canonicalize( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an input data and @@ -145,7 +149,9 @@ def canonicalize( """ raise NotImplementedError() - def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + def invert_canonicalization( + self, x_canonicalized_out: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: """ This method takes the output of the canonicalized data and returns the output for the original data orientation @@ -185,7 +191,7 @@ def canonicalizationnetworkout_to_groupelement( raise NotImplementedError() def canonicalize( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an input data and @@ -195,7 +201,9 @@ def canonicalize( """ raise NotImplementedError() - def invert_canonicalization(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: + def invert_canonicalization( + self, x_canonicalized_out: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: """ This method takes the output of the canonicalized data and returns the output for the original data orientation diff --git a/equiadapt/images/canonicalization/continuous_group.py b/equiadapt/images/canonicalization/continuous_group.py index 5a89478..fb47ad7 100644 --- a/equiadapt/images/canonicalization/continuous_group.py +++ b/equiadapt/images/canonicalization/continuous_group.py @@ -1,5 +1,6 @@ import math -from typing import Any, List, Tuple, Union +from omegaconf import DictConfig +from typing import Any, List, Tuple, Union, Optional, Dict import kornia as K import torch @@ -15,7 +16,7 @@ class ContinuousGroupImageCanonicalization(ContinuousGroupCanonicalization): def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, in_shape: tuple, ): super().__init__(canonicalization_network) @@ -55,9 +56,9 @@ def __init__( if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape) ) - self.group_info_dict = {} + self.group_info_dict: Dict[str, Any] = {} - def get_groupelement(self, x: torch.Tensor): + def get_groupelement(self, x: torch.Tensor) -> dict: """ This method takes the input image and maps it to the group element @@ -70,7 +71,9 @@ def get_groupelement(self, x: torch.Tensor): """ raise NotImplementedError("get_groupelement method is not implemented") - def transformations_before_canonicalization_network_forward(self, x: torch.Tensor): + def transformations_before_canonicalization_network_forward( + self, x: torch.Tensor + ) -> torch.Tensor: """ This method takes an image as input and returns the pre-canonicalized image @@ -79,7 +82,9 @@ def transformations_before_canonicalization_network_forward(self, x: torch.Tenso x = self.resize_canonization(x) return x - def get_group_from_out_vectors(self, out_vectors: torch.Tensor): + def get_group_from_out_vectors( + self, out_vectors: torch.Tensor + ) -> Tuple[dict, torch.Tensor]: """ This method takes the output of the canonicalization network and returns the group element @@ -128,7 +133,7 @@ def get_group_from_out_vectors(self, out_vectors: torch.Tensor): ) def canonicalize( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an image as input and @@ -178,17 +183,18 @@ def canonicalize( return x def invert_canonicalization( - self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = "vector" - ): + self, x_canonicalized_out: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: """ This method takes the output of canonicalized image as input and returns output of the original image """ + induced_rep_type = kwargs.get("induced_rep_type", "vector") return get_action_on_image_features( feature_map=x_canonicalized_out, group_info_dict=self.group_info_dict, - group_element_dict=self.canonicalization_info_dict["group_element"], + group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore induced_rep_type=induced_rep_type, ) @@ -197,7 +203,7 @@ class SteerableImageCanonicalization(ContinuousGroupImageCanonicalization): def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, in_shape: tuple, ): super().__init__( @@ -205,7 +211,7 @@ def __init__( ) self.group_type = canonicalization_network.group_type - def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): + def get_rotation_matrix_from_vector(self, vectors: torch.Tensor) -> torch.Tensor: """ This method takes the input vector and returns the rotation matrix @@ -220,7 +226,7 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): rotation_matrices = torch.stack([v1, v2], dim=1) return rotation_matrices - def get_groupelement(self, x: torch.Tensor): + def get_groupelement(self, x: torch.Tensor) -> dict: """ This method takes the input image and maps it to the group element @@ -232,7 +238,7 @@ def get_groupelement(self, x: torch.Tensor): group_element: group element """ - group_element_dict = {} + group_element_dict: Dict[str, Any] = {} x = self.transformations_before_canonicalization_network_forward(x) @@ -251,7 +257,7 @@ def get_groupelement(self, x: torch.Tensor): group_element_representation ) - self.canonicalization_info_dict["group_element"] = group_element_dict + self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore return group_element_dict @@ -260,7 +266,7 @@ class OptimizedSteerableImageCanonicalization(ContinuousGroupImageCanonicalizati def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, in_shape: tuple, ): super().__init__( @@ -268,7 +274,7 @@ def __init__( ) self.group_type = canonicalization_hyperparams.group_type - def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): + def get_rotation_matrix_from_vector(self, vectors: torch.Tensor) -> torch.Tensor: """ This method takes the input vector and returns the rotation matrix @@ -283,7 +289,7 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): rotation_matrices = torch.stack([v1, v2], dim=1) return rotation_matrices - def group_augment(self, x): + def group_augment(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Augmentation of the input images by applying random rotations and, if applicable, reflections, with corresponding transformation matrices. @@ -322,7 +328,7 @@ def group_augment(self, x): x = self.pad(x) # Note: F.affine_grid expects theta of shape (N, 2, 3) for 2D affine transformations - grid = F.affine_grid(rotation_matrices, x.size(), align_corners=False) + grid = F.affine_grid(rotation_matrices, list(x.size()), align_corners=False) augmented_images = F.grid_sample(x, grid, align_corners=False) @@ -336,7 +342,7 @@ def group_augment(self, x): # Return augmented images and the transformation matrices used return augmented_images, rotation_matrices[:, :, :2] - def get_groupelement(self, x: torch.Tensor): + def get_groupelement(self, x: torch.Tensor) -> dict: """ This method takes the input image and maps it to the group element @@ -348,7 +354,7 @@ def get_groupelement(self, x: torch.Tensor): group_element: group element """ - group_element_dict = {} + group_element_dict: Dict[str, Any] = {} batch_size = x.shape[0] @@ -385,7 +391,7 @@ def get_groupelement(self, x: torch.Tensor): self.canonicalization_info_dict["group_element_matrix_representation"] = ( group_element_representations ) - self.canonicalization_info_dict["group_element"] = group_element_dict + self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore _, group_element_representations_augmented = self.get_group_from_out_vectors( out_vectors_augmented @@ -399,7 +405,7 @@ def get_groupelement(self, x: torch.Tensor): return group_element_dict - def get_optimization_specific_loss(self): + def get_optimization_specific_loss(self) -> torch.Tensor: """ This method returns the optimization specific loss diff --git a/equiadapt/images/canonicalization/discrete_group.py b/equiadapt/images/canonicalization/discrete_group.py index a84c272..0d2ef00 100644 --- a/equiadapt/images/canonicalization/discrete_group.py +++ b/equiadapt/images/canonicalization/discrete_group.py @@ -1,5 +1,6 @@ import math -from typing import List, Tuple, Union +from omegaconf import DictConfig +from typing import List, Tuple, Union, Optional, Any import kornia as K import torch @@ -20,7 +21,7 @@ class DiscreteGroupImageCanonicalization(DiscreteGroupCanonicalization): def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, in_shape: tuple, ): super().__init__(canonicalization_network) @@ -67,7 +68,7 @@ def __init__( else transforms.Resize(size=canonicalization_hyperparams.resize_shape) ) - def groupactivations_to_groupelement(self, group_activations: torch.Tensor): + def groupactivations_to_groupelement(self, group_activations: torch.Tensor) -> dict: """ This method takes the activations for each group element as input and returns the group element @@ -112,7 +113,7 @@ def groupactivations_to_groupelement(self, group_activations: torch.Tensor): return group_element_dict - def get_group_activations(self, x: torch.Tensor): + def get_group_activations(self, x: torch.Tensor) -> torch.Tensor: """ This method takes an image as input and returns the group activations @@ -122,7 +123,7 @@ def get_group_activations(self, x: torch.Tensor): "the DiscreteGroupImageCanonicalization class" ) - def get_groupelement(self, x: torch.Tensor): + def get_groupelement(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """ This method takes the input image and maps it to the group element @@ -140,12 +141,14 @@ def get_groupelement(self, x: torch.Tensor): if not hasattr(self, "canonicalization_info_dict"): self.canonicalization_info_dict = {} - self.canonicalization_info_dict["group_element"] = group_element_dict + self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore self.canonicalization_info_dict["group_activations"] = group_activations return group_element_dict - def transformations_before_canonicalization_network_forward(self, x: torch.Tensor): + def transformations_before_canonicalization_network_forward( + self, x: torch.Tensor + ) -> torch.Tensor: """ This method takes an image as input and returns the pre-canonicalized image @@ -155,7 +158,7 @@ def transformations_before_canonicalization_network_forward(self, x: torch.Tenso return x def canonicalize( - self, x: torch.Tensor, targets: List = None + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an image as input and @@ -190,7 +193,7 @@ def canonicalize( targets[t]["boxes"], group_element_dict["rotation"][t], image_width ) targets[t]["masks"] = rotate_masks( - targets[t]["masks"], -group_element_dict["rotation"][t].item() + targets[t]["masks"], -group_element_dict["rotation"][t].item() # type: ignore ) return x, targets @@ -198,16 +201,17 @@ def canonicalize( return x def invert_canonicalization( - self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = "regular" - ): + self, x_canonicalized_out: torch.Tensor, **kwargs: Any + ) -> torch.Tensor: """ This method takes the output of canonicalized image as input and returns output of the original image """ + induced_rep_type = kwargs.get("induced_rep_type", "regular") return get_action_on_image_features( feature_map=x_canonicalized_out, group_info_dict=self.group_info_dict, - group_element_dict=self.canonicalization_info_dict["group_element"], + group_element_dict=self.canonicalization_info_dict["group_element"], # type: ignore induced_rep_type=induced_rep_type, ) @@ -216,7 +220,7 @@ class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization): def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, in_shape: tuple, ): super().__init__( @@ -234,7 +238,7 @@ def __init__( "num_group": self.num_group, } - def get_group_activations(self, x: torch.Tensor): + def get_group_activations(self, x: torch.Tensor) -> torch.Tensor: """ This method takes an image as input and returns the group activations @@ -250,7 +254,7 @@ class OptimizedGroupEquivariantImageCanonicalization( def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, in_shape: tuple, ): super().__init__( @@ -292,7 +296,7 @@ def __init__( def rotate_and_maybe_reflect( self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False - ): + ) -> List[torch.Tensor]: x_augmented_list = [] for degree in degrees: x_rot = self.pad_group_augment(x) @@ -303,7 +307,7 @@ def rotate_and_maybe_reflect( x_augmented_list.append(x_rot) return x_augmented_list - def group_augment(self, x: torch.Tensor): + def group_augment(self, x: torch.Tensor) -> torch.Tensor: degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device) x_augmented_list = self.rotate_and_maybe_reflect(x, degrees) @@ -313,7 +317,7 @@ def group_augment(self, x: torch.Tensor): return torch.cat(x_augmented_list, dim=0) - def get_group_activations(self, x: torch.Tensor): + def get_group_activations(self, x: torch.Tensor) -> torch.Tensor: """ This method takes an image as input and returns the group activations @@ -363,7 +367,7 @@ def get_group_activations(self, x: torch.Tensor): ).T # size (batch_size, group_size) return group_activations - def get_optimization_specific_loss(self): + def get_optimization_specific_loss(self) -> torch.Tensor: vectors = self.canonicalization_info_dict["vector_out"] # compute error to reduce rotation artifacts @@ -372,7 +376,7 @@ def get_optimization_specific_loss(self): vectors_dummy = self.canonicalization_info_dict["vector_out_dummy"] rotation_artifact_error = torch.nn.functional.mse_loss( vectors_dummy, vectors - ) + ) # type: ignore # error to ensure that the vectors are (as much as possible) orthogonal vectors = vectors.reshape(self.num_group, -1, self.out_vector_size).permute( diff --git a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py index 9ed20f0..a58bc3b 100644 --- a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py @@ -6,18 +6,19 @@ RotoReflectionEquivariantConvLift, RotoReflectionEquivariantConv, ) +from typing import Tuple class CustomEquivariantNetwork(nn.Module): def __init__( self, - in_shape, - out_channels, - kernel_size, - group_type="rotation", - num_rotations=4, - num_layers=1, - device="cuda" if torch.cuda.is_available() else "cpu", + in_shape: Tuple[int, int, int, int], + out_channels: int, + kernel_size: int, + group_type: str = "rotation", + num_rotations: int = 4, + num_layers: int = 1, + device: str = "cuda" if torch.cuda.is_available() else "cpu", ): super().__init__() @@ -25,34 +26,34 @@ def __init__( layer_list = [ RotationEquivariantConvLift( in_shape[0], out_channels, kernel_size, num_rotations, device=device - ) + ) # type: ignore ] for i in range(num_layers - 1): - layer_list.append(nn.ReLU()) + layer_list.append(nn.ReLU()) # type: ignore layer_list.append( RotationEquivariantConv( out_channels, out_channels, 1, num_rotations, device=device - ) + ) # type: ignore ) self.eqv_network = nn.Sequential(*layer_list) elif group_type == "roto-reflection": layer_list = [ RotoReflectionEquivariantConvLift( in_shape[0], out_channels, kernel_size, num_rotations, device=device - ) + ) # type: ignore ] for i in range(num_layers - 1): - layer_list.append(nn.ReLU()) + layer_list.append(nn.ReLU()) # type: ignore layer_list.append( RotoReflectionEquivariantConv( out_channels, out_channels, 1, num_rotations, device=device - ) + ) # type: ignore ) self.eqv_network = nn.Sequential(*layer_list) else: raise ValueError("group_type must be rotation or roto-reflection for now.") - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x shape: (batch_size, in_channels, height, width) :return: (batch_size, group_size) diff --git a/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py b/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py index 53755a5..e806348 100644 --- a/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py +++ b/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py @@ -8,14 +8,14 @@ class RotationEquivariantConvLift(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size, - num_rotations=4, - stride=1, - padding=0, - bias=True, - device="cuda", + in_channels: int, + out_channels: int, + kernel_size: int, + num_rotations: int = 4, + stride: int = 1, + padding: int = 0, + bias: bool = True, + device: str = "cuda", ): super().__init__() self.weights = nn.Parameter( @@ -26,7 +26,7 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: - self.bias = None + self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride @@ -34,7 +34,9 @@ def __init__( self.num_rotations = num_rotations self.kernel_size = kernel_size - def get_rotated_weights(self, weights, num_rotations=4): + def get_rotated_weights( + self, weights: torch.Tensor, num_rotations: int = 4 + ) -> torch.Tensor: device = weights.device weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1) rotated_weights = K.geometry.rotate( @@ -52,7 +54,7 @@ def get_rotated_weights(self, weights, num_rotations=4): ).transpose(0, 1) return rotated_weights.flatten(0, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x shape: (batch_size, in_channels, height, width) :return: (batch_size, out_channels, num_rotations, height, width) @@ -72,14 +74,14 @@ def forward(self, x): class RotoReflectionEquivariantConvLift(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size, - num_rotations=4, - stride=1, - padding=0, - bias=True, - device="cuda", + in_channels: int, + out_channels: int, + kernel_size: int, + num_rotations: int = 4, + stride: int = 1, + padding: int = 0, + bias: bool = True, + device: str = "cuda", ): super().__init__() num_group_elements = 2 * num_rotations @@ -91,7 +93,7 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: - self.bias = None + self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride @@ -100,7 +102,9 @@ def __init__( self.kernel_size = kernel_size self.num_group_elements = num_group_elements - def get_rotoreflected_weights(self, weights, num_rotations=4): + def get_rotoreflected_weights( + self, weights: torch.Tensor, num_rotations: int = 4 + ) -> torch.Tensor: device = weights.device weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1) rotated_weights = K.geometry.rotate( @@ -120,7 +124,7 @@ def get_rotoreflected_weights(self, weights, num_rotations=4): ).transpose(0, 1) return rotoreflected_weights.flatten(0, 1) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x shape: (batch_size, in_channels, height, width) :return: (batch_size, out_channels, num_group_elements, height, width) @@ -146,14 +150,14 @@ def forward(self, x): class RotationEquivariantConv(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size, - num_rotations=4, - stride=1, - padding=0, - bias=True, - device="cuda", + in_channels: int, + out_channels: int, + kernel_size: int, + num_rotations: int = 4, + stride: int = 1, + padding: int = 0, + bias: bool = True, + device: str = "cuda", ): super().__init__() self.weights = nn.Parameter( @@ -166,7 +170,7 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: - self.bias = None + self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride @@ -188,7 +192,9 @@ def __init__( 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 )[:num_rotations].to(device) - def get_rotated_permuted_weights(self, weights, num_rotations=4): + def get_rotated_permuted_weights( + self, weights: torch.Tensor, num_rotations: int = 4 + ) -> torch.Tensor: weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1, 1) permuted_weights = torch.gather(weights, 2, self.permute_indices_along_group) rotated_permuted_weights = K.geometry.rotate( @@ -214,7 +220,7 @@ def get_rotated_permuted_weights(self, weights, num_rotations=4): ) return rotated_permuted_weights - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x shape: (batch_size, in_channels, num_rotations, height, width) :return: (batch_size, out_channels, num_rotations, height, width) @@ -240,17 +246,17 @@ def forward(self, x): class RotoReflectionEquivariantConv(nn.Module): def __init__( self, - in_channels, - out_channels, - kernel_size, - num_rotations=4, - stride=1, - padding=0, - bias=True, - device="cuda", + in_channels: int, + out_channels: int, + kernel_size: int, + num_rotations: int = 4, + stride: int = 1, + padding: int = 0, + bias: bool = True, + device: str = "cuda", ): super().__init__() - num_group_elements = 2 * num_rotations + num_group_elements: int = 2 * num_rotations self.weights = nn.Parameter( torch.empty( out_channels, in_channels, num_group_elements, kernel_size, kernel_size @@ -261,7 +267,7 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_channels).to(device)) torch.nn.init.zeros_(self.bias) else: - self.bias = None + self.bias = None # type: ignore self.in_channels = in_channels self.out_channels = out_channels self.stride = stride @@ -310,7 +316,9 @@ def __init__( ] ).to(device) - def get_rotoreflected_permuted_weights(self, weights, num_rotations=4): + def get_rotoreflected_permuted_weights( + self, weights: torch.Tensor, num_rotations: int = 4 + ) -> torch.Tensor: weights = ( weights.flatten(0, 1) .unsqueeze(0) @@ -346,7 +354,7 @@ def get_rotoreflected_permuted_weights(self, weights, num_rotations=4): ) return rotoreflected_permuted_weights - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x shape: (batch_size, in_channels, num_group_elements, height, width) :return: (batch_size, out_channels, num_group_elements, height, width) diff --git a/equiadapt/images/utils.py b/equiadapt/images/utils.py index 5446995..3683424 100644 --- a/equiadapt/images/utils.py +++ b/equiadapt/images/utils.py @@ -1,9 +1,10 @@ import kornia as K import torch from torchvision import transforms +from typing import List, Tuple -def roll_by_gather(feature_map: torch.Tensor, shifts: torch.Tensor): +def roll_by_gather(feature_map: torch.Tensor, shifts: torch.Tensor) -> torch.Tensor: device = shifts.device # assumes 2D array batch, channel, group, x_dim, y_dim = feature_map.shape @@ -22,7 +23,7 @@ def get_action_on_image_features( group_info_dict: dict, group_element_dict: dict, induced_rep_type: str = "regular", -): +) -> torch.Tensor: """ This function takes the feature map and the action and returns the feature map after the action has been applied @@ -74,20 +75,22 @@ def get_action_on_image_features( raise ValueError("induced_rep_type must be regular, scalar or vector") -def flip_boxes(boxes, width): +def flip_boxes(boxes: torch.Tensor, width: int) -> torch.Tensor: boxes[:, [0, 2]] = width - boxes[:, [2, 0]] return boxes -def flip_masks(masks): +def flip_masks(masks: torch.Tensor) -> torch.Tensor: return masks.flip(-1) -def rotate_masks(masks, angle): +def rotate_masks(masks: torch.Tensor, angle: torch.Tensor) -> torch.Tensor: return transforms.functional.rotate(masks, angle) -def rotate_points(origin, point, angle): +def rotate_points( + origin: List[float], point: torch.Tensor, angle: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: ox, oy = origin px, py = point @@ -96,9 +99,9 @@ def rotate_points(origin, point, angle): return qx, qy -def rotate_boxes(boxes, angle, width): +def rotate_boxes(boxes: torch.Tensor, angle: torch.Tensor, width: int) -> torch.Tensor: # rotate points - origin = [width / 2, width / 2] + origin: List[float] = [width / 2, width / 2] x_min_rot, y_min_rot = rotate_points(origin, boxes[:, :2].T, torch.deg2rad(angle)) x_max_rot, y_max_rot = rotate_points(origin, boxes[:, 2:].T, torch.deg2rad(angle)) diff --git a/equiadapt/pointcloud/__init__.py b/equiadapt/pointcloud/__init__.py index b9fb3cd..8d9dc3d 100644 --- a/equiadapt/pointcloud/__init__.py +++ b/equiadapt/pointcloud/__init__.py @@ -12,7 +12,6 @@ VNBilinear, VNLeakyReLU, VNLinear, - VNLinearAndLeakyReLU, VNLinearLeakyReLU, VNMaxPool, VNSmall, @@ -33,7 +32,6 @@ "VNBilinear", "VNLeakyReLU", "VNLinear", - "VNLinearAndLeakyReLU", "VNLinearLeakyReLU", "VNMaxPool", "VNSmall", diff --git a/equiadapt/pointcloud/canonicalization/continuous_group.py b/equiadapt/pointcloud/canonicalization/continuous_group.py index 42a9e0b..d47fd53 100644 --- a/equiadapt/pointcloud/canonicalization/continuous_group.py +++ b/equiadapt/pointcloud/canonicalization/continuous_group.py @@ -1,21 +1,22 @@ # Note that for now we have only implemented canonicalizatin for rotation in the pointcloud setting. # This is meant to be a proof of concept and we are happy to receive contribution to extend this to other group actions. +from omegaconf import DictConfig import torch from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization from equiadapt.common.utils import gram_schmidt -from typing import Any, List, Tuple, Union +from typing import Any, List, Tuple, Union, Optional class ContinuousGroupPointcloudCanonicalization(ContinuousGroupCanonicalization): def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, ): super().__init__(canonicalization_network) - def get_groupelement(self, x: torch.Tensor): + def get_groupelement(self, x: torch.Tensor) -> dict: """ This method takes the input image and maps it to the group element @@ -29,7 +30,7 @@ def get_groupelement(self, x: torch.Tensor): raise NotImplementedError("get_groupelement method is not implemented") def canonicalize( - self, x: torch.Tensor, targets: List = None, **kwargs: Any + self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an image as input and @@ -63,11 +64,11 @@ class EquivariantPointcloudCanonicalization(ContinuousGroupPointcloudCanonicaliz def __init__( self, canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, + canonicalization_hyperparams: DictConfig, ): super().__init__(canonicalization_network, canonicalization_hyperparams) - def get_groupelement(self, x: torch.Tensor): + def get_groupelement(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """ This method takes the input image and maps it to the group element @@ -94,6 +95,6 @@ def get_groupelement(self, x: torch.Tensor): group_element_dict["rotation"] ) - self.canonicalization_info_dict["group_element"] = group_element_dict + self.canonicalization_info_dict["group_element"] = group_element_dict # type: ignore return group_element_dict diff --git a/equiadapt/pointcloud/canonicalization_networks/__init__.py b/equiadapt/pointcloud/canonicalization_networks/__init__.py index c57796a..e2cae21 100644 --- a/equiadapt/pointcloud/canonicalization_networks/__init__.py +++ b/equiadapt/pointcloud/canonicalization_networks/__init__.py @@ -12,7 +12,6 @@ VNBilinear, VNLeakyReLU, VNLinear, - VNLinearAndLeakyReLU, VNLinearLeakyReLU, VNMaxPool, VNSoftplus, @@ -26,7 +25,6 @@ "VNBilinear", "VNLeakyReLU", "VNLinear", - "VNLinearAndLeakyReLU", "VNLinearLeakyReLU", "VNMaxPool", "VNSmall", diff --git a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py index 82d0f33..0b147b2 100644 --- a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py +++ b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py @@ -6,9 +6,11 @@ VNBatchNorm, mean_pool, ) +from omegaconf import DictConfig +from typing import Optional -def knn(x, k): +def knn(x: torch.Tensor, k: int) -> torch.Tensor: inner = -2 * torch.matmul(x.transpose(2, 1), x) xx = torch.sum(x**2, dim=1, keepdim=True) pairwise_distance = -xx - inner - xx.transpose(2, 1) @@ -17,7 +19,9 @@ def knn(x, k): return idx -def get_graph_feature_cross(x, k=20, idx=None): +def get_graph_feature_cross( + x: torch.Tensor, k: int = 20, idx: Optional[torch.Tensor] = None +) -> torch.Tensor: batch_size = x.size(0) num_points = x.size(3) x = x.view(batch_size, -1, num_points) @@ -47,7 +51,7 @@ def get_graph_feature_cross(x, k=20, idx=None): class VNSmall(torch.nn.Module): - def __init__(self, hyperparams): + def __init__(self, hyperparams: DictConfig): super().__init__() self.n_knn = hyperparams.n_knn self.pooling = hyperparams.pooling @@ -60,14 +64,14 @@ def __init__(self, hyperparams): if self.pooling == "max": self.pool = VNMaxPool(64 // 3) elif self.pooling == "mean": - self.pool = mean_pool + self.pool = mean_pool # type: ignore else: raise ValueError(f"Pooling type {self.pooling} not supported") # Wild idea -- Just use a linear layer to predict the output # self.conv = VNLinear(3, 12 // 3) - def forward(self, point_cloud): + def forward(self, point_cloud: torch.Tensor) -> torch.Tensor: point_cloud = point_cloud.unsqueeze(1) feat = get_graph_feature_cross(point_cloud, k=self.n_knn) diff --git a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py index 7463d88..4561c01 100644 --- a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py +++ b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py @@ -4,16 +4,17 @@ import torch import torch.nn as nn +from typing import Tuple EPS = 1e-6 class VNLinear(nn.Module): - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int): super(VNLinear, self).__init__() self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -22,13 +23,13 @@ def forward(self, x): class VNBilinear(nn.Module): - def __init__(self, in_channels1, in_channels2, out_channels): + def __init__(self, in_channels1: int, in_channels2: int, out_channels: int): super(VNBilinear, self).__init__() self.map_to_feat = nn.Bilinear( in_channels1, in_channels2, out_channels, bias=False ) - def forward(self, x, labels): + def forward(self, x: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -38,7 +39,12 @@ def forward(self, x, labels): class VNSoftplus(nn.Module): - def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.0): + def __init__( + self, + in_channels: int, + share_nonlinearity: bool = False, + negative_slope: float = 0.0, + ): super(VNSoftplus, self).__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) @@ -46,7 +52,7 @@ def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.0): self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) self.negative_slope = negative_slope - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -68,7 +74,12 @@ def forward(self, x): class VNLeakyReLU(nn.Module): - def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2): + def __init__( + self, + in_channels: int, + share_nonlinearity: bool = False, + negative_slope: float = 0.2, + ): super(VNLeakyReLU, self).__init__() if share_nonlinearity: self.map_to_dir = nn.Linear(in_channels, 1, bias=False) @@ -76,7 +87,7 @@ def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2): self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) self.negative_slope = negative_slope - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -93,11 +104,11 @@ def forward(self, x): class VNLinearLeakyReLU(nn.Module): def __init__( self, - in_channels, - out_channels, - dim=5, - share_nonlinearity=False, - negative_slope=0.2, + in_channels: int, + out_channels: int, + dim: int = 5, + share_nonlinearity: bool = False, + negative_slope: float = 0.2, ): super(VNLinearLeakyReLU, self).__init__() self.dim = dim @@ -111,7 +122,7 @@ def __init__( else: self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -130,64 +141,25 @@ def forward(self, x): return x_out -class VNLinearAndLeakyReLU(nn.Module): - def __init__( - self, - in_channels, - out_channels, - dim=5, - share_nonlinearity=False, - use_batchnorm="norm", - negative_slope=0.2, - ): - super(VNLinearLeakyReLU, self).__init__() - self.dim = dim - self.share_nonlinearity = share_nonlinearity - self.use_batchnorm = use_batchnorm - self.negative_slope = negative_slope - - self.linear = VNLinear(in_channels, out_channels) - self.leaky_relu = VNLeakyReLU( - out_channels, - share_nonlinearity=share_nonlinearity, - negative_slope=negative_slope, - ) - - # BatchNorm - self.use_batchnorm = use_batchnorm - if use_batchnorm != "none": - self.batchnorm = VNBatchNorm(out_channels, dim=dim, mode=use_batchnorm) - - def forward(self, x): - """ - x: point features of shape [B, N_feat, 3, N_samples, ...] - """ - # Conv - x = self.linear(x) - # InstanceNorm - if self.use_batchnorm != "none": - x = self.batchnorm(x) - # LeakyReLU - x_out = self.leaky_relu(x) - return x_out - - class VNBatchNorm(nn.Module): - def __init__(self, num_features, dim): + def __init__(self, num_features: int, dim: int): super(VNBatchNorm, self).__init__() self.dim = dim if dim == 3 or dim == 4: - self.bn = nn.BatchNorm1d(num_features) + self.bn1d = nn.BatchNorm1d(num_features) elif dim == 5: - self.bn = nn.BatchNorm2d(num_features) + self.bn2d = nn.BatchNorm2d(num_features) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ # norm = torch.sqrt((x*x).sum(2)) norm = torch.norm(x, dim=2) + EPS - norm_bn = self.bn(norm) + if self.dim == 3 or self.dim == 4: + norm_bn = self.bn1d(norm) + elif self.dim == 5: + norm_bn = self.bn2d(norm) norm = norm.unsqueeze(2) norm_bn = norm_bn.unsqueeze(2) x = x / norm * norm_bn @@ -196,11 +168,11 @@ def forward(self, x): class VNMaxPool(nn.Module): - def __init__(self, in_channels): + def __init__(self, in_channels: int): super(VNMaxPool, self).__init__() self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -212,18 +184,18 @@ def forward(self, x): return x_max -def mean_pool(x, dim=-1, keepdim=False): +def mean_pool(x: torch.Tensor, dim: int = -1, keepdim: bool = False) -> torch.Tensor: return x.mean(dim=dim, keepdim=keepdim) class VNStdFeature(nn.Module): def __init__( self, - in_channels, - dim=4, - normalize_frame=False, - share_nonlinearity=False, - negative_slope=0.2, + in_channels: int, + dim: int = 4, + normalize_frame: bool = False, + share_nonlinearity: bool = False, + negative_slope: float = 0.2, ): super(VNStdFeature, self).__init__() self.dim = dim @@ -248,7 +220,7 @@ def __init__( else: self.vn_lin = nn.Linear(in_channels // 4, 3, bias=False) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ x: point features of shape [B, N_feat, 3, N_samples, ...] """ @@ -258,15 +230,12 @@ def forward(self, x): z0 = self.vn_lin(z0.transpose(1, -1)).transpose(1, -1) if self.normalize_frame: - # make z0 orthogonal. u2 = v2 - proj_u1(v2) v1 = z0[:, 0, :] - # u1 = F.normalize(v1, dim=1) - v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True)) + v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True)) # type: ignore u1 = v1 / (v1_norm + EPS) v2 = z0[:, 1, :] - v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1 - # u2 = F.normalize(u2, dim=1) - v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True)) + v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1 # type: ignore + v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True)) # type: ignore u2 = v2 / (v2_norm + EPS) # compute the cross product of the two output vectors diff --git a/examples/pointcloud/classification/model_utils.py b/examples/pointcloud/classification/model_utils.py index 4b6f2a4..7e40a72 100644 --- a/examples/pointcloud/classification/model_utils.py +++ b/examples/pointcloud/classification/model_utils.py @@ -4,7 +4,7 @@ def get_prediction_network( architecture: str, - hyperparams: DictConfig = None, + hyperparams: DictConfig, ): """ The function returns the prediction network based on the architecture type diff --git a/examples/pointcloud/classification/train.py b/examples/pointcloud/classification/train.py index e0e4ca7..4ecc762 100644 --- a/examples/pointcloud/classification/train.py +++ b/examples/pointcloud/classification/train.py @@ -50,7 +50,7 @@ def train_pointcloud(hyperparams: DictConfig): # initialize wandb wandb.init( - config=OmegaConf.to_container(hyperparams, resolve=True), + config=OmegaConf.to_container(hyperparams, resolve=True), # type: ignore entity=hyperparams["wandb"]["wandb_entity"], project=hyperparams["wandb"]["wandb_project"], dir=hyperparams["wandb"]["wandb_dir"], diff --git a/examples/pointcloud/classification/train_utils.py b/examples/pointcloud/classification/train_utils.py index 1842c42..99045cf 100644 --- a/examples/pointcloud/classification/train_utils.py +++ b/examples/pointcloud/classification/train_utils.py @@ -1,6 +1,6 @@ import dotenv from omegaconf import DictConfig -from typing import Dict, Optional +from typing import Optional import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping @@ -29,21 +29,7 @@ def get_model_pipeline(hyperparams: DictConfig): def get_trainer( hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger ): - if hyperparams.experiment.run_mode == "auto_tune": - trainer = pl.Trainer( - max_epochs=hyperparams.experiment.num_epochs, - accelerator="auto", - auto_scale_batch_size=True, - auto_lr_find=True, - logger=wandb_logger, - callbacks=callbacks, - deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, - devices=hyperparams.experiment.num_gpus, - strategy="ddp", - ) - - elif hyperparams.experiment.run_mode == "dryrun": + if hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, @@ -89,15 +75,15 @@ def get_callbacks(hyperparams: DictConfig): return [checkpoint_callback, early_stop_metric_callback] -def get_recursive_hyperparams_identifier(hyperparams: Dict): +def get_recursive_hyperparams_identifier(hyperparams: DictConfig): # get the identifier for the canonicalization network hyperparameters # recursively go through the dictionary and get the values and concatenate them identifier = "" for key, value in hyperparams.items(): if isinstance(value, DictConfig): - identifier += f"_{get_recursive_hyperparams_identifier(value)}_" + identifier += f"_{get_recursive_hyperparams_identifier(value)}_" # type: ignore else: - identifier += f"_{key}_{value}_" + identifier += f"_{key}_{value}_" # type: ignore return identifier diff --git a/examples/pointcloud/part_segmentation/model_utils.py b/examples/pointcloud/part_segmentation/model_utils.py index 92c4a48..40ec7e0 100644 --- a/examples/pointcloud/part_segmentation/model_utils.py +++ b/examples/pointcloud/part_segmentation/model_utils.py @@ -4,7 +4,7 @@ def get_prediction_network( architecture: str, - hyperparams: DictConfig = None, + hyperparams: DictConfig, ): """ The function returns the prediction network based on the architecture type diff --git a/examples/pointcloud/part_segmentation/train.py b/examples/pointcloud/part_segmentation/train.py index 2e7bd77..c9e67df 100644 --- a/examples/pointcloud/part_segmentation/train.py +++ b/examples/pointcloud/part_segmentation/train.py @@ -50,7 +50,7 @@ def train_pointcloud(hyperparams: DictConfig): # initialize wandb wandb.init( - config=OmegaConf.to_container(hyperparams, resolve=True), + config=OmegaConf.to_container(hyperparams, resolve=True), # type: ignore entity=hyperparams["wandb"]["wandb_entity"], project=hyperparams["wandb"]["wandb_project"], dir=hyperparams["wandb"]["wandb_dir"], diff --git a/examples/pointcloud/part_segmentation/train_utils.py b/examples/pointcloud/part_segmentation/train_utils.py index f8463fb..f3c360c 100644 --- a/examples/pointcloud/part_segmentation/train_utils.py +++ b/examples/pointcloud/part_segmentation/train_utils.py @@ -1,6 +1,6 @@ import dotenv from omegaconf import DictConfig -from typing import Dict, Optional +from typing import Optional import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping @@ -30,21 +30,7 @@ def get_model_pipeline(hyperparams: DictConfig): def get_trainer( hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger ): - if hyperparams.experiment.run_mode == "auto_tune": - trainer = pl.Trainer( - max_epochs=hyperparams.experiment.num_epochs, - accelerator="auto", - auto_scale_batch_size=True, - auto_lr_find=True, - logger=wandb_logger, - callbacks=callbacks, - deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, - devices=hyperparams.experiment.num_gpus, - strategy="ddp", - ) - - elif hyperparams.experiment.run_mode == "dryrun": + if hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, @@ -90,15 +76,15 @@ def get_callbacks(hyperparams: DictConfig): return [checkpoint_callback, early_stop_metric_callback] -def get_recursive_hyperparams_identifier(hyperparams: Dict): +def get_recursive_hyperparams_identifier(hyperparams: DictConfig): # get the identifier for the canonicalization network hyperparameters # recursively go through the dictionary and get the values and concatenate them identifier = "" for key, value in hyperparams.items(): if isinstance(value, DictConfig): - identifier += f"_{get_recursive_hyperparams_identifier(value)}_" + identifier += f"_{get_recursive_hyperparams_identifier(value)}_" # type: ignore else: - identifier += f"_{key}_{value}_" + identifier += f"_{key}_{value}_" # type: ignore return identifier