Skip to content

Commit

Permalink
Complex matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Nov 9, 2023
1 parent a54cc1e commit ec3e01a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 31 deletions.
50 changes: 19 additions & 31 deletions librapid/bindings/generators/arrayGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
("uint64_t", "UInt64"),
("float", "Float"),
("double", "Double"),
("lrc::Complex<float>", "ComplexFloat"),
("lrc::Complex<double>", "ComplexDouble")
# ("lrc::Complex<float>", "ComplexFloat"),
# ("lrc::Complex<double>", "ComplexDouble")
]:
for backend in ["CPU"]: # ["CPU", "OpenCL", "CUDA"]:
arrayTypes.append({
Expand Down Expand Up @@ -55,18 +55,6 @@ def generateFunctionsForArray(config):
]
),

# Move constructor
# function.Function(
# name="__init__",
# args=[
# argument.Argument(
# name="other",
# type=generateCppArrayType(config),
# move=True
# )
# ]
# ),

# Shape
function.Function(
name="__init__",
Expand Down Expand Up @@ -113,23 +101,23 @@ def generateFunctionsForArray(config):
),
]

# Static fromData (n dimensions)
for n in range(1, 9):
cppType = ("std::vector<" * n) + config['scalar'] + (">" * n)

methods.append(
function.Function(
name="__init__",
args=[
argument.Argument(
name=f"array{n}D",
type=cppType,
const=True,
ref=True,
)
]
)
)
# # Static fromData (n dimensions)
# for n in range(1, 9):
# cppType = ("std::vector<" * n) + config['scalar'] + (">" * n)
#
# methods.append(
# function.Function(
# name="__init__",
# args=[
# argument.Argument(
# name=f"array{n}D",
# type=cppType,
# const=True,
# ref=True,
# )
# ]
# )
# )

methods += [
# Get item
Expand Down
7 changes: 7 additions & 0 deletions librapid/cxxblas/auxiliary/complex.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
#include <complex>
#include "cxxblas/auxiliary/complex.h"

namespace librapid {
template<typename T>
class Complex;
}

namespace cxxblas {
template<typename T>
librapid::Complex<T> conjugate(const librapid::Complex<T> &val);

template<typename T>
typename cxxblas::RestrictTo<std::is_arithmetic<T>::value, const T &>::Type
Expand Down
7 changes: 7 additions & 0 deletions librapid/include/librapid/math/complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,13 @@ namespace librapid {
} // namespace typetraits
} // namespace librapid

namespace cxxblas {
template<typename T>
librapid::Complex<T> conjugate(const librapid::Complex<T> &val) {
return librapid::conj(val);
}
} // namespace cxxblas

// Support FMT printing
#ifdef FMT_API
template<typename T, typename Char>
Expand Down

0 comments on commit ec3e01a

Please sign in to comment.