From 3b1eec5df0814bb0b6af8caf8ba6346d504a4eda Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Thu, 8 Aug 2024 10:39:09 -0400 Subject: [PATCH] Add tests for the eigenvalue solvers --- diptorch/linalg.py | 1 - notebooks/01_linalg.ipynb | 124 +++++++++++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/diptorch/linalg.py b/diptorch/linalg.py index 1843b4a..b849490 100644 --- a/diptorch/linalg.py +++ b/diptorch/linalg.py @@ -84,7 +84,6 @@ def eigvalsh3( eig3 = q + 2 * p * phi.cos() eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos() eig2 = 3 * q - eig1 - eig3 - print(eig1.shape, eig2.shape, eig3.shape) return torch.concat([eig1, eig2, eig3], dim=1) # %% ../notebooks/01_linalg.ipynb 9 diff --git a/notebooks/01_linalg.ipynb b/notebooks/01_linalg.ipynb index b3df068..d0abf13 100644 --- a/notebooks/01_linalg.ipynb +++ b/notebooks/01_linalg.ipynb @@ -159,7 +159,6 @@ " eig3 = q + 2 * p * phi.cos()\n", " eig1 = q + 2 * p * (phi + 2 * torch.pi / 3).cos()\n", " eig2 = 3 * q - eig1 - eig3\n", - " print(eig1.shape, eig2.shape, eig3.shape)\n", " return torch.concat([eig1, eig2, eig3], dim=1)" ] }, @@ -181,6 +180,129 @@ " )" ] }, + { + "cell_type": "markdown", + "id": "339315e4-4234-4a83-bb49-f5fa302d0371", + "metadata": {}, + "source": [ + "### Testing\n", + "\n", + "Our closed-form solvers are numerically equivalent to `torch.linalg.eigvalsh`.\n", + "Unsurprisingly, our implementation is also much faster than PyTorch's solver." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "532bfdff-1dd5-4e4a-8902-e6ff20b77590", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the 2×2 implementation is equivalent to torch's eigvalsh\n", + "A = torch.randn(100, 2, 2, 30, 30)\n", + "A = A + A.transpose(1, 2) # Make A Hermitian\n", + "\n", + "torch.testing.assert_close(\n", + " eigvalsh(A),\n", + " torch.linalg.eigvalsh(A.permute(0, -2, -1, 1, 2)).permute(0, -1, 1, 2),\n", + " rtol=1e-5,\n", + " atol=1e-4,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b654569-01b7-4d66-8653-e38ad8ff3629", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.09 ms ± 157 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit eigvalsh(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58d45054-3dac-4ac2-8cd9-786541854deb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "46.7 ms ± 836 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "B = A.permute(0, -2, -1, 1, 2)\n", + "%timeit torch.linalg.eigvalsh(B)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11f5a9e6-a1e1-43cd-93c7-afa87de88b92", + "metadata": {}, + "outputs": [], + "source": [ + "# Test the 3×3 implementation is equivalent to torch's eigh\n", + "A = torch.randn(100, 3, 3, 30, 30)\n", + "A = A + A.transpose(1, 2) # Make A Hermitian\n", + "\n", + "torch.testing.assert_close(\n", + " eigvalsh(A),\n", + " torch.linalg.eigvalsh(A.permute(0, -2, -1, 1, 2)).permute(0, -1, 1, 2),\n", + " rtol=1e-5,\n", + " atol=1e-4,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3eac5bd1-9704-43c7-9aba-131e95790fcd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.52 ms ± 51.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "%timeit eigvalsh(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ceecc93-e204-41fe-9c3a-3cf2bbaccb7e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "98.2 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "B = A.permute(0, -2, -1, 1, 2)\n", + "%timeit torch.linalg.eigvalsh(B)" + ] + }, { "cell_type": "code", "execution_count": null,