Skip to content

Commit 277e98f

Browse files
tyb0807copybara-github
authored andcommitted
[xla:gpu] Add IsContiguousSlice on Shapes instead of HloInstruction
PiperOrigin-RevId: 615608156
1 parent c53ecca commit 277e98f

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

xla/service/gpu/ir_emission_utils.cc

+7-3
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,17 @@ bool IsContiguousSlice(const HloInstruction& instr) {
216216
// src and dst dimensions match.
217217
const Shape& src_shape = slice->operand(0)->shape();
218218
const Shape& dst_shape = slice->shape();
219+
return IsContiguousSlice(src_shape, dst_shape);
220+
}
221+
222+
bool IsContiguousSlice(const Shape& orig, const Shape& sliced) {
219223
bool sliced_dim_found = false;
220-
for (auto dim : src_shape.layout().minor_to_major()) {
224+
for (auto dim : orig.layout().minor_to_major()) {
221225
if (!sliced_dim_found) {
222-
sliced_dim_found = dst_shape.dimensions(dim) < src_shape.dimensions(dim);
226+
sliced_dim_found = sliced.dimensions(dim) < orig.dimensions(dim);
223227
continue;
224228
}
225-
if (dst_shape.dimensions(dim) != 1) return false;
229+
if (sliced.dimensions(dim) != 1) return false;
226230
}
227231
return true;
228232
}

xla/service/gpu/ir_emission_utils.h

+3
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ bool IsSliceWithUnitStrides(const HloInstruction* instr);
101101
// slice.
102102
bool IsContiguousSlice(const HloInstruction& instr);
103103

104+
// Returns true if `sliced` is a contiguous slice of `orig`.
105+
bool IsContiguousSlice(const Shape& orig, const Shape& sliced);
106+
104107
// Emits code to shuffle data between threads of a warp. This has the same
105108
// semantics as the PTX "shfl.sync.down" instruction but works for values that
106109
// aren't 32 bits in size. The last operand of the emitted "shfl" is

0 commit comments

Comments
 (0)