Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Nov 7, 2023
1 parent df7c735 commit 3e815d0
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 104 deletions.
131 changes: 42 additions & 89 deletions librapid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,56 +103,23 @@ def Array(*args, **kwargs):
backend = CPU

if backend == CPU:
cpuTypeList = {
float32: _librapid.ArrayFloatCPU,
float64: _librapid.ArrayDoubleCPU,
int32: _librapid.ArrayInt32CPU,
int64: _librapid.ArrayInt64CPU,
uint32: _librapid.ArrayUInt32CPU,
uint64: _librapid.ArrayUInt64CPU,
# Complex32: _librapid.ArrayComplexFloatCPU,
# Complex64: _librapid.ArrayComplexDoubleCPU,
}

arrayType = cpuTypeList.get(dtype, None)
arrayType = __cpuTypeList.get(dtype, None)
elif backend == OpenCL:
if not _librapid.hasOpenCL():
raise RuntimeError("OpenCL is not supported in this build "
"of librapid. Ensure OpenCL is installed "
"on your system and reinstall librapid "
"from source.")

openclTypeList = {
float32: _librapid.ArrayFloatOpenCL,
float64: _librapid.ArrayDoubleOpenCL,
int32: _librapid.ArrayInt32OpenCL,
int64: _librapid.ArrayInt64OpenCL,
uint32: _librapid.ArrayUInt32OpenCL,
uint64: _librapid.ArrayUInt64OpenCL,
# Complex32: _librapid.ArrayComplexFloatOpenCL,
# Complex64: _librapid.ArrayComplexDoubleOpenCL,
}

arrayType = openclTypeList.get(dtype, None)
arrayType = __openclTypeList.get(dtype, None)
elif backend == CUDA:
if not _librapid.hasCUDA():
raise RuntimeError("CUDA is not supported in this build "
"of librapid. Ensure CUDA is installed "
"on your system and reinstall librapid "
"from source.")

cudaTypeList = {
float32: _librapid.ArrayFloatCUDA,
float64: _librapid.ArrayDoubleCUDA,
int32: _librapid.ArrayInt32CUDA,
int64: _librapid.ArrayInt64CUDA,
uint32: _librapid.ArrayUInt32CUDA,
uint64: _librapid.ArrayUInt64CUDA,
# Complex32: _librapid.ArrayComplexFloatCUDA,
# Complex64: _librapid.ArrayComplexDoubleCUDA,
}

arrayType = cudaTypeList.get(dtype, None)
arrayType = __cudaTypeList.get(dtype, None)
else:
raise ValueError(f"Unknown backend {backend}")

Expand All @@ -167,58 +134,6 @@ def Array(*args, **kwargs):
raise RuntimeError("Unknown error")


def isArray(obj):
"""
Checks if an object is an Array.
Parameters
----------
obj : object
The object to check.
Returns
-------
bool
True if the object is an Array, False otherwise.
"""

if type(obj) in [
_librapid.ArrayFloatCPU,
_librapid.ArrayDoubleCPU,
_librapid.ArrayInt32CPU,
_librapid.ArrayInt64CPU,
_librapid.ArrayUInt32CPU,
_librapid.ArrayUInt64CPU,
# _librapid.ArrayComplexFloatCPU,
# _librapid.ArrayComplexDoubleCPU
]:
return True

if _librapid.hasOpenCL() and type(obj) in [
_librapid.ArrayFloatOpenCL,
_librapid.ArrayDoubleOpenCL,
_librapid.ArrayInt32OpenCL,
_librapid.ArrayInt64OpenCL,
_librapid.ArrayUInt32OpenCL,
_librapid.ArrayUInt64OpenCL,
# _librapid.ArrayComplexFloatOpenCL,
# _librapid.ArrayComplexDoubleOpenCL
]:
return True

if _librapid.hasCUDA() and type(obj) in [
_librapid.ArrayFloatCUDA,
_librapid.ArrayDoubleCUDA,
_librapid.ArrayInt32CUDA,
_librapid.ArrayInt64CUDA,
_librapid.ArrayUInt32CUDA,
_librapid.ArrayUInt64CUDA,
# _librapid.ArrayComplexFloatCUDA,
# _librapid.ArrayComplexDoubleCUDA
]:
return True


def hasOpenCL():
"""
Checks if OpenCL is supported.
Expand Down Expand Up @@ -320,7 +235,6 @@ def getSeed():
exp2 = _librapid.exp2
exp10 = _librapid.exp10


float32 = DataType("float32", 4)
float64 = DataType("float64", 8)
int32 = DataType("int32", 4)
Expand All @@ -330,6 +244,45 @@ def getSeed():
# Complex32 = DataType("Complex32", 8)
# Complex64 = DataType("Complex64", 16)

__cpuTypeList = {
int32: _librapid.ArrayInt32CPU,
int64: _librapid.ArrayInt64CPU,
# uint32: _librapid.ArrayUInt32CPU,
# uint64: _librapid.ArrayUInt64CPU,
float32: _librapid.ArrayFloatCPU,
float64: _librapid.ArrayDoubleCPU,
# Complex32: _librapid.ArrayComplexFloatCPU,
# Complex64: _librapid.ArrayComplexDoubleCPU,
}

if _librapid.hasOpenCL():
__openclTypeList = {
int32: _librapid.ArrayInt32OpenCL,
int64: _librapid.ArrayInt64OpenCL,
# uint32: _librapid.ArrayUInt32OpenCL,
# uint64: _librapid.ArrayUInt64OpenCL,
float32: _librapid.ArrayFloatOpenCL,
float64: _librapid.ArrayDoubleOpenCL,
# Complex32: _librapid.ArrayComplexFloatOpenCL,
# Complex64: _librapid.ArrayComplexDoubleOpenCL,
}
else:
__openclTypeList = None

if _librapid.hasCUDA():
__cudaTypeList = {
int32: _librapid.ArrayInt32CUDA,
int64: _librapid.ArrayInt64CUDA,
# uint32: _librapid.ArrayUInt32CUDA,
# uint64: _librapid.ArrayUInt64CUDA,
float32: _librapid.ArrayFloatCUDA,
float64: _librapid.ArrayDoubleCUDA,
# Complex32: _librapid.ArrayComplexFloatCUDA,
# Complex64: _librapid.ArrayComplexDoubleCUDA,
}
else:
__cudaTypeList = None

CPU = Backend("CPU")
OpenCL = Backend("OpenCL")
CUDA = Backend("CUDA")
6 changes: 3 additions & 3 deletions librapid/bindings/generators/arrayGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

for scalar in [("int32_t", "Int32"),
("int64_t", "Int64"),
("uint32_t", "UInt32"),
("uint64_t", "UInt64"),
# ("uint32_t", "UInt32"),
# ("uint64_t", "UInt64"),
("float", "Float"),
("double", "Double"),
# ("lrc::Complex<float>", "ComplexFloat"),
Expand Down Expand Up @@ -58,7 +58,7 @@ def generateFunctionsForArray(config):
# Move constructor
# function.Function(
# name="__init__",
#  args=[
# args=[
# argument.Argument(
# name="other",
# type=generateCppArrayType(config),
Expand Down
4 changes: 2 additions & 2 deletions librapid/bindings/generators/generalArrayViewGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

for scalar in [("int32_t", "Int32"),
("int64_t", "Int64"),
("uint32_t", "UInt32"),
("uint64_t", "UInt64"),
# ("uint32_t", "UInt32"),
# ("uint64_t", "UInt64"),
("float", "Float"),
("double", "Double"),
# ("lrc::Complex<float>", "ComplexFloat"),
Expand Down
10 changes: 0 additions & 10 deletions librapid/include/librapid/array/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,16 +804,6 @@ namespace librapid {
return true;
}

/// \sa shapesMatch
// template<typename First, typename Second, typename... Rest>
// requires(typetraits::IsSizeType<First>::value && typetraits::IsSizeType<Second>::value
//&& (typetraits::IsSizeType<Rest>::value && ...)) LIBRAPID_NODISCARD LIBRAPID_INLINE bool
//shapesMatch(const First &first, const Second &second, const Rest &...shapes) { if constexpr
//(sizeof...(Rest) == 0) { return first == second; } else { return first == second &&
//shapesMatch(first, shapes...);
// }
// }

namespace detail {
template<typename First, typename Second>
struct ShapeTypeHelperImpl {
Expand Down

0 comments on commit 3e815d0

Please sign in to comment.