File tree 2 files changed +10
-3
lines changed
2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -216,13 +216,17 @@ bool IsContiguousSlice(const HloInstruction& instr) {
216
216
// src and dst dimensions match.
217
217
const Shape& src_shape = slice->operand (0 )->shape ();
218
218
const Shape& dst_shape = slice->shape ();
219
+ return IsContiguousSlice (src_shape, dst_shape);
220
+ }
221
+
222
+ bool IsContiguousSlice (const Shape& orig, const Shape& sliced) {
219
223
bool sliced_dim_found = false ;
220
- for (auto dim : src_shape .layout ().minor_to_major ()) {
224
+ for (auto dim : orig .layout ().minor_to_major ()) {
221
225
if (!sliced_dim_found) {
222
- sliced_dim_found = dst_shape .dimensions (dim) < src_shape .dimensions (dim);
226
+ sliced_dim_found = sliced .dimensions (dim) < orig .dimensions (dim);
223
227
continue ;
224
228
}
225
- if (dst_shape .dimensions (dim) != 1 ) return false ;
229
+ if (sliced .dimensions (dim) != 1 ) return false ;
226
230
}
227
231
return true ;
228
232
}
Original file line number Diff line number Diff line change @@ -101,6 +101,9 @@ bool IsSliceWithUnitStrides(const HloInstruction* instr);
101
101
// slice.
102
102
bool IsContiguousSlice (const HloInstruction& instr);
103
103
104
+ // Returns true if `sliced` is a contiguous slice of `orig`.
105
+ bool IsContiguousSlice (const Shape& orig, const Shape& sliced);
106
+
104
107
// Emits code to shuffle data between threads of a warp. This has the same
105
108
// semantics as the PTX "shfl.sync.down" instruction but works for values that
106
109
// aren't 32 bits in size. The last operand of the emitted "shfl" is
You can’t perform that action at this time.
0 commit comments