From 309db937990e90f7bf5597cb4b06bff5bf073476 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 8 Aug 2024 11:42:59 -0400 Subject: [PATCH] Fix divide by zero error in eigvalsh3 --- diptorch/linalg.py | 18 ++++++++++-------- notebooks/01_linalg.ipynb | 26 +++++++++++++++++--------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/diptorch/linalg.py b/diptorch/linalg.py index b849490..a9e54fa 100644 --- a/diptorch/linalg.py +++ b/diptorch/linalg.py @@ -62,6 +62,7 @@ def eigvalsh3( jj: torch.Tensor, jk: torch.Tensor, kk: torch.Tensor, + eps: float = 1e-8, ) -> torch.Tensor: """ Compute the eigenvalues of a batched Hermitian 3×3 tensor @@ -70,14 +71,15 @@ def eigvalsh3( Returns eigenvalues in a tensor with shape [1 3 H W D] sorted in ascending order. """ - q = (ii + jj + kk) / 3 - p1 = torch.concat([ij, ik, jk], dim=1).square().sum(1, keepdim=True) - p2 = (torch.concat([ii, jj, kk], dim=1) - q).square().sum( - dim=1, keepdim=True - ) + 2 * p1 - p = (p2 / 6).sqrt() - - r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / p.pow(3) / 2 + diag = torch.concat([ii, jj, kk], dim=1) + triu = torch.concat([ij, ik, jk], dim=1) + + q = diag.sum(dim=1, keepdim=True) / 3 + p1 = triu.square().sum(dim=1, keepdim=True) + p2 = (diag - q).square().sum(dim=1, keepdim=True) + p = ((2 * p1 + p2) / 6).sqrt() + + r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / (p.pow(3) + eps) / 2 r = r.clamp(-1, 1) phi = r.arccos() / 3 diff --git a/notebooks/01_linalg.ipynb b/notebooks/01_linalg.ipynb index d0abf13..6052650 100644 --- a/notebooks/01_linalg.ipynb +++ b/notebooks/01_linalg.ipynb @@ -139,6 +139,7 @@ " jj: torch.Tensor,\n", " jk: torch.Tensor,\n", " kk: torch.Tensor,\n", + " eps: float = 1e-8,\n", ") -> torch.Tensor:\n", " \"\"\"\n", " Compute the eigenvalues of a batched Hermitian 3×3 tensor\n", @@ -147,12 +148,15 @@ " Returns eigenvalues in a tensor with shape [1 3 H W D]\n", " sorted in ascending order.\n", " \"\"\"\n", - " q = (ii + jj + kk) / 3\n", - " p1 = torch.concat([ij, ik, jk], dim=1).square().sum(1, keepdim=True)\n", - " p2 = (torch.concat([ii, jj, kk], dim=1) - q).square().sum(dim=1, keepdim=True) + 2 * p1\n", - " p = (p2 / 6).sqrt()\n", + " diag = torch.concat([ii, jj, kk], dim=1)\n", + " triu = torch.concat([ij, ik, jk], dim=1)\n", + " \n", + " q = diag.sum(dim=1, keepdim=True) / 3\n", + " p1 = triu.square().sum(dim=1, keepdim=True)\n", + " p2 = (diag - q).square().sum(dim=1, keepdim=True)\n", + " p = ((2 * p1 + p2) / 6).sqrt()\n", "\n", - " r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / p.pow(3) / 2\n", + " r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / (p.pow(3) + eps) / 2\n", " r = r.clamp(-1, 1)\n", " phi = r.arccos() / 3\n", "\n", @@ -220,11 +224,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.09 ms ± 157 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "1.61 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ + "# Time diptorch's implementation\n", "%timeit eigvalsh(A)" ] }, @@ -238,11 +243,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "46.7 ms ± 836 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "43.7 ms ± 40.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ + "# Time torch's implementation\n", "B = A.permute(0, -2, -1, 1, 2)\n", "%timeit torch.linalg.eigvalsh(B)" ] @@ -276,11 +282,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "5.52 ms ± 51.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "5.78 ms ± 11 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ + "# Time diptorch's implementation\n", "%timeit eigvalsh(A)" ] }, @@ -294,11 +301,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "98.2 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + "97.8 ms ± 766 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ + "# Time torch's implementation\n", "B = A.permute(0, -2, -1, 1, 2)\n", "%timeit torch.linalg.eigvalsh(B)" ]