Skip to content

Commit

Permalink
docs for equiadapt/common/utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sibasmarak committed Mar 13, 2024
1 parent 0fca345 commit 35ef33d
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions equiadapt/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
import torch

"""
This module contains utility functions and classes that are used for operations on Lie groups.
The module includes a function for the Gram-Schmidt process, which is used to orthogonalize a set of vectors. This function is implemented in a batch-wise manner, meaning it can process multiple sets of vectors at once.
The module also includes a class for parameterizing Lie groups and their representations.
This class supports several types of Lie groups, including the special orthogonal group (SO(n)),
the special Euclidean group (SE(n)), the orthogonal group (O(n)), and the Euclidean group (E(n)).
The class provides methods for generating the basis of the Lie group, as well as for computing the
group representation given a set of parameters.
Functions:
gram_schmidt(vectors: torch.Tensor) -> torch.Tensor
Classes:
LieParameterization
"""


def gram_schmidt(vectors: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -26,7 +44,8 @@ def gram_schmidt(vectors: torch.Tensor) -> torch.Tensor:


class LieParameterization(torch.nn.Module):
"""A class for parameterizing Lie groups and their representations for a single block.
"""
A class for parameterizing Lie groups and their representations for a single block.
Args:
group_type (str): The type of Lie group (e.g., 'SOn', 'SEn', 'On', 'En').
Expand All @@ -43,7 +62,8 @@ def __init__(self, group_type: str, group_dim: int):
self.group_dim = group_dim

def get_son_bases(self) -> torch.Tensor:
"""Generates the basis of the Lie group of SOn.
"""
Generates the basis of the Lie group of SOn.
Returns:
torch.Tensor: The son basis of shape (num_params, group_dim, group_dim).
Expand All @@ -62,7 +82,8 @@ def get_son_bases(self) -> torch.Tensor:
return son_bases

def get_son_rep(self, params: torch.Tensor) -> torch.Tensor:
"""Computes the representation for SOn group.
"""
Computes the representation for SOn group.
Args:
params (torch.Tensor): Input parameters of shape (batch_size, param_dim).
Expand Down Expand Up @@ -104,7 +125,8 @@ def get_on_rep(
return on_rep

def get_sen_rep(self, params: torch.Tensor) -> torch.Tensor:
"""Computes the representation for SEn group.
"""
Computes the representation for SEn group.
Args:
params (torch.Tensor): Input parameters of shape (batch_size, param_dim).
Expand All @@ -129,14 +151,6 @@ def get_sen_rep(self, params: torch.Tensor) -> torch.Tensor:
def get_en_rep(
self, params: torch.Tensor, reflect_indicators: torch.Tensor
) -> torch.Tensor:
"""Computes the representation for E(n) group.
Args:
params (torch.Tensor): Input parameters of shape (batch_size, param_dim).
Returns:
torch.Tensor: The representation of shape (batch_size, rep_dim, rep_dim).
"""
"""Computes the representation for E(n) group, including both rotations and translations.
Args:
Expand Down

0 comments on commit 35ef33d

Please sign in to comment.