Skip to content

Commit

Permalink
Add tests for the eigenvalue solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Aug 8, 2024
1 parent 49e72f5 commit 3b1eec5
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
1 change: 0 additions & 1 deletion diptorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def eigvalsh3(
eig3 = q + 2 * p * phi.cos()
eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos()
eig2 = 3 * q - eig1 - eig3
print(eig1.shape, eig2.shape, eig3.shape)
return torch.concat([eig1, eig2, eig3], dim=1)

# %% ../notebooks/01_linalg.ipynb 9
Expand Down
124 changes: 123 additions & 1 deletion notebooks/01_linalg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@
" eig3 = q + 2 * p * phi.cos()\n",
" eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos()\n",
" eig2 = 3 * q - eig1 - eig3\n",
" print(eig1.shape, eig2.shape, eig3.shape)\n",
" return torch.concat([eig1, eig2, eig3], dim=1)"
]
},
Expand All @@ -181,6 +180,129 @@
" )"
]
},
{
"cell_type": "markdown",
"id": "339315e4-4234-4a83-bb49-f5fa302d0371",
"metadata": {},
"source": [
"### Testing\n",
"\n",
"Our closed-form solvers are numerically equivalent to `torch.linalg.eigvalsh`.\n",
"Unsurprisingly, our implementation is also much faster than PyTorch's solver."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "532bfdff-1dd5-4e4a-8902-e6ff20b77590",
"metadata": {},
"outputs": [],
"source": [
"# Test the 2×2 implementation is equivalent to torch's eigvalsh\n",
"A = torch.randn(100, 2, 2, 30, 30)\n",
"A = A + A.transpose(1, 2) # Make A Hermitian\n",
"\n",
"torch.testing.assert_close(\n",
" eigvalsh(A),\n",
" torch.linalg.eigvalsh(A.permute(0, -2, -1, 1, 2)).permute(0, -1, 1, 2),\n",
" rtol=1e-5,\n",
" atol=1e-4,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b654569-01b7-4d66-8653-e38ad8ff3629",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.09 ms ± 157 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit eigvalsh(A)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58d45054-3dac-4ac2-8cd9-786541854deb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"46.7 ms ± 836 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"B = A.permute(0, -2, -1, 1, 2)\n",
"%timeit torch.linalg.eigvalsh(B)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11f5a9e6-a1e1-43cd-93c7-afa87de88b92",
"metadata": {},
"outputs": [],
"source": [
"# Test the 3×3 implementation is equivalent to torch's eigh\n",
"A = torch.randn(100, 3, 3, 30, 30)\n",
"A = A + A.transpose(1, 2) # Make A Hermitian\n",
"\n",
"torch.testing.assert_close(\n",
" eigvalsh(A),\n",
" torch.linalg.eigvalsh(A.permute(0, -2, -1, 1, 2)).permute(0, -1, 1, 2),\n",
" rtol=1e-5,\n",
" atol=1e-4,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3eac5bd1-9704-43c7-9aba-131e95790fcd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.52 ms ± 51.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit eigvalsh(A)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ceecc93-e204-41fe-9c3a-3cf2bbaccb7e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"98.2 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"B = A.permute(0, -2, -1, 1, 2)\n",
"%timeit torch.linalg.eigvalsh(B)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 3b1eec5

Please sign in to comment.