Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ static std::vector<paddle::Tensor> PrepareIndices(
const paddle::Tensor& bool_2_idx,
const paddle::Tensor& bool_index) {
std::vector<paddle::Tensor> indices;
for (int j = 0; j < bool_2_idx.shape()[1]; ++j) {
for (int64_t j = 0; j < bool_2_idx.shape()[1]; ++j) {
paddle::Tensor sliced_tensor =
slice_ad_func(bool_2_idx, {1}, {j}, {j + 1}, {1}, {});
paddle::Tensor sliced_tensor_c = sliced_tensor.contiguous();
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/funcs/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,11 @@ __global__ void GatherGPUKernel(const T* input,
int64_t input_index_dim_size,
int64_t size) {
int64_t block_size = blockDim.x;
int64_t idx = (blockIdx.x * block_size + threadIdx.x) * VecSize;
int64_t idx =
(static_cast<int64_t>(blockIdx.x) * block_size + threadIdx.x) * VecSize;
int64_t outer_size = outer_dim_size * out_index_dim_size;
for (; idx < size; idx += gridDim.x * block_size * VecSize) {
for (; idx < size;
idx += static_cast<int64_t>(gridDim.x) * block_size * VecSize) {
int64_t inner_dim_index = idx / outer_size;
int64_t next_idx = idx % outer_size;
int64_t index_dim_index = next_idx / outer_dim_size;
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/funcs/index_elementwise.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ __global__ void index_elementwise_with_tensor_kernel(const int64_t N,
const func_t f) {
const auto tid = threadIdx.x;
const auto nv = nt * vt;
auto idx = nv * blockIdx.x + tid;
int64_t idx = static_cast<int64_t>(nv) * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
Expand All @@ -54,7 +54,7 @@ __global__ void index_elementwise_kernel(const int64_t N,
const func_t f) {
const auto tid = threadIdx.x;
const auto nv = nt * vt;
auto idx = nv * blockIdx.x + tid;
int64_t idx = static_cast<int64_t>(nv) * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
Expand All @@ -70,7 +70,7 @@ __global__ void index_put_kernel(const int64_t N,
const func_t f) {
const auto tid = threadIdx.x;
const auto nv = nt * vt;
auto idx = nv * blockIdx.x + tid;
int64_t idx = static_cast<int64_t>(nv) * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
Expand Down Expand Up @@ -227,6 +227,12 @@ static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(
return OffsetCalculator<N, uint32_t, signed_strides>(
iter.ndim(), iter.shape().data(), strides.data());
}
constexpr bool IsInUint32Range(int64_t value) {
return value >= 0 && value <= std::numeric_limits<int32_t>::max();
}
constexpr bool IsInUint32Range(int64_t v1, int64_t v2) {
return IsInUint32Range(v1) && IsInUint32Range(v2);
}

} // namespace funcs
} // namespace phi
7 changes: 4 additions & 3 deletions paddle/phi/kernels/funcs/index_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ __global__ void VectorizedIndexKernel(T *out,
size_t numel,
size_t main_offset,
Functor func) {
size_t data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
size_t stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
size_t data_offset = static_cast<size_t>(BLOCK_ID_X) * BLOCK_NUM_X * VecSize;
size_t stride = static_cast<size_t>(BLOCK_NUM_X) * GRID_NUM_X * VecSize;
size_t args[VecSize];
T result[VecSize];
for (; data_offset < main_offset; data_offset += stride) {
Expand Down Expand Up @@ -69,7 +69,8 @@ void IndexKernel(const KPDevice &dev_ctx, DenseTensor *out, Functor func) {
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
#endif
size_t main_offset = (numel / (vec_size * block)) * vec_size * block;
size_t main_offset =
(numel / (vec_size * static_cast<size_t>(block))) * vec_size * block;
switch (vec_size) {
case 4:
VectorizedIndexKernel<T, Functor, 4>
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/index_put_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ static void CalCompressedDimsWith1AndWithout1(
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
__global__ void range_cuda_kernel(int64_t N, T* out) {
int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
int64_t idx = threadIdx.x + static_cast<int64_t>(blockDim.x) * blockIdx.x;

if (idx >= N) {
return;
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/select_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,15 @@ __global__ void CumsumOneBlock(const InT *in,
int64_t numel,
int64_t main_offset,
Functor func) {
int64_t stride = BLOCK_NUM_X * VecSize;
int64_t stride = static_cast<int64_t>(BLOCK_NUM_X) * VecSize;
int64_t offset = 0;
OutT pre_cumsum = static_cast<OutT>(0);
for (; offset < main_offset; offset += stride) {
CumsumImpl<InT, OutT, Functor, VecSize, false>(
in + offset, out + offset, &pre_cumsum, stride, func);
}

int num = numel - offset;
int64_t num = numel - offset;
if (num > 0) {
CumsumImpl<InT, OutT, Functor, VecSize, true>(
in + offset, out + offset, &pre_cumsum, num, func);
Expand Down
18 changes: 13 additions & 5 deletions paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ __global__ void IndexEleGetGradAccKernel(
offset_calc_t offset_calc) {
const int tid = threadIdx.x;
const int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
int64_t idx = nv * static_cast<int64_t>(blockIdx.x) + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
Expand Down Expand Up @@ -112,10 +112,17 @@ void GPUIndexElementwiseGetGrad(const phi::GPUContext& dev_ctx,
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);

const int64_t N = numel;

PADDLE_ENFORCE_EQ(true,
funcs::IsInUint32Range(N, value.numel()),
common::errors::PreconditionNotMet(
"the numel of input or output should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
const dim3 grid((N + static_cast<int64_t>(block.x) * vt - 1) /
(static_cast<int64_t>(block.x) * vt));
auto stream = dev_ctx.stream();

using dtype = funcs::OpaqueType<sizeof(T)>;
Expand Down Expand Up @@ -172,11 +179,12 @@ __global__ void IndexingBackwardKernel(const int64_t* sorted_indices,
using opmath_t = typename phi::dtype::MPTypeTrait<scalar_t>::Type;

for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.y + threadIdx.y;
if (idx < numel &&
(idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])) {
do {
int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
int64_t start_feature =
threadIdx.x + static_cast<int64_t>(blockIdx.y) * blockDim.x * SZ;
if (!accumulate && (idx < numel - 1) &&
sorted_indices[idx] == sorted_indices[idx + 1]) {
idx++;
Expand Down Expand Up @@ -222,7 +230,7 @@ __global__ void IndexingBackwardKernel(const int64_t* sorted_indices,
static_cast<scalar_t>(weight[ii]);
}
}
start_feature += gridDim.y * blockDim.x * SZ;
start_feature += static_cast<int64_t>(gridDim.y) * blockDim.x * SZ;
}
idx++;
} while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
Expand Down
11 changes: 5 additions & 6 deletions paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,11 @@ void GPUIndexElementwiseGetKernel(const phi::GPUContext& dev_ctx,
funcs::make_offset_calculator_put<3>(desired_shape, strides_array);

const int64_t N = output->numel();
PADDLE_ENFORCE_GE(
N, 0, common::errors::InvalidArgument("Output numel must >= 0"));
PADDLE_ENFORCE_LE(
N,
std::numeric_limits<int32_t>::max(),
common::errors::InvalidArgument("Output numel must <= INT32_MAX"));
PADDLE_ENFORCE_EQ(true,
funcs::IsInUint32Range(N, input.numel()),
common::errors::PreconditionNotMet(
"the numel of input or output should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
constexpr int nt = 128;
constexpr int vt = 4;
const dim3 block(nt);
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ void GPUIndexElementwisePutGradKernel(
auto index_ptrs = funcs::GetIndexDataPtrs<IndexT>(index);
const char* out_ptr = reinterpret_cast<const char*>(out_grad.data<T>());
char* value_ptr = reinterpret_cast<char*>(value_grad->data<T>());
PADDLE_ENFORCE_EQ(true,
funcs::IsInUint32Range(value_grad->numel()),
common::errors::PreconditionNotMet(
"the numel of input or output should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
funcs::index_elementwise_with_tensor_kernel<nt, vt>
<<<grid, block, 0, stream>>>(N, [=] __device__(int idx) {
const auto offsets = offset_calc.get(idx);
Expand All @@ -151,6 +156,11 @@ void GPUIndexElementwisePutGradKernel(
} else {
auto index_ptrs = funcs::GetIndexDataPtrs<IndexT>(index);
char* out_ptr = reinterpret_cast<char*>(x_grad->data<T>());
PADDLE_ENFORCE_EQ(true,
funcs::IsInUint32Range(value_grad->numel()),
common::errors::PreconditionNotMet(
"the numel of input or output should be in [0, "
"std::numeric_limits<int32_t>::max()]"));
char* value_ptr = reinterpret_cast<char*>(value_grad->data<T>());
funcs::index_elementwise_with_tensor_kernel<nt, vt>
<<<grid, block, 0, stream>>>(N, [=] __device__(int idx) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ __global__ void GPUMaskedFillXGradKernel(const T* out_grad,
const int64_t input_len,
const int64_t batch_size,
T* x_grad) {
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x);
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;

if (idx >= (input_len / VecSize)) {
return;
Expand Down Expand Up @@ -73,7 +73,7 @@ __global__ void GPUMaskedFillValueGradKernel(const T* out_grad,
const int64_t input_len,
const int64_t batch_size,
T* value_grad) {
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x);
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;

if (idx >= (input_len / VecSize)) {
return;
Expand Down Expand Up @@ -243,7 +243,7 @@ void GPUMaskedFillGrad(const phi::GPUContext& dev_ctx,

int64_t input_len = out_grad.numel();
int64_t mask_len = mask.numel();
int batch_size = input_len / mask_len;
int64_t batch_size = input_len / mask_len;

int vec_size = 8;
vec_size = std::min(phi::GetVectorizedSize(out_grad_data), vec_size);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/masked_fill_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ __global__ void GPUMaskedFillKernel(const T* input,
const int64_t input_len,
const int64_t batch_size,
T* output) {
int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x);
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;

if (idx >= (input_len / VecSize)) {
return;
Expand Down Expand Up @@ -161,7 +161,7 @@ void GPUMaskedFill(const phi::GPUContext& dev_ctx,
const T* value_data = value.data<T>();
int64_t input_len = input.numel();
int64_t mask_len = mask.numel();
int batch_size = input_len / mask_len;
int64_t batch_size = input_len / mask_len;

int vec_size = 8;
vec_size = std::min(phi::GetVectorizedSize(input_data), vec_size);
Expand Down
Loading