From e32891253230483bf8828c155d9ab1dfca3ac2ff Mon Sep 17 00:00:00 2001 From: Filipe Maia Date: Mon, 24 Oct 2016 00:15:39 +0200 Subject: [PATCH] Add the out argument to dot. Fixes #24. --- afnumpy/linalg/linalg.py | 12 ++++++------ tests/test_linalg.py | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/afnumpy/linalg/linalg.py b/afnumpy/linalg/linalg.py index 9cd68d7..a680b65 100644 --- a/afnumpy/linalg/linalg.py +++ b/afnumpy/linalg/linalg.py @@ -5,6 +5,7 @@ from afnumpy import asarray, sqrt, abs from afnumpy.lib import asfarray from .. import private_utils as pu +from ..decorators import * def isComplexType(t): return issubclass(t, complexfloating) @@ -13,14 +14,14 @@ def vdot(a, b): s = arrayfire.dot(arrayfire.conjg(a.flat.d_array), b.flat.d_array) return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()] -# TODO: Implement multidimensional dot +@outufunc def dot(a, b): # Arrayfire requires that the types match for dot and matmul res_dtype = numpy.result_type(a,b) a = a.astype(res_dtype, copy=False) b = b.astype(res_dtype, copy=False) if a.ndim == 1 and b.ndim == 1: - s = arrayfire.dot((a.flat.d_array), b.flat.d_array) + s = arrayfire.dot(a.d_array, b.d_array) return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()] a_shape = a.shape @@ -33,7 +34,9 @@ def dot(a, b): if a.ndim == 2 and b.ndim == 2: # Notice the order of the arguments to matmul. It's not a bug! s = arrayfire.matmul(b.d_array, a.d_array) - return afnumpy.ndarray(pu.af_shape(s), dtype=pu.typemap(s.dtype()), af_array=s) + return afnumpy.ndarray(pu.af_shape(s), dtype=pu.typemap(s.dtype()), + af_array=s) + # Multidimensional dot is done with loops # Calculate the shape of the result array @@ -43,9 +46,6 @@ def dot(a, b): b_shape.pop(-2) res_shape = a_shape + b_shape - # Initialize the output array - res = afnumpy.empty(res_shape, dtype=res_dtype) - # Make sure the arrays are at least 3D if a.ndim < 3: a = a.reshape((1,)+a.shape) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 5c99e72..6629d56 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -42,6 +42,12 @@ def test_dot_2D(): a = afnumpy.array(b) fassert(afnumpy.dot(a,a), numpy.dot(b,b)) + a = numpy.random.random((3,3))+numpy.random.random((3,3))*1.0j + b = numpy.random.random((3,3)) + fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b)) + out = afnumpy.array(a) + fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b),out=out), numpy.dot(a,b)) + def test_dot_3D(): b = numpy.random.random((3,3,3))+numpy.random.random((3,3,3))*1.0j a = afnumpy.array(b)