From ec3e01a5328549d4dcc522ef7fde193bb2e50514 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Thu, 9 Nov 2023 12:37:54 +0000 Subject: [PATCH] Complex matmul --- .../bindings/generators/arrayGenerator.py | 50 +++++++------------ librapid/cxxblas/auxiliary/complex.tcc | 7 +++ librapid/include/librapid/math/complex.hpp | 7 +++ 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/librapid/bindings/generators/arrayGenerator.py b/librapid/bindings/generators/arrayGenerator.py index bffa169ff..026f411e8 100644 --- a/librapid/bindings/generators/arrayGenerator.py +++ b/librapid/bindings/generators/arrayGenerator.py @@ -15,8 +15,8 @@ ("uint64_t", "UInt64"), ("float", "Float"), ("double", "Double"), - ("lrc::Complex", "ComplexFloat"), - ("lrc::Complex", "ComplexDouble") + # ("lrc::Complex", "ComplexFloat"), + # ("lrc::Complex", "ComplexDouble") ]: for backend in ["CPU"]: # ["CPU", "OpenCL", "CUDA"]: arrayTypes.append({ @@ -55,18 +55,6 @@ def generateFunctionsForArray(config): ] ), - # Move constructor - # function.Function( - # name="__init__", - # args=[ - # argument.Argument( - # name="other", - # type=generateCppArrayType(config), - # move=True - # ) - # ] - # ), - # Shape function.Function( name="__init__", @@ -113,23 +101,23 @@ def generateFunctionsForArray(config): ), ] - # Static fromData (n dimensions) - for n in range(1, 9): - cppType = ("std::vector<" * n) + config['scalar'] + (">" * n) - - methods.append( - function.Function( - name="__init__", - args=[ - argument.Argument( - name=f"array{n}D", - type=cppType, - const=True, - ref=True, - ) - ] - ) - ) + # # Static fromData (n dimensions) + # for n in range(1, 9): + # cppType = ("std::vector<" * n) + config['scalar'] + (">" * n) + # + # methods.append( + # function.Function( + # name="__init__", + # args=[ + # argument.Argument( + # name=f"array{n}D", + # type=cppType, + # const=True, + # ref=True, + # ) + # ] + # ) + # ) methods += [ # Get item diff --git a/librapid/cxxblas/auxiliary/complex.tcc b/librapid/cxxblas/auxiliary/complex.tcc index 687d52c6c..3d0cfac4a 100644 --- a/librapid/cxxblas/auxiliary/complex.tcc +++ b/librapid/cxxblas/auxiliary/complex.tcc @@ -36,7 +36,14 @@ #include #include "cxxblas/auxiliary/complex.h" +namespace librapid { + template + class Complex; +} + namespace cxxblas { + template + librapid::Complex conjugate(const librapid::Complex &val); template typename cxxblas::RestrictTo::value, const T &>::Type diff --git a/librapid/include/librapid/math/complex.hpp b/librapid/include/librapid/math/complex.hpp index f4c8da8a6..ab5ace962 100644 --- a/librapid/include/librapid/math/complex.hpp +++ b/librapid/include/librapid/math/complex.hpp @@ -2095,6 +2095,13 @@ namespace librapid { } // namespace typetraits } // namespace librapid +namespace cxxblas { + template + librapid::Complex conjugate(const librapid::Complex &val) { + return librapid::conj(val); + } +} // namespace cxxblas + // Support FMT printing #ifdef FMT_API template