Skip to content

Commit

Permalink
Fix divide by zero error in eigvalsh3
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Aug 8, 2024
1 parent 5c05c04 commit 309db93
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
18 changes: 10 additions & 8 deletions diptorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def eigvalsh3(
jj: torch.Tensor,
jk: torch.Tensor,
kk: torch.Tensor,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Compute the eigenvalues of a batched Hermitian 3×3 tensor
Expand All @@ -70,14 +71,15 @@ def eigvalsh3(
Returns eigenvalues in a tensor with shape [1 3 H W D]
sorted in ascending order.
"""
q = (ii + jj + kk) / 3
p1 = torch.concat([ij, ik, jk], dim=1).square().sum(1, keepdim=True)
p2 = (torch.concat([ii, jj, kk], dim=1) - q).square().sum(
dim=1, keepdim=True
) + 2 * p1
p = (p2 / 6).sqrt()

r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / p.pow(3) / 2
diag = torch.concat([ii, jj, kk], dim=1)
triu = torch.concat([ij, ik, jk], dim=1)

q = diag.sum(dim=1, keepdim=True) / 3
p1 = triu.square().sum(dim=1, keepdim=True)
p2 = (diag - q).square().sum(dim=1, keepdim=True)
p = ((2 * p1 + p2) / 6).sqrt()

r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / (p.pow(3) + eps) / 2
r = r.clamp(-1, 1)
phi = r.arccos() / 3

Expand Down
26 changes: 17 additions & 9 deletions notebooks/01_linalg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
" jj: torch.Tensor,\n",
" jk: torch.Tensor,\n",
" kk: torch.Tensor,\n",
" eps: float = 1e-8,\n",
") -> torch.Tensor:\n",
" \"\"\"\n",
" Compute the eigenvalues of a batched Hermitian 3×3 tensor\n",
Expand All @@ -147,12 +148,15 @@
" Returns eigenvalues in a tensor with shape [1 3 H W D]\n",
" sorted in ascending order.\n",
" \"\"\"\n",
" q = (ii + jj + kk) / 3\n",
" p1 = torch.concat([ij, ik, jk], dim=1).square().sum(1, keepdim=True)\n",
" p2 = (torch.concat([ii, jj, kk], dim=1) - q).square().sum(dim=1, keepdim=True) + 2 * p1\n",
" p = (p2 / 6).sqrt()\n",
" diag = torch.concat([ii, jj, kk], dim=1)\n",
" triu = torch.concat([ij, ik, jk], dim=1)\n",
" \n",
" q = diag.sum(dim=1, keepdim=True) / 3\n",
" p1 = triu.square().sum(dim=1, keepdim=True)\n",
" p2 = (diag - q).square().sum(dim=1, keepdim=True)\n",
" p = ((2 * p1 + p2) / 6).sqrt()\n",
"\n",
" r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / p.pow(3) / 2\n",
" r = deth3(ii - q, ij, ik, jj - q, jk, kk - q) / (p.pow(3) + eps) / 2\n",
" r = r.clamp(-1, 1)\n",
" phi = r.arccos() / 3\n",
"\n",
Expand Down Expand Up @@ -220,11 +224,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"3.09 ms ± 157 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"1.61 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"# Time diptorch's implementation\n",
"%timeit eigvalsh(A)"
]
},
Expand All @@ -238,11 +243,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"46.7 ms ± 836 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"43.7 ms ± 40.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"# Time torch's implementation\n",
"B = A.permute(0, -2, -1, 1, 2)\n",
"%timeit torch.linalg.eigvalsh(B)"
]
Expand Down Expand Up @@ -276,11 +282,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"5.52 ms ± 51.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
"5.78 ms ± 11 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"# Time diptorch's implementation\n",
"%timeit eigvalsh(A)"
]
},
Expand All @@ -294,11 +301,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"98.2 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"97.8 ms ± 766 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"# Time torch's implementation\n",
"B = A.permute(0, -2, -1, 1, 2)\n",
"%timeit torch.linalg.eigvalsh(B)"
]
Expand Down

0 comments on commit 309db93

Please sign in to comment.