1919#include " paddle/phi/core/dense_tensor.h"
2020#include " paddle/phi/core/tensor_utils.h"
2121#include " paddle/phi/kernels/empty_kernel.h"
22+ #include " paddle/phi/kernels/expand_kernel.h"
2223#include " paddle/phi/kernels/funcs/broadcast_function.h"
2324#include " paddle/phi/kernels/funcs/eigen/common.h"
2425#include " paddle/phi/kernels/funcs/eigen/eigen_function.h"
2526#include " paddle/phi/kernels/funcs/elementwise_functor.h"
2627#include " paddle/phi/kernels/funcs/slice_utils.h"
27-
2828namespace phi {
2929
3030// check whether the tensor with dimension of second can assign to the
@@ -89,7 +89,6 @@ void SetValueImpl(const Context& dev_ctx,
8989 in_dims, axes, starts_local, ends_local, &steps_local);
9090 auto decrease_slice_dims =
9191 phi::funcs::GetDecreasedDims (slice_dims, decrease_axes);
92-
9392 auto slice_dims_for_assign = decrease_slice_dims;
9493 if (!none_axes.empty ()) {
9594 std::vector<int64_t > slice_dims_with_none;
@@ -115,33 +114,36 @@ void SetValueImpl(const Context& dev_ctx,
115114
116115 slice_dims_for_assign = common::make_ddim (slice_dims_with_none);
117116 }
117+ CheckIsDimsMatch (slice_dims_for_assign, value.dims ());
118+
119+ auto value_shape = phi::vectorize<int64_t >(value.dims ());
120+
121+ DenseTensor value_tensor = Empty<T>(dev_ctx, IntArray{value_shape});
122+ value_tensor = value;
123+ auto it = value_shape.begin ();
124+ while (it != value_shape.end () && *it == 1 ) {
125+ it = value_shape.erase (it);
126+ }
127+ if (value_shape.empty ()) value_shape.push_back (1 );
128+ value_tensor.Resize (phi::make_ddim (value_shape));
129+
130+ auto expand_shape = phi::vectorize<int64_t >(slice_dims_for_assign);
131+ for (size_t i = 0 ; i <= expand_shape.size (); i++) {
132+ if (expand_shape[i] == 0 ) expand_shape[i] = 1 ;
133+ }
134+ if (expand_shape.empty ()) expand_shape.push_back (1 );
135+ DenseTensor expand_tensor = Empty<T>(dev_ctx, IntArray{expand_shape});
118136
119137 auto place = dev_ctx.GetPlace ();
120138 auto & eigen_place = *dev_ctx.eigen_device ();
121139
122- // Here copy data from input to avoid data loss at PE and Graph level.
123- // TODO(liym27): Speed up in the future version.
124- // - Q: Why don't call ShareDataWith to speed up?
125- // - A: Because it's not supported to ShareDataWith on OP's input and output
126- // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
127- // - Q: Why don't delete Input, after all, the input and output are the same
128- // Tensor at program level?
129- // - A: If deleting Input, the graph will be complex, such as there will
130- // be two ops points to the output in graph: op1 -> output <- set_value.
131- // In this case, we have to find a way to handle the running order of
132- // set_value is what we want.
133140 Copy (dev_ctx, in, place, false , out);
141+ ExpandKernel<T, Context>(
142+ dev_ctx, value_tensor, IntArray{expand_shape}, &expand_tensor);
143+ expand_tensor.Resize (slice_dims);
134144
135- DenseTensor slice_tensor =
136- Empty<T>(dev_ctx, IntArray{slice_dims.Get (), slice_dims.size ()});
137- DenseTensor pad_tensor =
138- Empty<T>(dev_ctx, IntArray{in_dims.Get (), in_dims.size ()});
139- auto pad_e = EigenTensor<T, RANK>::From (pad_tensor, in_dims);
140145 auto out_e = EigenTensor<T, RANK>::From (*out);
141- auto slice_e = EigenTensor<T, RANK>::From (slice_tensor, slice_dims);
142-
143- // Step 1: Set the value of out at `_index` to zero
144- slice_e.device (eigen_place) = slice_e.constant (T (0 ));
146+ auto value_e = EigenTensor<T, RANK>::From (expand_tensor);
145147
146148 auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
147149 auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
@@ -164,65 +166,7 @@ void SetValueImpl(const Context& dev_ctx,
164166 }
165167
166168 out_e.stridedSlice (starts_indices, ends_indices, strides_indices)
167- .device (eigen_place) = slice_e;
168-
169- // Step 2: Set a tensor with the same shape as out tensor. And its data at
170- // '_index' is the same as value, and data out of '_index' to zero
171-
172- // - Step 2.1 Set slice tensor with value
173-
174- // NOTE(liym27): [ Why resize slice_tensor here? ]
175- // A: When do broadcasting on slice_tensor and value, the shape of
176- // slice_tensor should be decreased dims.
177- // e.g.
178- // x[:,0] = value
179- // x's shape = [3, 4], value's shape = [3]
180- // We get slice_dims = [3, 1], decrease_slice_dims = [3]
181- // If do broadcasting on Tensor with shape [3, 1] and [3], the result's
182- // shape is [3, 3], which cross the border;
183- // If do broadcasting on Tensor with shape [3] and [3], the result's shape
184- // is [3], which is right.
185-
186- slice_tensor.Resize (slice_dims_for_assign);
187-
188- CheckIsDimsMatch (slice_dims_for_assign, value.dims ());
189-
190- bool is_gpu_place = dev_ctx.GetPlace ().GetType () == phi::AllocationType::GPU;
191- if (is_gpu_place || slice_tensor.dims ().size () >= value.dims ().size ()) {
192- // [Why here we confirm running device]
193- // ElementwiseComputeEx can do broadcasting in two cases:
194- // 1. The place is GPU.
195- // 2. The place is CPU, and the 'x' does not need broadcast.
196- // Please see the note in
197- // paddle/fluid/operators/elementwise/elementwise_op_function.h
198- // So, here we choose different logic depending on the device to avoid
199- // numerical problems, temporarily.
200- //
201- // TODO(zoooo0820): Reimplement logic of set_value to avoid using
202- // elementwise-sub.
203- funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
204- dev_ctx,
205- slice_tensor,
206- value,
207- funcs::SubtractFunctor<T>(),
208- &slice_tensor);
209- } else {
210- funcs::ElementwiseCompute<funcs::InverseSubtractFunctor<T>, T>(
211- dev_ctx,
212- slice_tensor,
213- value,
214- funcs::InverseSubtractFunctor<T>(),
215- &slice_tensor);
216- }
217- slice_tensor.Resize (slice_dims);
218-
219- // - Step 2.2 Pad slice tensor with 0
220- pad_e.device (eigen_place) = pad_e.constant (T (0 ));
221- pad_e.stridedSlice (starts_indices, ends_indices, strides_indices)
222- .device (eigen_place) = slice_e;
223-
224- // Step 3: Set out tensor with value
225- out_e.device (eigen_place) = out_e - pad_e;
169+ .device (eigen_place) = value_e;
226170}
227171
228172template <typename T, typename Context>
0 commit comments