From 19bde5cc27fe158a4cb6a94e098a924a37d6e97a Mon Sep 17 00:00:00 2001 From: Bilal Aamer <52858537+bilal-aamer@users.noreply.github.com> Date: Tue, 15 Aug 2023 23:49:46 +0000 Subject: [PATCH] put_along_axis commit --- .../frontends/paddle/tensor/manipulation.py | 16 ++++++++ .../test_tensor/test_manipulation.py | 37 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/manipulation.py b/ivy/functional/frontends/paddle/tensor/manipulation.py index 7d0fa71ebe270..110cec20e3055 100644 --- a/ivy/functional/frontends/paddle/tensor/manipulation.py +++ b/ivy/functional/frontends/paddle/tensor/manipulation.py @@ -169,3 +169,19 @@ def take_along_axis(arr, indices, axis): @to_ivy_arrays_and_back def rot90(x, k=1, axes=(0, 1), name=None): return ivy.rot90(x, k=k, axes=axes) + + +@with_supported_dtypes( + { + "2.5.1 and below": ( + "int32", + "int64", + "float32", + "float64", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def put_along_axis(arr, indices, values, axis, reduce="assign"): + return ivy.put_along_axis(arr, indices, values, axis) \ No newline at end of file diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py index b28dfbc686720..ffa5abb30180f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_manipulation.py @@ -674,3 +674,40 @@ def test_paddle_rot90( k=k, axes=tuple(axes), ) + + +# put_along_axis +@handle_frontend_test( + fn_tree="numpy.put_along_axis", + dtype_x_indices_axis=helpers.array_indices_put_along_axis( + array_dtypes=helpers.get_dtypes(kind="valid"), + indices_dtypes=["int32", "int64"], + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, + ), + test_with_out=st.just(False), +) +def test_numpy_put_along_axis( + *, + dtype_x_indices_axis, + test_flags, + frontend, + fn_tree, + on_device, + backend_fw, +): + dtypes, x, indices, axis, values, _ = dtype_x_indices_axis + helpers.test_frontend_function( + input_dtypes=dtypes, + test_flags=test_flags, + backend_to_test=backend_fw, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + arr=x, + indices=indices, + axis=axis, + values=values, + ) \ No newline at end of file