Skip to content

Commit 0a2a9a7

Browse files
committed
fix
1 parent f394b4e commit 0a2a9a7

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

csrc/device_lower/analysis/index_compute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ IterDomain* getLogicalIDToTraverse(
13771377
const std::vector<Val*>& consumer_all_ids) {
13781378
const auto& logical_ids =
13791379
GpuLower::current()->caMap()->getLogicalDomainsOfIdGroup(
1380-
id, IdMappingMode::PERMISSIVE);
1380+
id, IdMappingMode::ALMOSTEXACT);
13811381
if (logical_ids.empty()) {
13821382
return nullptr;
13831383
}

tests/cpp/test_indexing.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5167,4 +5167,46 @@ TEST_F(IndexingTest, PerDimLogicalIndices) {
51675167
lower.run();
51685168
}
51695169

5170+
TEST_F(IndexingTest, Issue3299) {
5171+
auto fusion_ptr = std::make_unique<Fusion>();
5172+
Fusion& fusion = *fusion_ptr;
5173+
FusionGuard fg(fusion_ptr.get());
5174+
5175+
std::vector<int64_t> input_shape{128000, 1024};
5176+
auto tv0 = makeContigConcreteTensor(input_shape);
5177+
fusion.addInput(tv0);
5178+
5179+
std::vector<Val*> reshape_sizes1{
5180+
IrBuilder::create<Val>(128000L),
5181+
IrBuilder::create<Val>(8L),
5182+
IrBuilder::create<Val>(128L)};
5183+
auto tv1 = reshape(tv0, reshape_sizes1);
5184+
auto tv2 = permute(tv1, {1, 0, 2});
5185+
auto tv3 = broadcast(tv2, {false, true, false, false});
5186+
5187+
std::vector<Val*> expand_sizes{
5188+
IrBuilder::create<Val>(-1L),
5189+
IrBuilder::create<Val>(4L),
5190+
IrBuilder::create<Val>(-1L),
5191+
IrBuilder::create<Val>(-1L)};
5192+
auto tv4 = expand(tv3, expand_sizes);
5193+
5194+
std::vector<Val*> reshape_sizes2{
5195+
IrBuilder::create<Val>(32L),
5196+
IrBuilder::create<Val>(128000L),
5197+
IrBuilder::create<Val>(128L)};
5198+
auto tv5 = reshape(tv4, reshape_sizes2);
5199+
fusion.addOutput(tv5);
5200+
5201+
FusionExecutorCache fec(std::move(fusion_ptr));
5202+
5203+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5204+
at::Tensor t0 = at::randn({128000, 1024}, options);
5205+
std::vector<c10::IValue> inputs = {t0};
5206+
5207+
auto outputs = fec.runFusionWithInputs(inputs);
5208+
5209+
testValidate(fec.fusion(), outputs, inputs, __LINE__, __FILE__);
5210+
}
5211+
51705212
} // namespace nvfuser

0 commit comments

Comments
 (0)