diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..7da60d9 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,93 @@ +name: Publish torchTT to pypi + +on: push + +jobs: + build: + name: Build distribution + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish to PyPI + if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/torchTT # Replace with your PyPI project name + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: >- + Sign the Python 🐍 distribution 📦 with Sigstore + and upload them to GitHub Release + needs: + - publish-to-pypi + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v3.0.0 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create + '${{ github.ref_name }}' + --repo '${{ github.repository }}' + --notes "" + - name: Upload artifact signatures to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + # Upload to GitHub Release using the `gh` CLI. + # `dist/` contains the built packages, and the + # sigstore-produced signatures and certificates. + run: >- + gh release upload + '${{ github.ref_name }}' dist/** + --repo '${{ github.repository }}' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c91b190 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[build-system] +requires = ["setuptools>=61", "setuptools-scm>=8.0", "wheel", "torch>=1.7", "numpy>=1.18", "opt_einsum", "ninja", "scipy>=0.16"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.package-data] +torchtt = ["cpp/*"] + +[tool.setuptools] +py-modules = [] + +[project] +name = "torchTT" +version = "0.1" +description = "Tensor-Train decomposition in pytorch." +readme = "README.md" +requires-python = ">=3.7" +dependencies = [ + "torch>=1.7", + "numpy>=1.18", + "opt_einsum", + "scipy>=0.16", + "ninja" +] +license = {file = "LICENSE"} +authors = [ + { name = "Ion Gabriel Ion", email = "ion.ion.gabriel@gmail.com" } +] +keywords = ["pytorch", "tensor-train decomposition"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra -q" +testpaths = [ + "tests", +] \ No newline at end of file diff --git a/pyproject_toml b/pyproject_toml deleted file mode 100644 index ec2c853..0000000 --- a/pyproject_toml +++ /dev/null @@ -1,13 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel", "torch>=1.7", "numpy>=1.18", "opt_einsum"] -build-backend = "setuptools.build_meta" - -[project] -name = "torchTT" -version = "2.0" -authors = [ - {name = "Ion Gabriel Ion", email = "ion.ion.gabriel@gmail.com"}, -] -description = "Tensor-Train decomposition in pytorch." -readme = "README.md" -requires-python = ">=3.7" diff --git a/setup.py b/setup.py index baa347e..619dbdc 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,13 @@ from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext import platform +from warnings import warn + +try: + import torch.utils.cpp_extension + from torch.utils.cpp_extension import BuildExtension, CppExtension +except ImportError: + raise Exception("Torch must be installed before running this setup.") logo_ascii = """ _ _ _____ _____ @@ -10,62 +18,47 @@ """ -try: - from torch.utils.cpp_extension import BuildExtension, CppExtension -except: - raise Exception("Torch has to be installed first") - os_name = platform.system() -print() -print(logo_ascii) -print() - -def python_install(): - - import warnings - warnings.warn("\x1B[33m\nC++ implementation not available. Using pure Python.\n\033[0m") - - setup(name='torchTT', - version='2.0', - description='Tensor-Train decomposition in pytorch', - url='https://github.com/ion-g-ion/torchTT', - author='Ion Gabriel Ion', - author_email='ion.ion.gabriel@gmail.com', - license='MIT', - packages=['torchtt'], - install_requires=['numpy>=1.18','torch>=1.7','opt_einsum'], - test_suite='tests', - zip_safe=False) - +print("\n" + logo_ascii + "\n") -if os_name == 'Linux' or os_name == 'Darwin': +if os_name in ['Linux', 'Darwin']: try: - setup(name='torchTT', - version='2.0', - description='Tensor-Train decomposition in pytorch', - url='https://github.com/ion-g-ion/torchTT', - author='Ion Gabriel Ion', - author_email='ion.ion.gabriel@gmail.com', - license='MIT', - packages=['torchtt'], - install_requires=['pytest', 'numpy>=1.18','torch>=1.7','opt_einsum'], - ext_modules=[ - CppExtension('torchttcpp', ['cpp/cpp_ext.cpp'], extra_compile_args=['-lblas', '-llapack', '-std=c++17', '-Wno-c++11-narrowing', '-g', '-w', '-O3']), - ], - cmdclass={ - 'build_ext': BuildExtension - }, - test_suite='tests', - zip_safe=False, - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ]) - except: - python_install() -else: - python_install() - - + # setup( + # # cmdclass={'build_ext': build_ext}, + # ext_modules=[ + # Extension( + # name='torchttcpp', + # sources=['cpp/cpp_ext.cpp'], + # include_dirs=torch.utils.cpp_extension.include_paths()+["cpp"], + # libray_dirs = torch.utils.cpp_extension.library_paths(), + # language='c++', + # extra_compile_args=[ + # '-lblas', '-llapack', '-std=c++17', + # '-Wno-c++11-narrowing', '-g', '-w', '-O3' + # ]) + # ] + # ) + + setup( + cmdclass={'build_ext': BuildExtension}, + ext_modules=[ + CppExtension( + 'torchttcpp', + ['cpp/cpp_ext.cpp'], + include_dirs=["cpp"], + extra_compile_args=[ + '-std=c++17', + '-Wno-c++11-narrowing', '-g', '-w', '-O3' + ], + is_python_module=True # Ensures linking with PyTorch's C++ libraries + ) + ], + ) + except Exception as e: + warn("\x1B[33m\nC++ implementation not available. Falling back to pure Python.\n\033[0m") + print(f"Error: {e}") + setup() +else: + warn("\x1B[33m\nC++ implementation not supported on this OS. Falling back to pure Python.\n\033[0m") + setup() diff --git a/torchtt/cpp/BLAS.h b/torchtt/cpp/BLAS.h new file mode 100644 index 0000000..a6828e8 --- /dev/null +++ b/torchtt/cpp/BLAS.h @@ -0,0 +1,238 @@ +#pragma once + +#include +#include + + +extern "C"{ + int dgemm_(char *, char *, int64_t *, int64_t *, int64_t *, double *, double *, int64_t *, double *, int64_t *, double *, double *, int64_t *); + int sgemm_(char *, char *, int64_t *, int64_t *, int64_t *, float *, float *, int64_t *, float *, int64_t *, float *, float *, int64_t *); + double dnrm2_(int64_t *, double *, int64_t *); + float snrm2_(int64_t *, float *, int64_t *); + int daxpy_(int64_t*, double*, double*, int64_t*, double*, int64_t*); + int saxpy_(int64_t*, float*, float*, int64_t*, float*, int64_t*); + void dscal_(int64_t*, double*, double*, int64_t*); + void sscal_(int64_t*, float*, float*, int64_t*); + double ddot_(int64_t*, double*, int64_t*, double*, int64_t*); + float sdot_(int64_t*, float*, int64_t*, float*, int64_t*); + + void dcopy_(int64_t*, double*, int64_t*, double*, int64_t*); + void scopy_(int64_t*, float*, int64_t*, float*, int64_t*); + + void dgesv_(int64_t *, int64_t *, double *, int64_t *, int64_t *, double *, int64_t *, int64_t *); + void sgesv_(int64_t *, int64_t *, float *, int64_t *, int64_t *, float *, int64_t *, int64_t *); + + void domatcopy_(int64_t *m, int64_t *n, double *alpha, double *a, int64_t *lda, double *b, int64_t *ldb); + void somatcopy_(int64_t *m, int64_t *n, float *alpha, float *a, int64_t *lda, float *b, int64_t *ldb); +} + +namespace BLAS{ + /** + * @brief Computes a matrix-matrix product and adds the result to a matrix. + * + * This function computes a matrix-matrix product of the form C = alpha * op(A) * op(B) + beta * C, + * where op(X) = X if trans == 'N', and op(X) = X^T if trans == 'T' or 'C'. + * + * @tparam T The data type of the matrices (float or double). + * @param transA Specifies whether to transpose matrix A ('T') or not ('N'). + * @param transB Specifies whether to transpose matrix B ('T') or not ('N'). + * @param m The number of rows of matrix op(A) and matrix C. + * @param n The number of columns of matrix op(B) and matrix C. + * @param k The number of columns of matrix op(A) and rows of matrix op(B). + * @param alpha The scalar alpha. + * @param A The m x k matrix A (if transA == 'N') or k x m matrix A^T (if transA == 'T'). + * @param LDA The leading dimension of matrix A. LDA >= max(1,m) if transA == 'N' and LDA >= max(1,k) otherwise. + * @param B The k x n matrix B (if transB == 'N') or n x k matrix B^T (if transB == 'T'). + * @param LDB The leading dimension of matrix B. LDB >= max(1,k) if transB == 'N' and LDB >= max(1,n) otherwise. + * @param beta The scalar beta. + * @param C The m x n matrix C. + * @param LDC The leading dimension of matrix C. LDC >= max(1,m). + */ + template + void gemm(char *transA, char *transB, int64_t *m, int64_t *n, int64_t *k, T *alpha, T *A, int64_t *LDA, T *B, int64_t *LDB, T *beta, T *C, int64_t *LDC ); + + /** + * @brief Computes the Euclidean norm of a vector. + * + * This function computes the Euclidean norm of a vector x, defined as ||x||_2 = sqrt(x^T * x). + * + * @tparam T The data type of the vector elements (float or double). + * @param n The number of elements in the vector. + * @param x The vector of length n. + * @param incx The stride between consecutive elements of the vector. incx > 0. + * @return The Euclidean norm of the vector. + */ + template + T nrm2(int64_t *n, T *x, int64_t *incx); + + /** + * @brief Computes a vector-scalar product and adds the result to a vector. + * + * This function computes a vector-scalar product, defined as y = alpha * x + y, where alpha is a scalar + * and x and y are vectors of the same length. The operation is performed in-place, so the result + * overwrites the input vector y. + * + * @tparam T The data type of the vector elements (float or double). + * @param n The number of elements in the vectors. + * @param alpha The scalar value by which to multiply the vector x. + * @param x The vector of length n. + * @param incx The stride between consecutive elements of the vector x. incx > 0. + * @param y The vector of length n to which the result is added. + * @param incy The stride between consecutive elements of the vector y. incy > 0. + */ + template + void axpy(int64_t* N, T* alpha, T* X, int64_t* incX, T* Y, int64_t* incY); + + /** + * @brief Scales a vector by a scalar value. + * + * This function multiplies each element in a vector x by a scalar value alpha, overwriting the + * original values in x. + * + * @tparam T The data type of the elements in the vector. + * @param n The length of the vector. + * @param alpha The scalar value. + * @param x The vector to scale. + * @param incx The stride between consecutive elements in x. + */ + template + void scal(int64_t* n, T* alpha, T* x, int64_t* incx); + + /** + * @brief Computes the dot product of two vectors x and y. + * + * This function computes the dot product of two vectors x and y, which is defined as: + * + * dot(x, y) = sum(x_i * y_i) for i = 1 to n + * + * where n is the length of the vectors and x_i and y_i are the ith elements of x and y, respectively. + * + * @tparam T The data type of the elements in the vectors. + * @param n The length of the vectors. + * @param x The first vector. + * @param incx The stride between consecutive elements in x. + * @param y The second vector. + * @param incy The stride between consecutive elements in y. + * @return The dot product of x and y. + */ + template + T dot(int64_t *n, T *x, int64_t *incx, T *y, int64_t *incy); + + template + void copy(int64_t* n, T* x, int64_t* incx, T* y, int64_t* incy); + + template + int gesv(int64_t *, int64_t *, T *, int64_t *, int64_t *, T *, int64_t *, int64_t *); + + + /// matrix multiplication + //specialized for double + template <> + void gemm(char *transA, char *transB, int64_t *m, int64_t *n, int64_t *k, double *alpha, double *A, int64_t *LDA, double *B, int64_t *LDB, double *beta, double *C, int64_t *LDC ){ + dgemm_(transA, transB, m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC); + } + // specialized for float + template <> + void gemm(char *transA, char *transB, int64_t *m, int64_t *n, int64_t *k, float *alpha, float *A, int64_t *LDA, float *B, int64_t *LDB, float *beta, float *C, int64_t *LDC ){ + sgemm_(transA, transB, m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC); + } + + /// norm + // specialized for double + template <> + double nrm2(int64_t *n, double *x, int64_t *incx){ + return dnrm2_(n, x, incx); + } + // specialized for float + template <> + float nrm2(int64_t *n, float *x, int64_t *incx){ + return snrm2_(n, x, incx); + } + + /// Multiplication with scalar + // specialized for double + template <> + void axpy(int64_t* N, double* alpha, double* X, int64_t* incX, double* Y, int64_t* incY) { + daxpy_(N, alpha, X, incX, Y, incY); + } + // specialized for float + template <> + void axpy(int64_t* N, float* alpha, float* X, int64_t* incX, float* Y, int64_t* incY) { + saxpy_(N, alpha, X, incX, Y, incY); + } + + /// scale a vector + // Specialization for double + template <> + void scal(int64_t* n, double* alpha, double* x, int64_t* incx) { + dscal_(n, alpha, x, incx); + } + // Specialization for float + template <> + void scal(int64_t* n, float* alpha, float* x, int64_t* incx) { + sscal_(n, alpha, x, incx); + } + + /// dot product + // specialized for double + template <> + double dot(int64_t *n, double *x, int64_t *incx, double *y, int64_t *incy){ + return ddot_(n, x, incx, y, incy); + } + // specialized for float + template <> + float dot(int64_t *n, float *x, int64_t *incx, float *y, int64_t *incy){ + return sdot_(n, x, incx, y, incy); + } + + template <> + void copy(int64_t* n, double* x, int64_t* incx, double* y, int64_t* incy){ + dcopy_(n, x, incx, y, incy); + } + + template <> + void copy(int64_t* n, float* x, int64_t* incx, float* y, int64_t* incy){ + scopy_(n, x, incx, y, incy); + } + + +} + +namespace LAPACK{ + + template + int64_t gesv(int64_t n, int64_t nrhs, T * A, int64_t lda, int64_t *ipiv, T *B, int64_t ldb); + + + template <> + int64_t gesv(int64_t n, int64_t nrhs, double * A, int64_t lda, int64_t *ipiv, double *B, int64_t ldb){ + int64_t info; + + if(ipiv != nullptr) + dgesv_(&n, &nrhs, A, &lda, ipiv, B, &ldb, &info); + else + { + int64_t *IPIV = new int64_t[n]; + dgesv_(&n, &nrhs, A, &lda, IPIV, B, &ldb, &info); + delete [] IPIV; + } + + return info; + } + + template <> + int64_t gesv(int64_t n, int64_t nrhs, float * A, int64_t lda, int64_t *ipiv, float *B, int64_t ldb){ + int64_t info; + + if(ipiv != nullptr) + sgesv_(&n, &nrhs, A, &lda, ipiv, B, &ldb, &info); + else + { + int64_t *IPIV = new int64_t[n]; + sgesv_(&n, &nrhs, A, &lda, IPIV, B, &ldb, &info); + delete [] IPIV; + } + + return info; + } +} \ No newline at end of file diff --git a/torchtt/cpp/amen_divide.h b/torchtt/cpp/amen_divide.h new file mode 100644 index 0000000..dc4a75d --- /dev/null +++ b/torchtt/cpp/amen_divide.h @@ -0,0 +1,549 @@ +#include "define.h" +#include "ortho.h" +#include +#include "matvecs.h" +#include "gmres.h" + +//torch::NoGradGuard no_grad; +/** + * @brief Compute thelocal matvec product in the AMEn: lsr,smnS,LSR,rnR->lmL + * + * @param[in] Phi_right The right interface + * @param[in] Phi_left The left interface + * @param[in] coreA The corre of the TT operator + * @param[in] core The core vector + * @return at::Tensor + */ +at::Tensor local_product(at::Tensor &Phi_right, at::Tensor &Phi_left, at::Tensor &coreA, at::Tensor &core){ + + + // rnR,lsr->nRls + auto tmp1 = at::tensordot(core, Phi_left, {0}, {2}); + // nRls,smnS->RlmS + auto tmp2 = at::tensordot(tmp1, coreA, {0,3},{2,0}); + // RlmS,LSR->lmL + return at::tensordot(tmp2, Phi_right, {0,3}, {2,1}); + +} + +/** + * @brief Compute the phi backwards for the form dot(left,A @ right) + * + * @param[in] Phi_now the current phi. Has shape r1_k+1 x R_k+1 x r2_k+1 + * @param[in] core_left the core on the left. Has shape r1_k x N_k x r1_k+1 + * @param[in] core_A the core of the matrix. Has shape R_k x N_k x N_k x R_k + * @param[in] core_right the core to the right. Has shape r2_k x N_k x r2_k+1 + * @return at::Tensor the following phi (backward). Has shape r1_k x R_k x r2_k + */ +at::Tensor compute_phi_bck_A(at::Tensor &Phi_now, at::Tensor &core_left, at::Tensor &core_A, at::Tensor &core_right){ + at::Tensor Phi; + // 5 GEMM lML,LSR->lMSR sMNS,rNR,lMSR->lsr + // 6 TDOT lMSR,sMNS->lRsN rNR,lRsN->lsr + // 5 TDOT lRsN,rNR->lsr lsr->lsr + //Phi = oe.contract('LSR,lML,sMNS,rNR->lsr',Phi_now,core_left,core_A,core_right) + Phi = at::tensordot(core_left, Phi_now, {2}, {0}); + Phi = at::tensordot(Phi, core_A, {1,2}, {1,3}); // lRsN + return at::tensordot(Phi, core_right, {1,3}, {2,1}); +} + +/** + * @brief Compute the phi forward for the form dot(left,A @ right) + * + * @param[in] Phi_now the current phi. Has shape r1_k x R_k x r2_k + * @param[in] core_left the core on the left. Has shape r1_k x N_k x r1_k+1 + * @param[in] core_A the core of the matrix. Has shape R_k x N_k x N_k x R_k + * @param[in] core_right the core to the right. Has shape r2_k x N_k x r2_k+1 + * @return at::Tensor the following phi (backward). Has shape r1_k+1 x R_k+1 x r2_k+1 + */ +at::Tensor compute_phi_fwd_A(at::Tensor &Phi_now, at::Tensor &core_left, at::Tensor &core_A, at::Tensor &core_right){ + at::Tensor Phi_next; +// 5 GEMM lML,lsr->MLsr sMNS,rNR,MLsr->LSR +// 6 TDOT MLsr,sMNS->LrNS rNR,LrNS->LSR +// 5 TDOT LrNS,rNR->LSR LSR->LSR + //Phi_next = oe.contract('lsr,lML,sMNS,rNR->LSR',Phi_now,core_left,core_A,core_right) + Phi_next = at::tensordot(core_left, Phi_now, {0}, {0}); // MLsr + Phi_next = at::tensordot(Phi_next, core_A, {0,2}, {1,0}); // LrNS + Phi_next = at::tensordot(Phi_next, core_right, {1,2}, {0,1}); + return Phi_next; +} + +/** + * @brief Compute the backward phi `BR,bnB,rnR->br` + * + * @param[in] Phi_now the current Phi. Has shape rb_k+1 x r_k+1 + * @param[in] core_b the core of the rhs. Has shape rb_k x N_k x rb_k+1 + * @param[in] core the current core. Has shape r_k x N_k x r_k+1 + * @return at::Tensor Has shape rb_k x r_k + */ +at::Tensor compute_phi_bck_rhs(at::Tensor &Phi_now, at::Tensor &core_b, at::Tensor &core){ + at::Tensor Phi; + Phi = at::tensordot(core_b, Phi_now, {2}, {0}); + Phi = at::tensordot(Phi, core, {1,2}, {1,2}); + return Phi; +} + +/** + * @brief Compute the forward phi `br,bnB,rnR->BR` + * + * @param[in] Phi_now the current Phi. Has shape rb_k x r_k + * @param[in] core_rhs the core of the rhs. Has shape rb_k x N_k+1 x rb_k+1 + * @param[in] core the current core. Has shape r_k x N_k x r_k+1 + * @return at::Tensor Has shape rb_k+1 x r_k+1 + */ +at::Tensor compute_phi_fwd_rhs(at::Tensor &Phi_now, at::Tensor &core_rhs, at::Tensor &core){ + + at::Tensor Phi_next = at::tensordot(Phi_now, core_rhs, {0}, {0}); + Phi_next = at::tensordot(Phi_next, core, {0,1}, {0,1}); + return Phi_next; +} + +/** + * @brief AMEn solve implementation in C++. + * + * @param[in] A_cores + * @param[in] b_cores + * @param[in] x0_cores + * @param[in] N + * @param[in] rA + * @param[in] rb + * @param[in] r_x0 + * @param[in] nswp + * @param[in] eps + * @param[in] rmax + * @param[in] max_full + * @param[in] kickrank + * @param[in] kick2 + * @param[in] local_iterations + * @param[in] resets + * @param[in] verbose + * @param[in] preconditioner + * @return std::vector TT cores of the solution + */ +std::vector amen_solve( + std::vector &A_cores, + std::vector &b_cores, + std::vector &x0_cores, + std::vector N, + std::vector rA, + std::vector rb, + std::vector r_x0, + uint64_t nswp, + double eps, + uint64_t rmax, + uint64_t max_full, + uint64_t kickrank, + uint64_t kick2, + uint64_t local_iterations, + uint64_t resets, + bool verbose, + int preconditioner) +{ + + torch::NoGradGuard no_grad; + + if(verbose) + { + std::cout << "Starting AMEn solve with:"; + std::cout << "\n\tepsilon : " << eps; + std::cout << "\n\tsweeps : " << nswp; + std::cout << "\n\tlocal iterations : " << local_iterations; + std::cout << "\n\tresets : " << resets; + char prec_char = (preconditioner == 0 ? 'N' : ( preconditioner == 1 ? 'C' : 'R')); + std::cout << "\n\tlocal preconditioner : " << prec_char; + std::cout << std::endl << std::endl; + } + + //at::TensorBase::device dtype = A_cores[0].dtype; + auto options = A_cores[0].options(); + uint64_t d = N.size(); + std::vector x_cores; + + if(x0_cores.size() == 0){ + for(int i = 0; i < d; i++) + x_cores.push_back(torch::ones({1,N[i],1}, options)); + } + else + x_cores = x0_cores; + std::vector rx = r_x0; + + std::vector rz(d+1); + rz[0] = 1; + rz[d] = 1; + for(int i=1;i z_cores(d); + for(int i=0;i Phiz(d+1); + std::vector Phiz_b(d+1); + std::vector Phis(d+1); + std::vector Phis_b(d+1); + Phiz[0] = at::ones({1,1,1}, options); + Phiz_b[0] = at::ones({1,1}, options); + Phis[0] = at::ones({1,1,1}, options); + Phis_b[0] = at::ones({1,1}, options); + Phiz[d] = at::ones({1,1,1}, options); + Phiz_b[d] = at::ones({1,1}, options); + Phis[d] = at::ones({1,1,1}, options); + Phis_b[d] = at::ones({1,1}, options); + + double *normA = new double[d-1]; + double *normb = new double[d-1]; + double *normx = new double[d-1]; + for(int k=0;k tme_swp, tme_total; + if(verbose) + tme_total = std::chrono::high_resolution_clock::now(); + int swp; + for(swp=0;swp0;k--){ + if(!last){ + at::Tensor cz_new; + if(swp>0){ + at::Tensor czA = local_product(Phiz[k+1], Phiz[k], A_cores[k], x_cores[k]); + at::Tensor czy = at::tensordot(Phiz_b[k], b_cores[k], {0}, {0}); + czy = at::tensordot(czy, Phiz_b[k+1], {2}, {0}); + czy *= nrmsc; + czy -= czA; + std::tuple USV = at::linalg_svd(czy.reshape({czy.sizes()[0],-1}), false); + uint64_t temp = kickrank < std::get<2>(USV).sizes()[0] ? kickrank : std::get<2>(USV).sizes()[0]; + cz_new = std::get<2>(USV).index({ torch::indexing::Slice(0, temp), torch::indexing::Ellipsis}).t(); + if(k < d-1) + cz_new = at::cat({cz_new,torch::randn({cz_new.sizes()[0], kick2}, options)}, 1); + } + else{ + cz_new = z_cores[k].reshape({rz[k],-1}).t(); + } + + at::Tensor Qz; + std::tie(Qz, std::ignore) = at::linalg_qr(cz_new); + rz[k] = Qz.sizes()[1]; + z_cores[k] = (Qz.t()).reshape({rz[k], N[k], rz[k+1]}); + } + + + + if(swp>0) + nrmsc = nrmsc * normA[k-1] * normx[k-1] / normb[k-1]; + + auto core = x_cores[k].reshape({rx[k],N[k]*rx[k+1]}).t(); + + std::tuple QR = at::linalg_qr(core); + + auto core_prev = at::tensordot(x_cores[k-1], std::get<1>(QR).t(), {2}, {0}); + rx[k] = std::get<0>(QR).sizes()[1]; + + double current_norm = torch::norm(core_prev).item(); + if(current_norm > 0) + core_prev /= current_norm; + else + current_norm = 1.0; + normx[k-1] = normx[k-1] * current_norm; + + x_cores[k] = (std::get<0>(QR).t()).reshape({rx[k], N[k], rx[k+1]}).clone(); + x_cores[k-1] = core_prev.clone(); + + + Phis[k] = compute_phi_bck_A(Phis[k+1],x_cores[k],A_cores[k],x_cores[k]); + Phis_b[k] = compute_phi_bck_rhs(Phis_b[k+1],b_cores[k],x_cores[k]); + + + double norm = torch::norm(Phis[k]).item(); + norm = norm>0 ? norm : 0.0; + normA[k-1] = norm; + Phis[k] = Phis[k] / norm; + + norm = torch::norm(Phis_b[k]).item(); + norm = norm>0 ? norm : 0.0; + normb[k-1] = norm; + Phis_b[k] = Phis_b[k] / norm; + + // norm correction + nrmsc = nrmsc * normb[k-1] / (normA[k-1] * normx[k-1]); + + // compute phis_z + if(!last){ + Phiz[k] = compute_phi_bck_A(Phiz[k+1], z_cores[k], A_cores[k], x_cores[k]) / normA[k-1]; + Phiz_b[k] = compute_phi_bck_rhs(Phiz_b[k+1], b_cores[k], z_cores[k]) / normb[k-1]; + } + } + double max_res = 0; + double max_dx = 0; + + for(int k = 0; k(); + + // residuals + double real_tol = (eps/std::sqrt(d))/damp; + + // direct local solver or iterative + bool use_full = rx[k]*N[k]*rx[k+1] < max_full; + at::Tensor solution_now; + double res_old, res_new; + + at::Tensor B; + auto Op = AMENsolveMV(); + + if(use_full){ + if(verbose) + std::cout << "\t\tChoosing direct solver (local size " << rx[k]*N[k]*rx[k+1] << ")..." << std::endl; + auto Bp = at::tensordot(A_cores[k], Phis[k+1], {3}, {1}); // smnS,LSR->smnLR + Bp = at::tensordot(Phis[k], Bp, {1}, {0}); // lsr,smnLR->lrmnLR + B = Bp.permute({0,2,4,1,3,5}).reshape({rx[k]*N[k]*rx[k+1], rx[k]*N[k]*rx[k+1]}); + + solution_now = at::linalg_solve(B, rhs); + + res_old = torch::norm(at::linalg_matmul(B, previous_solution) - rhs).item() / norm_rhs; + res_new = torch::norm(at::linalg_matmul(B, solution_now) - rhs).item() / norm_rhs; + } + else{ + std::chrono::time_point tme_local; + if(verbose) { + std::cout << "\t\tChoosing iterative solver (local size " << rx[k]*N[k]*rx[k+1] << ")..." <({rx[k], N[k], rx[k+1]}))); + Op.setter(Phis[k], Phis[k+1], A_cores[k],shape_now, preconditioner, options); + + double eps_local = real_tol * norm_rhs; + + auto drhs = rhs - Op.matvec(previous_solution, false); + eps_local /= torch::norm(drhs).item(); + + int flag; + int nit; + + at::Tensor ps = 0.0 * previous_solution; + gmres(solution_now, flag, nit, Op, drhs, ps, drhs.sizes()[0], local_iterations, eps_local, resets ); + + if(preconditioner!=NO_PREC){ + solution_now = Op.apply_prec(solution_now.reshape(shape_now)); + } + solution_now = solution_now.reshape({-1,1}); + + solution_now += previous_solution; + res_old = torch::norm(Op.matvec(previous_solution, false)-rhs).item()/norm_rhs; + res_new = torch::norm(Op.matvec(solution_now, false)-rhs).item()/norm_rhs; + + if(verbose){ + std::cout<<"\t\tFinished with flag " << flag << " after " << nit << " iterations with relres " << res_new << " (from " << eps_local << ")" << std::endl; + auto duration = (double)(std::chrono::duration_cast(std::chrono::high_resolution_clock::now()-tme_local)).count() /1000.0 ; + std::cout<<"\t\tTime needed " << duration << " ms" << std::endl; + } + } + if(verbose && res_old/res_new < damp && res_new > real_tol) + std::cout << "WARNING: residual increase. res_old " << res_old << ", res_new " << res_new << ", " << real_tol << std::endl; + + auto dx = torch::norm(solution_now - previous_solution).item() / torch::norm(solution_now).item(); + + if(verbose) + std::cout << "\t\tdx = " << dx << ", res_now = " << res_new << ", res_old = " << res_old << std::endl; + + max_dx = dx < max_dx ? max_dx : dx; + max_res = max_res < res_old ? res_old : max_res; + + solution_now = solution_now.reshape({rx[k]*N[k], rx[k+1]}); + + at::Tensor u,s,v; + uint64_t r; + + if(k0){ + auto solution = at::linalg_matmul(u.index({torch::indexing::Ellipsis, torch::indexing::Slice(0,r,1)}) * s.index({torch::indexing::Slice(0,r,1)}), v.index({torch::indexing::Slice(0,r,1), torch::indexing::Ellipsis})); + + double res; + if(use_full) + res = torch::norm(at::linalg_matmul(B, solution.reshape({-1,1})) - rhs).item() / norm_rhs; + else{ + auto tmp_tens = solution.reshape({-1,1}); + res = torch::norm(Op.matvec(tmp_tens, false)-rhs).item()/norm_rhs; + } + + if(res>(res_new > real_tol*damp ? res_new : real_tol*damp)) + break; + --r; + } + ++r; + + r = (r(); + + if(norm_now>0) + v = v / norm_now; + else + norm_now = 1.0; + + normx[k] = normx[k] * norm_now; + + x_cores[k] = u.reshape({rx[k], N[k], r}).clone(); + x_cores[k+1] = v.reshape({r, N[k+1], rx[k+2]}).clone(); + rx[k+1] = r; + + + + Phis[k+1] = compute_phi_fwd_A(Phis[k], x_cores[k], A_cores[k], x_cores[k]); + Phis_b[k+1] = compute_phi_fwd_rhs(Phis_b[k], b_cores[k],x_cores[k]); + + // ... and norms + auto norm = torch::norm(Phis[k+1]).item(); + norm = norm>0 ? norm : 1.0; + normA[k] = norm; + Phis[k+1] = Phis[k+1] / norm; + norm = torch::norm(Phis_b[k+1]).item(); + norm = norm>0 ? norm : 1.0; + normb[k] = norm; + Phis_b[k+1] = Phis_b[k+1] / norm; + + // norm correction + nrmsc = nrmsc * normb[k] / ( normA[k] * normx[k] ); + + + // next phiz + if(!last){ + Phiz[k+1] = compute_phi_fwd_A(Phiz[k], z_cores[k], A_cores[k], x_cores[k]) / normA[k]; + Phiz_b[k+1] = compute_phi_fwd_rhs(Phiz_b[k], b_cores[k],z_cores[k]) / normb[k]; + } + } + else{ + auto usv = at::linalg_matmul(u * s.index({torch::indexing::Slice(0,r,1)}), v.index({torch::indexing::Slice(0,r,1), torch::indexing::Ellipsis}).t()); + x_cores[k] = usv.reshape({rx[k],N[k],rx[k+1]}); + } + + + + } + if(verbose){ + std::cout << "Solution rank [ "; + for(auto rr: rx) std::cout << rr << " "; + std::cout << "]" << std::endl; + std::cout << "Maxres " << max_res << std::endl; + auto diff_time = std::chrono::high_resolution_clock::now() - tme_swp; + std::cout << "Time " << (double)(std::chrono::duration_cast(diff_time)).count()/1000.0 << " ms"<(diff_time)).count()/1000000.0 << " seconds" << std::endl << std::endl; + } + + double norm_x = 0.0; + for(int i=0;i +#include "matvecs.h" +#include "gmres.h" + +//torch::NoGradGuard no_grad; +/** + * @brief Compute thelocal matvec product in the AMEn: lsr,smnS,LSR,rnR->lmL + * + * @param[in] Phi_right The right interface + * @param[in] Phi_left The left interface + * @param[in] coreA The corre of the TT operator + * @param[in] core The core vector + * @return at::Tensor + */ +at::Tensor local_product(at::Tensor &Phi_right, at::Tensor &Phi_left, at::Tensor &coreA, at::Tensor &core){ + + + // rnR,lsr->nRls + auto tmp1 = at::tensordot(core, Phi_left, {0}, {2}); + // nRls,smnS->RlmS + auto tmp2 = at::tensordot(tmp1, coreA, {0,3},{2,0}); + // RlmS,LSR->lmL + return at::tensordot(tmp2, Phi_right, {0,3}, {2,1}); + +} + +/** + * @brief Compute the phi backwards for the form dot(left,A @ right) + * + * @param[in] Phi_now the current phi. Has shape r1_k+1 x R_k+1 x r2_k+1 + * @param[in] core_left the core on the left. Has shape r1_k x N_k x r1_k+1 + * @param[in] core_A the core of the matrix. Has shape R_k x N_k x N_k x R_k + * @param[in] core_right the core to the right. Has shape r2_k x N_k x r2_k+1 + * @return at::Tensor the following phi (backward). Has shape r1_k x R_k x r2_k + */ +at::Tensor compute_phi_bck_A(at::Tensor &Phi_now, at::Tensor &core_left, at::Tensor &core_A, at::Tensor &core_right){ + at::Tensor Phi; + // 5 GEMM lML,LSR->lMSR sMNS,rNR,lMSR->lsr + // 6 TDOT lMSR,sMNS->lRsN rNR,lRsN->lsr + // 5 TDOT lRsN,rNR->lsr lsr->lsr + //Phi = oe.contract('LSR,lML,sMNS,rNR->lsr',Phi_now,core_left,core_A,core_right) + Phi = at::tensordot(core_left, Phi_now, {2}, {0}); + Phi = at::tensordot(Phi, core_A, {1,2}, {1,3}); // lRsN + return at::tensordot(Phi, core_right, {1,3}, {2,1}); +} + +/** + * @brief Compute the phi forward for the form dot(left,A @ right) + * + * @param[in] Phi_now the current phi. Has shape r1_k x R_k x r2_k + * @param[in] core_left the core on the left. Has shape r1_k x N_k x r1_k+1 + * @param[in] core_A the core of the matrix. Has shape R_k x N_k x N_k x R_k + * @param[in] core_right the core to the right. Has shape r2_k x N_k x r2_k+1 + * @return at::Tensor the following phi (backward). Has shape r1_k+1 x R_k+1 x r2_k+1 + */ +at::Tensor compute_phi_fwd_A(at::Tensor &Phi_now, at::Tensor &core_left, at::Tensor &core_A, at::Tensor &core_right){ + at::Tensor Phi_next; +// 5 GEMM lML,lsr->MLsr sMNS,rNR,MLsr->LSR +// 6 TDOT MLsr,sMNS->LrNS rNR,LrNS->LSR +// 5 TDOT LrNS,rNR->LSR LSR->LSR + //Phi_next = oe.contract('lsr,lML,sMNS,rNR->LSR',Phi_now,core_left,core_A,core_right) + Phi_next = at::tensordot(core_left, Phi_now, {0}, {0}); // MLsr + Phi_next = at::tensordot(Phi_next, core_A, {0,2}, {1,0}); // LrNS + Phi_next = at::tensordot(Phi_next, core_right, {1,2}, {0,1}); + return Phi_next; +} + +/** + * @brief Compute the backward phi `BR,bnB,rnR->br` + * + * @param[in] Phi_now the current Phi. Has shape rb_k+1 x r_k+1 + * @param[in] core_b the core of the rhs. Has shape rb_k x N_k x rb_k+1 + * @param[in] core the current core. Has shape r_k x N_k x r_k+1 + * @return at::Tensor Has shape rb_k x r_k + */ +at::Tensor compute_phi_bck_rhs(at::Tensor &Phi_now, at::Tensor &core_b, at::Tensor &core){ + at::Tensor Phi; + Phi = at::tensordot(core_b, Phi_now, {2}, {0}); + Phi = at::tensordot(Phi, core, {1,2}, {1,2}); + return Phi; +} + +/** + * @brief Compute the forward phi `br,bnB,rnR->BR` + * + * @param[in] Phi_now the current Phi. Has shape rb_k x r_k + * @param[in] core_rhs the core of the rhs. Has shape rb_k x N_k+1 x rb_k+1 + * @param[in] core the current core. Has shape r_k x N_k x r_k+1 + * @return at::Tensor Has shape rb_k+1 x r_k+1 + */ +at::Tensor compute_phi_fwd_rhs(at::Tensor &Phi_now, at::Tensor &core_rhs, at::Tensor &core){ + + at::Tensor Phi_next = at::tensordot(Phi_now, core_rhs, {0}, {0}); + Phi_next = at::tensordot(Phi_next, core, {0,1}, {0,1}); + return Phi_next; +} + +/** + * @brief AMEn solve implementation in C++. + * + * @param[in] A_cores + * @param[in] b_cores + * @param[in] x0_cores + * @param[in] N + * @param[in] rA + * @param[in] rb + * @param[in] r_x0 + * @param[in] nswp + * @param[in] eps + * @param[in] rmax + * @param[in] max_full + * @param[in] kickrank + * @param[in] kick2 + * @param[in] local_iterations + * @param[in] resets + * @param[in] verbose + * @param[in] preconditioner + * @return std::vector TT cores of the solution + */ +std::vector amen_solve( + std::vector &A_cores, + std::vector &b_cores, + std::vector &x0_cores, + std::vector N, + std::vector rA, + std::vector rb, + std::vector r_x0, + uint64_t nswp, + double eps, + uint64_t rmax, + uint64_t max_full, + uint64_t kickrank, + uint64_t kick2, + uint64_t local_iterations, + uint64_t resets, + bool verbose, + int preconditioner) +{ + + torch::NoGradGuard no_grad; + + if(verbose) + { + std::cout << "Starting AMEn solve with:"; + std::cout << "\n\tepsilon : " << eps; + std::cout << "\n\tsweeps : " << nswp; + std::cout << "\n\tlocal iterations : " << local_iterations; + std::cout << "\n\tresets : " << resets; + char prec_char = (preconditioner == 0 ? 'N' : ( preconditioner == 1 ? 'C' : 'R')); + std::cout << "\n\tlocal preconditioner : " << prec_char; + std::cout << std::endl << std::endl; + } + + //at::TensorBase::device dtype = A_cores[0].dtype; + auto options = A_cores[0].options(); + uint64_t d = N.size(); + std::vector x_cores; + + if(x0_cores.size() == 0){ + for(int i = 0; i < d; i++) + x_cores.push_back(torch::ones({1,N[i],1}, options)); + } + else + x_cores = x0_cores; + std::vector rx = r_x0; + + std::vector rz(d+1); + rz[0] = 1; + rz[d] = 1; + for(int i=1;i z_cores(d); + for(int i=0;i Phiz(d+1); + std::vector Phiz_b(d+1); + std::vector Phis(d+1); + std::vector Phis_b(d+1); + Phiz[0] = at::ones({1,1,1}, options); + Phiz_b[0] = at::ones({1,1}, options); + Phis[0] = at::ones({1,1,1}, options); + Phis_b[0] = at::ones({1,1}, options); + Phiz[d] = at::ones({1,1,1}, options); + Phiz_b[d] = at::ones({1,1}, options); + Phis[d] = at::ones({1,1,1}, options); + Phis_b[d] = at::ones({1,1}, options); + + double *normA = new double[d-1]; + double *normb = new double[d-1]; + double *normx = new double[d-1]; + for(int k=0;k tme_swp, tme_total; + if(verbose) + tme_total = std::chrono::high_resolution_clock::now(); + int swp; + for(swp=0;swp0;k--){ + if(!last){ + at::Tensor cz_new; + if(swp>0){ + at::Tensor czA = local_product(Phiz[k+1], Phiz[k], A_cores[k], x_cores[k]); + at::Tensor czy = at::tensordot(Phiz_b[k], b_cores[k], {0}, {0}); + czy = at::tensordot(czy, Phiz_b[k+1], {2}, {0}); + czy *= nrmsc; + czy -= czA; + std::tuple USV = at::linalg_svd(czy.reshape({czy.sizes()[0],-1}), false); + uint64_t temp = kickrank < std::get<2>(USV).sizes()[0] ? kickrank : std::get<2>(USV).sizes()[0]; + cz_new = std::get<2>(USV).index({ torch::indexing::Slice(0, temp), torch::indexing::Ellipsis}).t(); + if(k < d-1) + cz_new = at::cat({cz_new,torch::randn({cz_new.sizes()[0], kick2}, options)}, 1); + } + else{ + cz_new = z_cores[k].reshape({rz[k],-1}).t(); + } + + at::Tensor Qz; + std::tie(Qz, std::ignore) = at::linalg_qr(cz_new); + rz[k] = Qz.sizes()[1]; + z_cores[k] = (Qz.t()).reshape({rz[k], N[k], rz[k+1]}); + } + + + + if(swp>0) + nrmsc = nrmsc * normA[k-1] * normx[k-1] / normb[k-1]; + + auto core = x_cores[k].reshape({rx[k],N[k]*rx[k+1]}).t(); + + std::tuple QR = at::linalg_qr(core); + + auto core_prev = at::tensordot(x_cores[k-1], std::get<1>(QR).t(), {2}, {0}); + rx[k] = std::get<0>(QR).sizes()[1]; + + double current_norm = torch::norm(core_prev).item(); + if(current_norm > 0) + core_prev /= current_norm; + else + current_norm = 1.0; + normx[k-1] = normx[k-1] * current_norm; + + x_cores[k] = (std::get<0>(QR).t()).reshape({rx[k], N[k], rx[k+1]}).clone(); + x_cores[k-1] = core_prev.clone(); + + + Phis[k] = compute_phi_bck_A(Phis[k+1],x_cores[k],A_cores[k],x_cores[k]); + Phis_b[k] = compute_phi_bck_rhs(Phis_b[k+1],b_cores[k],x_cores[k]); + + + double norm = torch::norm(Phis[k]).item(); + norm = norm>0 ? norm : 0.0; + normA[k-1] = norm; + Phis[k] = Phis[k] / norm; + + norm = torch::norm(Phis_b[k]).item(); + norm = norm>0 ? norm : 0.0; + normb[k-1] = norm; + Phis_b[k] = Phis_b[k] / norm; + + // norm correction + nrmsc = nrmsc * normb[k-1] / (normA[k-1] * normx[k-1]); + + // compute phis_z + if(!last){ + Phiz[k] = compute_phi_bck_A(Phiz[k+1], z_cores[k], A_cores[k], x_cores[k]) / normA[k-1]; + Phiz_b[k] = compute_phi_bck_rhs(Phiz_b[k+1], b_cores[k], z_cores[k]) / normb[k-1]; + } + } + double max_res = 0; + double max_dx = 0; + + for(int k = 0; k(); + + // residuals + double real_tol = (eps/std::sqrt(d))/damp; + + // direct local solver or iterative + bool use_full = rx[k]*N[k]*rx[k+1] < max_full; + at::Tensor solution_now; + double res_old, res_new; + + at::Tensor B; + auto Op = AMENsolveMV(); + + if(use_full){ + if(verbose) + std::cout << "\t\tChoosing direct solver (local size " << rx[k]*N[k]*rx[k+1] << ")..." << std::endl; + auto Bp = at::tensordot(A_cores[k], Phis[k+1], {3}, {1}); // smnS,LSR->smnLR + Bp = at::tensordot(Phis[k], Bp, {1}, {0}); // lsr,smnLR->lrmnLR + B = Bp.permute({0,2,4,1,3,5}).reshape({rx[k]*N[k]*rx[k+1], rx[k]*N[k]*rx[k+1]}); + + solution_now = at::linalg_solve(B, rhs); + + res_old = torch::norm(at::linalg_matmul(B, previous_solution) - rhs).item() / norm_rhs; + res_new = torch::norm(at::linalg_matmul(B, solution_now) - rhs).item() / norm_rhs; + } + else{ + std::chrono::time_point tme_local; + if(verbose) { + std::cout << "\t\tChoosing iterative solver (local size " << rx[k]*N[k]*rx[k+1] << ")..." <({rx[k], N[k], rx[k+1]}))); + Op.setter(Phis[k], Phis[k+1], A_cores[k],shape_now, preconditioner, options); + + double eps_local = real_tol * norm_rhs; + + auto drhs = rhs - Op.matvec(previous_solution, false); + eps_local /= torch::norm(drhs).item(); + + int flag; + int nit; + + at::Tensor ps = 0.0 * previous_solution; + gmres(solution_now, flag, nit, Op, drhs, ps, drhs.sizes()[0], local_iterations, eps_local, resets ); + + if(preconditioner!=NO_PREC){ + solution_now = Op.apply_prec(solution_now.reshape(shape_now)); + } + solution_now = solution_now.reshape({-1,1}); + + solution_now += previous_solution; + res_old = torch::norm(Op.matvec(previous_solution, false)-rhs).item()/norm_rhs; + res_new = torch::norm(Op.matvec(solution_now, false)-rhs).item()/norm_rhs; + + if(verbose){ + std::cout<<"\t\tFinished with flag " << flag << " after " << nit << " iterations with relres " << res_new << " (from " << eps_local << ")" << std::endl; + auto duration = (double)(std::chrono::duration_cast(std::chrono::high_resolution_clock::now()-tme_local)).count() /1000.0 ; + std::cout<<"\t\tTime needed " << duration << " ms" << std::endl; + } + } + if(verbose && res_old/res_new < damp && res_new > real_tol) + std::cout << "WARNING: residual increase. res_old " << res_old << ", res_new " << res_new << ", " << real_tol << std::endl; + + auto dx = torch::norm(solution_now - previous_solution).item() / torch::norm(solution_now).item(); + + if(verbose) + std::cout << "\t\tdx = " << dx << ", res_now = " << res_new << ", res_old = " << res_old << std::endl; + + max_dx = dx < max_dx ? max_dx : dx; + max_res = max_res < res_old ? res_old : max_res; + + solution_now = solution_now.reshape({rx[k]*N[k], rx[k+1]}); + + at::Tensor u,s,v; + uint64_t r; + + if(k0){ + auto solution = at::linalg_matmul(u.index({torch::indexing::Ellipsis, torch::indexing::Slice(0,r,1)}) * s.index({torch::indexing::Slice(0,r,1)}), v.index({torch::indexing::Slice(0,r,1), torch::indexing::Ellipsis})); + + double res; + if(use_full) + res = torch::norm(at::linalg_matmul(B, solution.reshape({-1,1})) - rhs).item() / norm_rhs; + else{ + auto tmp_tens = solution.reshape({-1,1}); + res = torch::norm(Op.matvec(tmp_tens, false)-rhs).item()/norm_rhs; + } + + if(res>(res_new > real_tol*damp ? res_new : real_tol*damp)) + break; + --r; + } + ++r; + + r = (r(); + + if(norm_now>0) + v = v / norm_now; + else + norm_now = 1.0; + + normx[k] = normx[k] * norm_now; + + x_cores[k] = u.reshape({rx[k], N[k], r}).clone(); + x_cores[k+1] = v.reshape({r, N[k+1], rx[k+2]}).clone(); + rx[k+1] = r; + + + + Phis[k+1] = compute_phi_fwd_A(Phis[k], x_cores[k], A_cores[k], x_cores[k]); + Phis_b[k+1] = compute_phi_fwd_rhs(Phis_b[k], b_cores[k],x_cores[k]); + + // ... and norms + auto norm = torch::norm(Phis[k+1]).item(); + norm = norm>0 ? norm : 1.0; + normA[k] = norm; + Phis[k+1] = Phis[k+1] / norm; + norm = torch::norm(Phis_b[k+1]).item(); + norm = norm>0 ? norm : 1.0; + normb[k] = norm; + Phis_b[k+1] = Phis_b[k+1] / norm; + + // norm correction + nrmsc = nrmsc * normb[k] / ( normA[k] * normx[k] ); + + + // next phiz + if(!last){ + Phiz[k+1] = compute_phi_fwd_A(Phiz[k], z_cores[k], A_cores[k], x_cores[k]) / normA[k]; + Phiz_b[k+1] = compute_phi_fwd_rhs(Phiz_b[k], b_cores[k],z_cores[k]) / normb[k]; + } + } + else{ + auto usv = at::linalg_matmul(u * s.index({torch::indexing::Slice(0,r,1)}), v.index({torch::indexing::Slice(0,r,1), torch::indexing::Ellipsis}).t()); + x_cores[k] = usv.reshape({rx[k],N[k],rx[k+1]}); + } + + + + } + if(verbose){ + std::cout << "Solution rank [ "; + for(auto rr: rx) std::cout << rr << " "; + std::cout << "]" << std::endl; + std::cout << "Maxres " << max_res << std::endl; + auto diff_time = std::chrono::high_resolution_clock::now() - tme_swp; + std::cout << "Time " << (double)(std::chrono::duration_cast(diff_time)).count()/1000.0 << " ms"<(diff_time)).count()/1000000.0 << " seconds" << std::endl << std::endl; + } + + double norm_x = 0.0; + for(int i=0;i &cores, std::vector &shape, std::vector &rank, double epsilon, uint64_t rmax){ + uint64_t d = cores.size(); + + if(d<=1){ + return ; + } + + lr_orthogonal(cores, shape, rank); + double eps = epsilon / std::sqrt((double)d-1.0); + + + + at::Tensor core_now, core_next; + + core_now = cores[d-1].reshape({rank[d-1],shape[d-1]*rank[d]}); + + for(int i=d-1;i>0;i--){ + + core_next = cores[i-1].reshape({rank[i-1]*shape[i-1], rank[i]}); + + std::tuple USV = at::linalg_svd(core_now, false); + + int rc = rank_chop(std::get<1>(USV), eps * torch::norm(std::get<1>(USV)).item()); + + int rnew = rmax < rc ? rmax : rc; + + rank[i] = rnew; + + auto US = std::get<0>(USV).index({torch::indexing::Ellipsis, torch::indexing::Slice(0, rnew)}).matmul(torch::diag(std::get<1>(USV).index({torch::indexing::Slice(0, rnew)}))); + core_next = core_next.matmul(US); + + cores[i-1] = core_next.reshape({rank[i-1], shape[i-1], rank[i]}); + cores[i] = std::get<2>(USV).index({ torch::indexing::Slice(0, rnew), torch::indexing::Ellipsis}).reshape({rank[i], shape[i], rank[i+1]}); + + core_now = core_next.reshape({rank[i-1],-1}); + + + } + +} \ No newline at end of file diff --git a/torchtt/cpp/cpp_ext.cpp b/torchtt/cpp/cpp_ext.cpp new file mode 100644 index 0000000..75afe20 --- /dev/null +++ b/torchtt/cpp/cpp_ext.cpp @@ -0,0 +1,19 @@ +#include "full.h" +#include "amen_solve.h" +#include "compression.h" +#include "dmrg_mv.h" + +/// Functions from cpp to import in python +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("tt_full", &full, "TT to full"); + m.def("amen_solve", &amen_solve, "AMEn solve"); + m.def("round_this", &round_this, "Implace rounding"); + m.def("dmrg_mv", &dmrg_mv, "DMRG matrix vector product"); +} + + + + + + + diff --git a/torchtt/cpp/define.h b/torchtt/cpp/define.h new file mode 100644 index 0000000..00add3a --- /dev/null +++ b/torchtt/cpp/define.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "BLAS.h" + +// torch::NoGradGuard no_grad; + +#define NO_PREC 0 +#define C_PREC 1 +#define R_PREC 2 + + + diff --git a/torchtt/cpp/dmrg_mv.h b/torchtt/cpp/dmrg_mv.h new file mode 100644 index 0000000..b6d46d1 --- /dev/null +++ b/torchtt/cpp/dmrg_mv.h @@ -0,0 +1,231 @@ +#include "define.h" +#include "ortho.h" +#include + +/** + * @brief Fast matvec `y = Ax` using DMRG optimization + * + * @param A_cores the cores of the TT matrix `A` + * @param x_cores the cores of `x` + * @param y0_cores initial guess for `y`. If an emptyvector is provided the guess is random + * @param M shape of `y` + * @param N shape of `x` + * @param rx rank of `x` + * @param ry0 rank of the initial guess of `y` + * @param nswp number of sweeps + * @param eps relative accuracy + * @param rmax maximum rank + * @param kickrank kickrank + * @param r_enlage how much we enlarge the rank compared to the previous step + * @param verb show debug info. + * @return std::vector + */ +std::vector dmrg_mv( + std::vector A_cores, + std::vector x_cores, + std::vector y0_cores, + std::vector M, + std::vector N, + std::vector rx, + std::vector ry0, + int64_t nswp, double eps, + int64_t rmax, + int64_t kickrank, + bool verb) +{ + auto options = A_cores[0].options(); + int64_t d = N.size(); + + std::vector y_cores(d); + std::vector ry(d+1); + + + if(y0_cores.size() == 0){ + + for(int i = 1; i < d; ++i) + ry[i] = 2; + ry[0] = 1; + ry[d] = 1; + for(int i=0; i < d; ++i) + y_cores[i] = torch::randn({ry[i], M[i], ry[i + 1]}, options); + } + else{ + y_cores = y0_cores; + ry = ry0; + } + + std::vector Phis(d+1); + + Phis[0] = at::ones({1,1,1}, options); + Phis[d] = at::ones({1,1,1}, options); + + std::vector delta_cores(d-1); + std::vector delta_cores_prev(d-1); + std::vector r_enlarge(d); + + for(int i = 0; i < d-1; ++i) + { + r_enlarge[i] = 2; + delta_cores[i] = 1.0; + delta_cores_prev[i] = 1.0; + } + r_enlarge[d-1] = 2; + + bool last = false; + + std::chrono::time_point tme_swp, tme_total; + + if(verb) + tme_total = std::chrono::high_resolution_clock::now(); + + int swp; + for(swp = 0; swp < nswp; ++swp) + { + if(verb) + std::cout << "Sweep " << swp << std::endl; + + for(int k = d-1; k > 0; --k) + { + auto core = y_cores[k].permute({1,2,0}).reshape({M[k]*ry[k+1], ry[k]}); + at::Tensor Q,R; + std::tie(Q,R) = at::linalg_qr(core.contiguous()); + int64_t rnew = core.sizes()[0] < core.sizes()[1] ? core.sizes()[0] : core.sizes()[1]; + y_cores[k] = Q.t().reshape({rnew, M[k], -1}).contiguous(); + ry[k] = rnew; + auto core_next = at::tensordot(y_cores[k-1].reshape({-1, y_cores[k-1].sizes()[2]}), R, {1}, {1}); + y_cores[k-1] = core_next.reshape({-1, M[k-1], rnew}); + + + auto Phi = at::tensordot(Phis[k+1], at::conj(x_cores[k]), {2}, {2}); + Phi = at::tensordot(at::conj(A_cores[k]), Phi, {2,3}, {3,1}); + Phi = at::tensordot(y_cores[k], Phi, {1,2}, {1,2}); + + //auto Phi = at::einsum("ijk,mnk->ijmn",{Phis[k+1],at::conj(x_cores[k])}); + //Phi = at::einsum("ijkl,mlnk->ijmn",{at::conj(A_cores[k]),Phi}); + //Phi = at::einsum("ijkl,mjk->mil",{Phi,y_cores[k]}); + + Phis[k] = Phi.contiguous().clone(); + + } + + for(int k = 0; k ijlm",{Phis[k],at::conj(x_cores[k])}); + // W1 = at::einsum("ijkl,mikn->mjln",{at::conj(A_cores[k]),W1}); + // + // auto W2 = at::einsum("ijk,mnk->njmi",{Phis[k+2],at::conj(x_cores[k+1])}); + // W2 = at::einsum("ijkl,klmn->ijmn",{at::conj(A_cores[k+1]),W2}); + // + // W = at::einsum("ijkl,kmln->ijmn",{W1,W2}); + + } + else + W = at::conj(W_prev); + + double b = torch::norm(W).cpu().item(); + if ( b >0 ) + { + double a = torch::norm(W-at::conj(W_prev)).cpu().item(); + delta_cores[k] = a/b; + } + else + delta_cores[k] = 0; + + if( delta_cores[k] / delta_cores_prev[k] >= 1 && delta_cores[k]>eps) + r_enlarge[k] += 1; + + if( delta_cores[k]/delta_cores_prev[k] < 0.1 && delta_cores[k] 1 ? r_enlarge[k]-1 : 1; + + at::Tensor U,S,V; + + std::tie(U, S, V) = at::linalg_svd(W.reshape({W.sizes()[0]*W.sizes()[1], -1}), false); + + int64_t r_new = rank_chop(S.cpu(), b*eps/(std::pow((double)d, last ? 0.5 : 1.5)) ); //<<<<<<<<<<<< + + // r_new = rank_chop(S.cpu().numpy(),(b.cpu()*eps/(d**(0.5 if last else 1.5))).numpy()) + + if(!last) + r_new += r_enlarge[k]; + + r_new = std::min(std::min(r_new, S.sizes()[0]), rmax); + r_new = r_new > 1 ? r_new : 1; + + at::Tensor W1 = U.index({torch::indexing::Ellipsis, torch::indexing::Slice(0,r_new,1)}); + at::Tensor W2 = V.index({torch::indexing::Slice(0,r_new,1), torch::indexing::Ellipsis}).t() * S.index({torch::indexing::Slice(0, r_new, 1)}); + + if( swp < nswp - 1) + { + at::Tensor Rmat; + auto tmp_tens = at::cat({W1, torch::randn({W1.sizes()[0], kickrank}, options)}, 1); + std::tie(W1, Rmat) = at::linalg_qr(tmp_tens); + W2 = at::cat({W2, torch::zeros({W2.sizes()[0], kickrank}, options)}, 1); + W2 = at::tensordot(Rmat, W2, {1},{1}); + r_new = W1.sizes()[1]; + } + else + W2 = W2.t(); + + if(verb) + { + std::cout << "\tcore " << k << ": delta " << delta_cores[k] << " rank " << ry[k+1] << " -> " << r_new << std::endl; + } + ry[k+1] = r_new; + + y_cores[k] = at::conj(W1.reshape({ry[k], M[k], r_new})); + y_cores[k+1] = at::conj(W2.reshape({r_new, M[k+1], ry[k+2]})); + + // auto Wc = at::conj(at::tensordot(y_cores[k]), y_cores[k+1], {2}, {0})); + + auto Phi_next = at::tensordot(Phis[k], at::conj(x_cores[k]), {2}, {0}); + Phi_next = at::tensordot(Phi_next, at::conj(A_cores[k]), {1,2}, {0, 2}); // result ilmn + Phi_next = at::tensordot(y_cores[k], Phi_next, {0,1}, {0,2}); + Phi_next = Phi_next.permute({0,2,1}); + + // auto Phi_next = at::einsum("ijk,kmn->ijmn",{Phis[k],at::conj(x_cores[k])}); // # shape rk-1 x rAk-1 x Nk x rxk + // Phi_next = at::einsum("ijkl,jmkn->imnl",{Phi_next,at::conj(A_cores[k])}); //# shape rk-1 x Mk x rAk x rxk + // Phi_next = at::einsum("ijm,ijkl->mkl",{y_cores[k],Phi_next}); + + Phis[k+1] = Phi_next.contiguous().clone(); + } + + if(last) + break; + + auto max_delta = std::max_element(delta_cores.begin(), delta_cores.end()); + + if(*max_delta < eps) + last = true; + + delta_cores_prev = delta_cores; + } + + + if(verb){ + auto diff_time = std::chrono::high_resolution_clock::now() - tme_total; + std::cout << std::endl << "Finished after " << (swp < nswp ? swp+1 : swp) <<" sweeps and "<< (double)(std::chrono::duration_cast(diff_time)).count()/1000000.0 << " seconds" << std::endl << std::endl; + } + + + return y_cores; +} \ No newline at end of file diff --git a/torchtt/cpp/full.h b/torchtt/cpp/full.h new file mode 100644 index 0000000..ae03e45 --- /dev/null +++ b/torchtt/cpp/full.h @@ -0,0 +1,12 @@ +#include "define.h" + +torch::Tensor full(std::vector cores) +{ + torch::Tensor t = cores[0].index({0, torch::indexing::Ellipsis}); + for (int i = 1; i < cores.size(); i++) + { + + t = torch::tensordot(t, cores[i], {i}, {0}); + } + return t.index({torch::indexing::Ellipsis, 0}); +} diff --git a/torchtt/cpp/gmres.h b/torchtt/cpp/gmres.h new file mode 100644 index 0000000..16f93b1 --- /dev/null +++ b/torchtt/cpp/gmres.h @@ -0,0 +1,310 @@ +#include "define.h" +#include +#include + + + +/** + * @brief givensrotation. + * + * @tparam T typename (double or float). + * @param[in] v1 the first value. + * @param[in] v2 the second value. + * @return std::tuple + */ +template std::tuple givens_rotation(T v1, T v2){ + T den = std::sqrt(v1*v1+v2*v2); + return std::make_tuple(v1/den, v2/den); +} + +template void apply_givens_rotation_cpu(T *h, T *cs, T *sn, uint64_t k, T &cs_k, T &sn_k){ + + for(int i = 0; i < k-1; ++i){ + T temp = cs[i]* h[i] + sn[i] * h[i+1]; + h[i+1] = -sn[i] * h[i] + cs[i] * h[i+1]; + h[i] = temp; + } + std::tie(cs_k, sn_k) = givens_rotation(h[k-1], h[k]); + + h[k-1] = cs_k * h[k-1] + sn_k * h[k]; + h[k] = 0.0; +} + + +template +void gmres_single(at::Tensor &solution, int &flag, int &nit, AMENsolveMV &Op, at::Tensor &rhs, at::Tensor &x0, uint64_t size, uint64_t iters, T threshold){ + + bool converged = false; + + at::Tensor r = rhs - Op.matvec(x0); + + T b_norm = torch::norm(rhs).item(); + T error = torch::norm(r).item() / b_norm; + + T * sn = new T[iters]; + T * cs = new T[iters]; + T * e1 = new T[iters+1]; + for (int i=0; i(); + + if(r_norm<=0){ + flag = 1; + nit = 0; + solution = x0.clone(); + // free memory + delete [] sn; + delete [] cs; + delete [] e1; + return; + } + + std::vector Q; + Q.push_back(r.squeeze() / r_norm); + + at::Tensor H, beta; + if(std::is_same::value){ + auto options = torch::TensorOptions().dtype(torch::kFloat32); + beta = torch::zeros(iters+1, options); + H = torch::zeros({iters+1, iters}, options); + } + else{ + auto options = torch::TensorOptions().dtype(torch::kFloat64); + beta = torch::zeros(iters+1, options); + H = torch::zeros({iters+1, iters}, options); + } + + auto betaA = beta.accessor(); + auto HA = H.accessor(); + betaA[0] = r_norm; + + int k; + for(k = 0; k(diff_time)).count()/1000 << std::endl; + + // ts = std::chrono::high_resolution_clock::now(); + + // #pragma omp parallel for num_threads(32) + for(int i=0;i(); + q -= (HA[i][k] * Q[i]).reshape({-1,1}); + } + // diff_time = std::chrono::high_resolution_clock::now() - ts; + // std::cout << " PROJ " << (double)(std::chrono::duration_cast(diff_time)).count()/1000 << std::endl; + + // ts = std::chrono::high_resolution_clock::now(); + T h = torch::norm(q).item(); + + q /= h; + + HA[k+1][k] = h; + Q.push_back(q.clone().squeeze()); + + T c,s; + at::Tensor htemp = H.index({torch::indexing::Slice(0,k+2), k}).contiguous(); + apply_givens_rotation_cpu(htemp.data_ptr(), cs, sn, k+1, c, s); + H.index_put_({torch::indexing::Slice(0,k+2), k}, htemp); + cs[k] = c; + sn[k] = s; + + betaA[k+1] = -sn[k]*betaA[k]; + betaA[k] = cs[k]*betaA[k]; + error = std::abs(betaA[k+1])/b_norm; + // diff_time = std::chrono::high_resolution_clock::now() - ts; + // std::cout << " REST " << (double)(std::chrono::duration_cast(diff_time)).count()/1000 << std::endl; + if(error<=threshold) + { + flag = 1; + break; + } + } + k = k(); + + nit = k+1; + // free memory + delete [] sn; + delete [] cs; + delete [] e1; + +} + +template +void gmres(at::Tensor &solution, int &flag, int &nit, AMENsolveMV &Op, at::Tensor &rhs, at::Tensor &x0, uint64_t size, uint64_t max_iters, T threshold, uint64_t resets ){ + nit = 0; + flag = 0; + + auto xs = x0; + for(int r =0;r(solution, flag, nowit, Op, rhs, xs, size, max_iters, threshold); + nit+=nowit; + if(flag==1){ + break; + } + xs = solution.clone(); + } +} + + +void gmres_double_cpu(double *solution, + int &flag, + int &nit, + std::function matvec, + double *rhs, + int64_t size, + int64_t max_iters, + double threshold, + int64_t resets, + bool debug) +{ + + nit = 0; + flag = 0; + + int64_t inc1 = 1; + char transN = 'N'; + double alpha1 = 1.0; + double alpham1 = -1.0; + + double *sn = new double[max_iters]; + double *cs = new double[max_iters]; + + double *Q = nullptr; + //double *q = new double[size]; + double *H = new double[max_iters*(max_iters+1)]; + double *beta = new double[max_iters+1]; + double *work1 = new double [max_iters+1]; + + int64_t *piv_tmp = new int64_t[size]; + + double b_norm; + double error; + + b_norm = BLAS::nrm2(&size, rhs, &inc1); + + if(b_norm <= 0) + { + double alpha0 = 0.0; + BLAS::scal(&size, &alpha0, solution, &inc1); + nit = 1; + flag = 1; + } + else + { + + if(Q == nullptr) + Q = new double[size*(max_iters+1)]; + + for(uint64_t r=0; r(&size, &alpham1, Q, &inc1); + BLAS::axpy(&size, &alpha1, rhs, &inc1, Q, &inc1); + + auto r_norm = BLAS::nrm2(&size, Q, &inc1); + + if( ! r_norm>0 ) + { + flag = 1; + nit = 0; + break; + } + + double tmp = 1/r_norm; + BLAS::scal(&size, &tmp, Q, &inc1); + + //if(Q == nullptr) + // Q = new double[size*(max_iters+1)]; + + // fill with 0 + std::fill_n(beta, max_iters+1, 0); + std::fill_n(cs, max_iters+1, 0); + std::fill_n(sn, max_iters+1, 0); + std::fill_n(H, (max_iters+1)*max_iters, 0); + + + error = r_norm / b_norm; + beta[0] = r_norm; + + for(k = 0; k>> + double c,s; + apply_givens_rotation_cpu(H+k*(max_iters+1), cs, sn, k+1, c, s); + cs[k] = c; + sn[k] = s; + + beta[k+1] = -sn[k]*beta[k]; + beta[k] = cs[k]*beta[k]; + error = std::abs(beta[k+1])/b_norm; + + if(debug) + std::cout << "Iteration " << k << " error " << error << std::endl; + if(error<=threshold) + { + flag = 1; + break; + } + } + + k = k +class AMENsolveMV_cpu{ + +private: + T * Phi_left; + T * Phi_right; + T * coreA; + T * J; + int prec; + T * work1; + T * work2; + int64_t r,R,n,s,S,l,L; + +public: + ~AMENsolveMV_cpu(){ + delete [] Phi_left; + delete [] Phi_right; + delete [] coreA; + delete [] work1; + delete [] work2; + + if(!prec) + delete [] J; + } + + + AMENsolveMV_cpu(at::Tensor &Phi_left, at::Tensor &Phi_right, at::Tensor & coreA, int prec) + { + int64_t inc1 = 1; + int64_t size; + + s = coreA.sizes()[0]; + n = coreA.sizes()[1]; + S = coreA.sizes()[3]; + l = Phi_left.sizes()[0]; + r = Phi_left.sizes()[2]; + L = Phi_right.sizes()[0]; + R = Phi_right.sizes()[2]; + + if(this->prec == C_PREC){ + auto Jl = at::tensordot(at::diagonal(Phi_left,0,0,2), coreA, {0}, {0}); + auto Jr = at::diagonal(Phi_right, 0, 0, 2); + auto J = at::linalg_inv(at::tensordot(Jl,Jr,{3},{0}).permute({0,3,1,2})); + // TODO: switch to column major + + size = r*R*n*n; + this->J = new T[size]; + BLAS::copy(&size, J.data_ptr(), &inc1, this->J, &inc1); + } + else if(this->prec == R_PREC){ + auto Jl = at::tensordot(at::diagonal(Phi_left,0,0,2), coreA, {0},{0}); // sd,smnS->dmnS + auto Jt = at::tensordot(Jl, Phi_right, {3}, {1}); // dmnS,LSR->dmnLR + Jt = Jt.permute({0, 1, 3, 2, 4}); + auto sh = Jt.sizes(); + auto Jt2 = Jt.reshape({-1, Jt.sizes()[1]*Jt.sizes()[2], Jt.sizes()[3]*Jt.sizes()[4]}); + auto J = at::linalg_inv(Jt2).reshape(sh); + // TODO : switch to column major + + size = r*n*L*n*R; + this->J = new T[size]; + BLAS::copy(&size, J.data_ptr(), &inc1, this->J, &inc1); + + } + + // TODO: switch to column major !!! + + auto Pl = Phi_left.contiguous(); + auto Pr = Phi_right.contiguous(); + + Pl = Pl.permute({0,2,1}).contiguous(); + Pr = Pr.permute({2,1,0}).contiguous(); + + this->coreA = new T[s*n*n*S]; + this->Phi_left = new T[l*r*s]; + this->Phi_right = new T[R*S*L]; + size = std::max(r*n*S*L,s*n*r*L); + this->work1 = new T[size]; + this->work2 = new T[size]; + + // copy everything + + size = s*n*n*S; + BLAS::copy(&size, coreA.contiguous().data_ptr(), &inc1, this->coreA, &inc1); + size = l*r*s; + BLAS::copy(&size, Pl.data_ptr(), &inc1, this->Phi_left, &inc1); + size = R*S*L; + BLAS::copy(&size, Pr.data_ptr(), &inc1, this->Phi_right, &inc1); + + + } + + void matvec(T *in, T *out) + { + //w = tn.einsum('lsr,smnS,LSR,rnR->lmL',self.Phi_left,self.coreA,self.Phi_right,x) + //w = tn.einsum('rsl,smnS,RSL,rnR->lmL',self.Phi_left,self.coreA,self.Phi_right,x) + char tN = 'N'; + char tC = 'C'; + T alpha1 = 1.0; + T alpha0 = 0.0; + int64_t M, N, K; + + M = r*n; + N = S*L; + K = R; + BLAS::gemm(&tN, &tN, &M, &N, &K, &alpha1, in, &M, this->Phi_right, &K, &alpha0, this->work1, &M); // work1 is now rnSL + // >>>>> transpose !!!!! + + // work is now LrnS + M = L*r; + N = s*n; + K = n*S; + BLAS::gemm(&tN, &tC, &M, &N, &K, &alpha1, this->work1, &M, this->coreA, &N, &alpha0, this->work2, &M); + // work2 is now Lrsm + + // >>>>>> transpose !!! + + //work is now rsmL + M = l; + N = n*L; + K = r*s; + BLAS::gemm(&tN, &tN, &M, &N, &K, &alpha1, this->work2, &M, this->Phi_left, &K, &alpha0, out, &M); + } + +}; + +template class AMENsolveMV{ + +private: + at::Tensor Phi_left; + at::Tensor Phi_right; + at::Tensor coreA; + at::Tensor J; + int prec; + at::IntArrayRef shape; + at::TensorOptions options; + + T * Phi_left_ptr; + T * Phi_right_ptr; + T * coreA_ptr; + T * J_ptr; + T * work1_ptr; + T * work2_ptr; + int64_t r,R,n,s,S,l,L; +public: + AMENsolveMV(){ + ; + } + + void setter(at::Tensor &Phi_left, at::Tensor &Phi_right, at::Tensor & coreA, at::IntArrayRef shape, int prec, at::TensorOptions options){ + this->prec = prec; + this->options = options; + this->shape = shape; + + this->Phi_left = Phi_left; //torch::from_blob(Phi_left.contiguous().data_ptr(), Phi_left.sizes(), options); + this->Phi_right = Phi_right; //torch::from_blob(Phi_right.contiguous().data_ptr(), Phi_right.sizes(), options); + this->coreA = coreA; // torch::from_blob(coreA.contiguous().data_ptr(), coreA.sizes(), options); + if(this->prec == C_PREC){ + auto Jl = at::tensordot(at::diagonal(Phi_left,0,0,2), coreA, {0}, {0}); + auto Jr = at::diagonal(Phi_right, 0, 0, 2); + this->J = at::linalg_inv(at::tensordot(Jl,Jr,{3},{0}).permute({0,3,1,2})); + } + else if(this->prec == R_PREC){ + auto Jl = at::tensordot(at::diagonal(Phi_left,0,0,2), coreA, {0},{0}); // sd,smnS->dmnS + auto Jt = at::tensordot(Jl, Phi_right, {3}, {1}); // dmnS,LSR->dmnLR + Jt = Jt.permute({0, 1, 3, 2, 4}); + auto sh = Jt.sizes(); + auto Jt2 = Jt.reshape({-1, Jt.sizes()[1]*Jt.sizes()[2], Jt.sizes()[3]*Jt.sizes()[4]}); + this->J = at::linalg_inv(Jt2).reshape(sh); + //std::cout << "READY 1" <prec == C_PREC) { + uint64_t s0,s1,s2; + s0 = sol.sizes()[0]; + s1 = sol.sizes()[1]; + s2 = sol.sizes()[2]; + + at::Tensor tmp = sol.permute({0,2,1}).reshape({s0, s2, s1, 1}); + ret = at::linalg_matmul(this->J, tmp).permute({0,2,1,3}).reshape({s0, s1, s2}); + } + else if(this->prec == R_PREC){ + ret = at::einsum("rnR,rmLnR->rmL", {sol, this->J}); + + } + + return ret; + } + + at::Tensor matvec(at::Tensor &x, bool use_prec = true){ + at::Tensor tmp; + + if(!use_prec || this->prec == NO_PREC){ + tmp = x.reshape(this->shape); + } + else + { + tmp = apply_prec(x.reshape(this->shape)); + } +// w = tn.einsum('lsr,smnS,LSR,rnR->lmL',self.Phi_left,self.coreA,self.Phi_right,x) + auto w = at::tensordot(tmp, this->Phi_left, {0}, {2}); + auto w2 = at::tensordot(w, this->coreA, {0,3}, {2,0}); + auto w3 = at::tensordot(w2, this->Phi_right, {0,3}, {2,1}); + return w3.reshape({this->shape[0]*this->shape[1]*this->shape[2],1}); + + } + + void matvec_cpu(T *in, T *out){ + + + } +}; \ No newline at end of file diff --git a/torchtt/cpp/ortho.h b/torchtt/cpp/ortho.h new file mode 100644 index 0000000..855c865 --- /dev/null +++ b/torchtt/cpp/ortho.h @@ -0,0 +1,103 @@ +#ifndef ORTHO +#define ORTHO +#include "define.h" + +void perform_QR(at::Tensor &Q, at::Tensor &R, at::Tensor &M){ + at::linalg_qr_out(Q,R,M); +} + +/** + * @brief chop the rank up to a prescribed accuracy. + * + * @param s the singular values vactor. + * @param eps the relative accuracy. + * @return int + */ +int rank_chop(torch::Tensor s, double eps) +{ + int n = s.sizes()[0]; + int r = n - 1; + if (torch::norm(s).item() == 0.0) + return 1; + + if (eps <= 0.0) + return r; + + double *ss = (double *)s.data_ptr(); + + while (r > 0) + { + double sum = 0.0; + + for (int k = r; k < n; k++) + sum += ss[k] * ss[k]; + + if (sum >= eps * eps) + break; + + r--; + } + r++; + r = r > 0 ? r : 1; + + return r; +} + +void rl_orthogonal_this(std::vector &cores, std::vector &shape, std::vector &rank){ + + uint64_t d = shape.size(); + + + at::Tensor core_now; + + + for(int i=d-1;i>0;i--){ + core_now = cores[i].reshape({cores[i].sizes()[0], cores[i].sizes()[1]* cores[i].sizes()[2]}).t(); + + // perform QR + // perform_QR(Q,R,core_now); + std::tuple QR = at::linalg_qr(core_now); + + + uint64_t r_new; // = core_now.sizes()[0] < core_now.sizes()[1] ? core_now.sizes()[0] : core_now.sizes()[1]; + r_new = std::get<1>(QR).sizes()[0]; + + cores[i] = std::get<0>(QR).t().reshape({r_new,shape[i],-1}); + rank[i] = r_new; + + cores[i-1] = (cores[i-1].reshape({-1,cores[i-1].sizes()[2]}).matmul(std::get<1>(QR).t())).reshape({cores[i-1].sizes()[0],shape[i-1],-1}); + + } + + +} + + + +void lr_orthogonal(std::vector &cores, std::vector &shape, std::vector &rank){ + int d = shape.size(); + + at::Tensor core_now; + + + + for(int i=0;i QR = at::linalg_qr(core_now); + + rank[i+1] = std::get<0>(QR).sizes()[1]; + + cores[i] = std::get<0>(QR).reshape({rank[i], shape[i], -1}); + + cores[i+1] = (std::get<1>(QR).matmul(cores[i+1].reshape({cores[i+1].sizes()[0],-1}))).reshape({cores[i].sizes()[2], shape[i+1],-1}); + + } + + + +} + +#endif \ No newline at end of file diff --git a/torchtt/cpp/test_gmres.cpp b/torchtt/cpp/test_gmres.cpp new file mode 100644 index 0000000..c62e40d --- /dev/null +++ b/torchtt/cpp/test_gmres.cpp @@ -0,0 +1,253 @@ +#include "BLAS.h" +#include +#include +#include +#include +#include + +#define N 20000 + +/** + * @brief givensrotation. + * + * @tparam T typename (double or float). + * @param[in] v1 the first value. + * @param[in] v2 the second value. + * @return std::tuple + */ +template std::tuple givens_rotation(T v1, T v2){ + T den = std::sqrt(v1*v1+v2*v2); + return std::make_tuple(v1/den, v2/den); +} + +template void apply_givens_rotation_cpu(T *h, T *cs, T *sn, uint64_t k, T &cs_k, T &sn_k){ + + for(int i = 0; i < k-1; ++i){ + T temp = cs[i]* h[i] + sn[i] * h[i+1]; + h[i+1] = -sn[i] * h[i] + cs[i] * h[i+1]; + h[i] = temp; + } + std::tie(cs_k, sn_k) = givens_rotation(h[k-1], h[k]); + + h[k-1] = cs_k * h[k-1] + sn_k * h[k]; + h[k] = 0.0; +} + + +void gmres_double_cpu(double *solution, + int &flag, + int &nit, + std::function matvec, + double *rhs, + int64_t size, + int64_t max_iters, + double threshold, + int64_t resets, + bool debug) +{ + + nit = 0; + flag = 0; + + int64_t inc1 = 1; + char transN = 'N'; + double alpha1 = 1.0; + double alpham1 = -1.0; + + double *sn = new double[max_iters]; + double *cs = new double[max_iters]; + + double *Q = nullptr; + //double *q = new double[size]; + double *H = new double[max_iters*(max_iters+1)]; + double *beta = new double[max_iters+1]; + double *work1 = new double [max_iters+1]; + + int64_t *piv_tmp = new int64_t[size]; + + double b_norm; + double error; + + b_norm = BLAS::nrm2(&size, rhs, &inc1); + + if(b_norm <= 0) + { + double alpha0 = 0.0; + BLAS::scal(&size, &alpha0, solution, &inc1); + nit = 1; + flag = 1; + } + else + { + + if(Q == nullptr) + Q = new double[size*(max_iters+1)]; + + for(uint64_t r=0; r(&size, &alpham1, Q, &inc1); + BLAS::axpy(&size, &alpha1, rhs, &inc1, Q, &inc1); + + auto r_norm = BLAS::nrm2(&size, Q, &inc1); + + if( ! r_norm>0 ) + { + flag = 1; + nit = 0; + break; + } + + double tmp = 1/r_norm; + BLAS::scal(&size, &tmp, Q, &inc1); + + //if(Q == nullptr) + // Q = new double[size*(max_iters+1)]; + + // fill with 0 + std::fill_n(beta, max_iters+1, 0); + std::fill_n(cs, max_iters+1, 0); + std::fill_n(sn, max_iters+1, 0); + std::fill_n(H, (max_iters+1)*max_iters, 0); + + + error = r_norm / b_norm; + beta[0] = r_norm; + + for(k = 0; k>> + double c,s; + apply_givens_rotation_cpu(H+k*(max_iters+1), cs, sn, k+1, c, s); + cs[k] = c; + sn[k] = s; + + beta[k+1] = -sn[k]*beta[k]; + beta[k] = cs[k]*beta[k]; + error = std::abs(beta[k+1])/b_norm; + + if(debug) + std::cout << "Iteration " << k << " error " << error << std::endl; + if(error<=threshold) + { + flag = 1; + break; + } + } + + k = k start, end; + + start = std::chrono::system_clock::now(); + + int nit, flag; + gmres_double_cpu(solution, + flag, + nit, + mv, + rhs, + N, + 100, + 1e-8, + 20, + false); + + end = std::chrono::system_clock::now(); + + std::chrono::duration elapsed_seconds = end - start; + std::time_t end_time = std::chrono::system_clock::to_time_t(end); + + mv(solution, tmp); + + int64_t sz = N; + int64_t inc1 = 1; + double m1 = -1; + + BLAS::axpy(&sz, &m1, rhs, &inc1, tmp, &inc1); + double nrm = BLAS::nrm2(&sz, tmp, &inc1) / N; + + std::cout << "Elapsed time: " << elapsed_seconds.count() << "s\n"; + std::cout << "Flag " << flag << " number of iterations " << nit <