Skip to content

Commit

Permalink
fix paddle 2.6.0 problem with zipping iterables of tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Jan 2, 2024
1 parent 3f28537 commit 91d564e
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion ivy/functional/backends/paddle/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def _gather(params1):
return _gather(params)


@with_unsupported_device_and_dtypes(
{"2.5.2 and below": {"cpu": ("bfloat16", "float16")}},
backend_version,
)
def gather_nd(
params: paddle.Tensor,
indices: paddle.Tensor,
Expand Down Expand Up @@ -269,6 +273,8 @@ def gather_nd(
indices_shape = indices.shape
batch_shape = params_shape[:batch_dims]
batch_size = paddle.prod(batch_shape, [0]).numpy().tolist()
if isinstance(batch_size, int):
batch_size = [batch_size]
index_internal_ndims = indices.ndim - batch_dims - 1
indices_internal_shape = indices_shape[batch_dims:-1]

Expand Down Expand Up @@ -647,7 +653,11 @@ def _vmap(*args, **kwargs):

# vectorisation - applying map_fn if only one arg provided as reduce requires
# two elements to begin with.
arr_results = [func(*arrays) for arrays in zip(*args)]
arr_results = []
for arrays in zip(*args):
arrays = [a if a.shape != [] else a.unsqueeze(0) for a in arrays]
arr_results.append(func(*arrays))

res = paddle_backend.concat(arr_results)

if out_axes:
Expand Down

0 comments on commit 91d564e

Please sign in to comment.