diff --git a/paddle_geometric/utils/_scatter.py b/paddle_geometric/utils/_scatter.py index a02661e..fbe8768 100644 --- a/paddle_geometric/utils/_scatter.py +++ b/paddle_geometric/utils/_scatter.py @@ -78,14 +78,14 @@ def scatter( return src.new_zeros(size) index = broadcast(index, src, dim) - return src.new_zeros(size).scatter_add_(dim=dim, index=index, src=src) + return src.new_zeros(size).scatter_add(dim=dim, index=index, src=src) if reduce == 'mean': if index.numel() == 0: return src.new_zeros(size) count = paddle.zeros(dim_size, device=src.place) - count.scatter_add_( + count.scatter_add( dim=0, index=index, src=paddle.ones(src.shape[dim], device=src.place), @@ -93,7 +93,7 @@ def scatter( count = count.clip(min=1) index = broadcast(index, src, dim) - out = src.new_zeros(size).scatter_add_(dim=dim, index=index, src=src) + out = src.new_zeros(size).scatter_add(dim=dim, index=index, src=src) return out / broadcast(count, out, dim) @@ -173,7 +173,7 @@ def broadcast(src: paddle.Tensor, ref: paddle.Tensor, dim: int) -> paddle.Tensor: dim = ref.dim() + dim if dim < 0 else dim size = (1, ) * dim + (-1, ) + (1, ) * (ref.dim() - dim - 1) - return src.view(size).expand_as(y=ref) + return src.reshape(size).expand_as(y=ref) def scatter_argmax(