diff --git a/diffdrr/pose.py b/diffdrr/pose.py index ae4903c7c..1cbdadd59 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -19,7 +19,7 @@ class RigidTransform(torch.nn.Module): def __init__(self, matrix): super().__init__() - if specimen.extrinsic.matrix.dim() == 2: + if matrix.dim() == 2: matrix = matrix.unsqueeze(0) self.register_buffer("matrix", matrix) @@ -39,7 +39,7 @@ def inverse(self): t = self.matrix[..., :3, 3] Rinv = R.mT tinv = -torch.einsum("bij, bj -> bi", Rinv, t) - make_matrix(Rinv, tinv) + matrix = make_matrix(Rinv, tinv) return RigidTransform(matrix) def compose(self, T): diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 0ae9716c1..6f85a1dfa 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -95,7 +95,7 @@ "\n", " def __init__(self, matrix):\n", " super().__init__()\n", - " if specimen.extrinsic.matrix.dim() == 2:\n", + " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", " self.register_buffer(\"matrix\", matrix)\n", "\n", @@ -115,7 +115,7 @@ " t = self.matrix[..., :3, 3]\n", " Rinv = R.mT\n", " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", - " make_matrix(Rinv, tinv)\n", + " matrix = make_matrix(Rinv, tinv)\n", " return RigidTransform(matrix)\n", "\n", " def compose(self, T):\n",