diff --git a/ivy/functional/backends/paddle/general.py b/ivy/functional/backends/paddle/general.py index 60f32855f4992..8bdb2a2a86cbc 100644 --- a/ivy/functional/backends/paddle/general.py +++ b/ivy/functional/backends/paddle/general.py @@ -227,6 +227,10 @@ def _gather(params1): return _gather(params) +@with_unsupported_device_and_dtypes( + {"2.5.2 and below": {"cpu": ("bfloat16", "float16")}}, + backend_version, +) def gather_nd( params: paddle.Tensor, indices: paddle.Tensor, @@ -269,6 +273,8 @@ def gather_nd( indices_shape = indices.shape batch_shape = params_shape[:batch_dims] batch_size = paddle.prod(batch_shape, [0]).numpy().tolist() + if isinstance(batch_size, int): + batch_size = [batch_size] index_internal_ndims = indices.ndim - batch_dims - 1 indices_internal_shape = indices_shape[batch_dims:-1] @@ -647,7 +653,11 @@ def _vmap(*args, **kwargs): # vectorisation - applying map_fn if only one arg provided as reduce requires # two elements to begin with. - arr_results = [func(*arrays) for arrays in zip(*args)] + arr_results = [] + for arrays in zip(*args): + arrays = [a if a.shape != [] else a.unsqueeze(0) for a in arrays] + arr_results.append(func(*arrays)) + res = paddle_backend.concat(arr_results) if out_axes: