From 64ee23a614ed34a928ec311462477b8b7e54526b Mon Sep 17 00:00:00 2001 From: KanishkNavale Date: Sun, 4 Feb 2024 17:18:14 +0100 Subject: [PATCH] Modified: pydantic declarations --- heimdall/datatypes/camera.py | 6 ++++++ heimdall/datatypes/image.py | 8 ++++++++ heimdall/datatypes/pose.py | 6 ++++++ 3 files changed, 20 insertions(+) diff --git a/heimdall/datatypes/camera.py b/heimdall/datatypes/camera.py index 3f612ce..43fcdb9 100644 --- a/heimdall/datatypes/camera.py +++ b/heimdall/datatypes/camera.py @@ -10,6 +10,9 @@ class CameraIntrinsics(BaseDataClass): cx: float cy: float + def __init__(self, fx: float, fy: float, cx: float, cy: float) -> None: + super().__init__(fx=fx, fy=fy, cx=cx, cy=cy) + @property def matrix(self) -> np.ndarray: return np.array([self.fx, self.fy, self.cx, self.cy]) @@ -18,3 +21,6 @@ def matrix(self) -> np.ndarray: class CameraInformation(BaseDataClass): intrinsics: CameraIntrinsics pose: Pose + + def __init__(self, intrinsics: CameraIntrinsics, pose: Pose) -> None: + super().__init__(intrinsics=intrinsics, pose=pose) diff --git a/heimdall/datatypes/image.py b/heimdall/datatypes/image.py index 04449f0..4acc363 100644 --- a/heimdall/datatypes/image.py +++ b/heimdall/datatypes/image.py @@ -9,3 +9,11 @@ class Image(BaseDataClass): rgb: np.ndarray | torch.Tensor depth: np.ndarray | torch.Tensor camera_information: CameraInformation + + def __init__( + self, + rgb: np.ndarray | torch.Tensor, + depth: np.ndarray | torch.Tensor, + camera_information: CameraInformation, + ) -> None: + super().__init__(rgb=rgb, depth=depth, camera_information=camera_information) diff --git a/heimdall/datatypes/pose.py b/heimdall/datatypes/pose.py index 049fb8f..afb7530 100644 --- a/heimdall/datatypes/pose.py +++ b/heimdall/datatypes/pose.py @@ -9,6 +9,9 @@ class SO3(BaseDataClass): matrix: torch.Tensor + def __init__(self, matrix: torch.Tensor) -> None: + super().__init__(matrix=matrix) + def __post_init__(self, *kwargs) -> None: if self.matrix.ndim > 3 or self.matrix.ndim < 2: raise ValueError("SO3 matrix must be of shape (3, 3) or (N, 3, 3))") @@ -52,6 +55,9 @@ class Pose(BaseDataClass): translation: torch.Tensor SO3: SO3 + def __init__(self, translation: torch.Tensor, SO3: SO3) -> None: + super().__init__(translation=translation, SO3=SO3) + def __post_init__(self, *kwargs) -> None: if self.translation.ndim > 2 or self.translation.ndim < 1: raise ValueError("Translation vector must be of shape (3,) or (N, 3)")