Skip to content

Commit

Permalink
Modified: pydantic declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
KanishkNavale committed Feb 4, 2024
1 parent ab1557e commit 64ee23a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions heimdall/datatypes/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class CameraIntrinsics(BaseDataClass):
cx: float
cy: float

def __init__(self, fx: float, fy: float, cx: float, cy: float) -> None:
super().__init__(fx=fx, fy=fy, cx=cx, cy=cy)

@property
def matrix(self) -> np.ndarray:
return np.array([self.fx, self.fy, self.cx, self.cy])
Expand All @@ -18,3 +21,6 @@ def matrix(self) -> np.ndarray:
class CameraInformation(BaseDataClass):
intrinsics: CameraIntrinsics
pose: Pose

def __init__(self, intrinsics: CameraIntrinsics, pose: Pose) -> None:
super().__init__(intrinsics=intrinsics, pose=pose)
8 changes: 8 additions & 0 deletions heimdall/datatypes/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ class Image(BaseDataClass):
rgb: np.ndarray | torch.Tensor
depth: np.ndarray | torch.Tensor
camera_information: CameraInformation

def __init__(
self,
rgb: np.ndarray | torch.Tensor,
depth: np.ndarray | torch.Tensor,
camera_information: CameraInformation,
) -> None:
super().__init__(rgb=rgb, depth=depth, camera_information=camera_information)
6 changes: 6 additions & 0 deletions heimdall/datatypes/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
class SO3(BaseDataClass):
matrix: torch.Tensor

def __init__(self, matrix: torch.Tensor) -> None:
super().__init__(matrix=matrix)

def __post_init__(self, *kwargs) -> None:
if self.matrix.ndim > 3 or self.matrix.ndim < 2:
raise ValueError("SO3 matrix must be of shape (3, 3) or (N, 3, 3))")
Expand Down Expand Up @@ -52,6 +55,9 @@ class Pose(BaseDataClass):
translation: torch.Tensor
SO3: SO3

def __init__(self, translation: torch.Tensor, SO3: SO3) -> None:
super().__init__(translation=translation, SO3=SO3)

def __post_init__(self, *kwargs) -> None:
if self.translation.ndim > 2 or self.translation.ndim < 1:
raise ValueError("Translation vector must be of shape (3,) or (N, 3)")
Expand Down

0 comments on commit 64ee23a

Please sign in to comment.