Skip to content

Resize scheduler: enable vectorization #3694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
944dc65
Add ResizeHeuristic
naoyam Jan 6, 2025
8ec8e25
cleanup
naoyam Jan 7, 2025
88b2008
Disable the move-pad preseg pass when the resize scheduler is enabled
naoyam Jan 7, 2025
c0e3e91
Merge remote-tracking branch 'origin/disable_move_pad_when_resize_sch…
naoyam Jan 7, 2025
fde550b
remove unnecessary check
naoyam Jan 7, 2025
f4d2a3d
Merge remote-tracking branch 'origin/main' into resize_scheduler_mino…
naoyam Jan 7, 2025
65ab3fb
Relax broadcast constraint in scheduleLoopDomainsLike
naoyam Jan 7, 2025
c3b3daf
Merge branch 'allow_broadcast_ref' into resize_scheduler_test
naoyam Jan 7, 2025
61bcbb5
Merge branch 'resize_scheduler_minor_update_heuristic' into resize_sc…
naoyam Jan 7, 2025
eed894b
Merge remote-tracking branch 'origin/main' into resize_scheduler_test
naoyam Jan 8, 2025
92823d0
bcast check
naoyam Jan 8, 2025
dd0f134
Schedule loop domains such that reshape transforms are cancelled
naoyam Jan 8, 2025
a7513c5
Merge remote-tracking branch 'origin/cancel_reshape' into resize_sche…
naoyam Jan 8, 2025
ed100ac
fix
naoyam Jan 9, 2025
0c2c335
comment
naoyam Jan 9, 2025
80bcffb
Don't try to index broadcast IDs as they should always be zero and there
naoyam Jan 9, 2025
3b5c621
repro
naoyam Jan 9, 2025
3f79de9
cleanup
naoyam Jan 9, 2025
70df6fa
Merge remote-tracking branch 'origin/main' into resize_scheduler_test
naoyam Jan 10, 2025
deeb24b
Merge remote-tracking branch 'origin/skip_trying_indexing_broadcast_i…
naoyam Jan 10, 2025
5d42ace
reorder tensors like largest input
naoyam Jan 10, 2025
536a006
Merge remote-tracking branch 'origin/cancel_reshape' into resize_sche…
naoyam Jan 10, 2025
ab4fcd3
Merge remote-tracking branch 'origin/main' into resize_scheduler_reorder
naoyam Jan 10, 2025
afa1cab
enable vectorization
naoyam Jan 10, 2025
ad9cb1c
Merge remote-tracking branch 'origin/main' into resize_scheduler_vec
naoyam Jan 10, 2025
6c5da61
Merge branch 'main' into resize_scheduler_reorder
naoyam Jan 14, 2025
9edd096
comment
naoyam Jan 14, 2025
41310f2
comment
naoyam Jan 14, 2025
f7ea6b5
Merge branch 'main' into resize_scheduler_vec
naoyam Jan 14, 2025
ef2cbf9
Merge branch 'resize_scheduler_reorder' into resize_scheduler_vec
naoyam Jan 14, 2025
67df58a
cleanup
naoyam Jan 14, 2025
8aa1716
comment
naoyam Jan 14, 2025
1bec6f1
comment
naoyam Jan 15, 2025
fbc734a
test WAR
naoyam Jan 15, 2025
923c7e1
cleanup
naoyam Jan 15, 2025
0510ec6
Use reference tv for vectorization analysis
naoyam Jan 15, 2025
d89742d
PR feedback
naoyam Jan 15, 2025
714e30d
fix
naoyam Jan 15, 2025
071eaca
Merge branch 'main' into resize_scheduler_vec
naoyam Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 104 additions & 12 deletions csrc/scheduler/resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,33 @@ std::unique_ptr<HeuristicParams> ResizeScheduler::computeHeuristics(
params->split_grid_x_dim =
ceilDiv(max_num_elms, bdimx) > ResizeParams::max_gdimx;

const auto largest_input =
getLargestTensor(fusion->inputs(), runtime_info).first;
if (largest_input != nullptr) {
int64_t index_of_largest_input = std::distance(
fusion->inputs().begin(),
std::find(
fusion->inputs().begin(), fusion->inputs().end(), largest_input));
params->largest_input = index_of_largest_input;
} else {
params->largest_input = -1;
}

// Vectorization based on the largest input if there's any input
// tv. This is because the current heuristics are designed to
// optimize the read perfornance. The largest output is used if
// there's no input.
auto ref_tv_for_vectorization =
largest_input != nullptr ? largest_input : largest_output;
// Only consider the innermost dimension to vectorize for now.
// TODO: Consider vectorizing merged IDs, not just the innermost
params->vectorization_factor = vectorize_helper::getVectorizationFactor(
runtime_info,
ref_tv_for_vectorization,
data_cache,
(int64_t)ref_tv_for_vectorization->getLogicalDomain().size() - 1,
{});

return params;
}

Expand Down Expand Up @@ -251,6 +278,17 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
ir_utils::replaceValInExprInputs(resize_tensor_op, inp_tv, inp_tv_copy);
}

TensorView* largest_input = nullptr;
if (resize_params->largest_input >= 0) {
largest_input =
fusion->inputs().at(resize_params->largest_input)->as<TensorView>();

// The tensors are going to be reordered to align with the largest
// input. To make it work, merge operations for reshape should be
// cancelled.
scheduler_tools::cancelReshapeInLoopDomains(largest_input);
}

for (auto expr : fusion->exprs()) {
if (!expr->isOneOf<SliceOp, PadOp>()) {
continue;
Expand All @@ -265,26 +303,66 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
// Just simple scheduling for now.
// TODO: Do something smarter. Can just use the pointwise scheduler?

const int64_t bdimx = 128;
// Reorder tensors to align with the largest input. This is expected
// to improve the memory read performance, while the write
// performance could be lowered. This should generally be more
// important to optimize the read performance, but more robust
// decision would be needed.
if (largest_input != nullptr) {
std::vector<IterDomain*> ref_alloc;
ref_alloc.reserve(largest_input->getMaybeAllocationDomain().size());
std::copy_if(
largest_input->getMaybeAllocationDomain().begin(),
largest_input->getMaybeAllocationDomain().end(),
std::back_inserter(ref_alloc),
[](IterDomain* alloc_id) {
return !alloc_id->isBroadcast() && !alloc_id->isReduction() &&
!alloc_id->isDeviceDim();
});

// Reorder the reference as the allocation domain of the largest fusion
// input
scheduler_utils::reorderTensorLike(ref_tv, ref_alloc);
}

// Make sure the DID ID located at the outermost position
const auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv);

// Schedule only the remaining IDs
ref_tv->flatten(outermost_pos);
// [..., I0]
const int64_t vec_factor = resize_params->vectorization_factor;

ref_tv->split(-1, bdimx);
ref_tv->axis(-1)->parallelize(ParallelType::TIDx);
// [..., I0/bdimx, bdimx(TIDx)]
const int64_t bdimx = 128;

int64_t next_innermost_pos = -1;
// [..., ...]
// ^
// +--- next_innermost_pos

if (vec_factor > 1) {
ref_tv->split(-1, vec_factor);
--next_innermost_pos;
// [..., vec_factor]
// ^
// +--- next_innermost_pos
}

ref_tv->flatten(outermost_pos, next_innermost_pos);
// [..., I0, vec_factor]
// ^
// +--- next_innermost_pos

ref_tv->split(next_innermost_pos, bdimx);
ref_tv->axis(next_innermost_pos)->parallelize(ParallelType::TIDx);
--next_innermost_pos;
// [..., I0/bdimx, bdimx(TIDx), vec_factor]
// ^
// +--- next_innermost_pos

if (resize_params->split_grid_x_dim) {
ref_tv->split(-2, ResizeParams::max_gdimx);
// [..., I0/bdimx/max_gdimx, max_gdimx, bdimx(TIDx)]
ref_tv->split(next_innermost_pos, ResizeParams::max_gdimx);
// [..., I0/bdimx/max_gdimx, max_gdimx, bdimx(TIDx), vec_factor]
}
ref_tv->axis(-2)->parallelize(ParallelType::BIDx);
// [..., I0/bdimx/max_gdimx, max_gdimx(BIDx), bdimx(TIDx)] or
// [..., I0/bdimx(BIDx), bdimx(TIDx)]
ref_tv->axis(next_innermost_pos)->parallelize(ParallelType::BIDx);
// [..., I0/bdimx/max_gdimx, max_gdimx(BIDx), bdimx(TIDx), vec_factor]

// Propagate the reference to the other tensors. Note that the
// update flag is enabled so to workaround the resize propagation
Expand All @@ -297,6 +375,20 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
ref_tv->getLoopDomain(),
/*update_loop_domain_only=*/true);

if (vec_factor > 1) {
const auto tvs_to_vectorize =
scheduler_utils::getInputsOutputsWithInnerDim(ref_tv, true, true);
for (auto tv_to_vectorize : tvs_to_vectorize) {
if (tv_to_vectorize->isFusionInput()) {
for (auto consumer_tv : ir_utils::consumerTvsOf(tv_to_vectorize)) {
consumer_tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
} else {
tv_to_vectorize->axis(-1)->parallelize(ParallelType::Vectorize);
}
}
}

inlineMost();

markAliases(fusion);
Expand Down
12 changes: 10 additions & 2 deletions csrc/scheduler/resize_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class ResizeParams : public HeuristicParams {
// Split grid x dimension
bool split_grid_x_dim = false;

int64_t largest_input = 0;

int64_t vectorization_factor = 1;

static constexpr int64_t max_gdimx = (1L << 31) - 1L;

using HeuristicParams::HeuristicParams;
Expand All @@ -34,15 +38,19 @@ class ResizeParams : public HeuristicParams {
return false;
}
bool attr_equal = other->cparams == cparams &&
other->split_grid_x_dim == split_grid_x_dim;
other->split_grid_x_dim == split_grid_x_dim &&
other->largest_input == largest_input &&
other->vectorization_factor == vectorization_factor;
return attr_equal;
}

std::string toString() const override {
std::stringstream ss;
ss << "\n===== Resize Parameters ========\n"
<< (tag.empty() ? "" : "Tag: ") << tag << " Resize Characteristics:\n"
<< " split grid x dim: " << split_grid_x_dim << "\n";
<< " split grid x dim: " << split_grid_x_dim << "\n"
<< " index of largest input: " << largest_input << "\n"
<< " vectorization factor: " << vectorization_factor << "\n";
ss << "====================================\n";
return ss.str();
}
Expand Down
77 changes: 77 additions & 0 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,83 @@ int64_t reorderDevicesToOuter(TensorView* tv) {
return (int64_t)old2new.size();
}

void reorderTensorLike(
TensorView* target_tv,
const std::vector<IterDomain*>& ref) {
const auto& tv_loop_domain = target_tv->getLoopDomain();

IdModel id_model(target_tv->fusion(), /*build_graphs=*/false);
const auto& graph = id_model.buildBroadcastGraph();

ValGroups target_groups = graph.toGroups(tv_loop_domain);

ValGroups ref_groups = graph.toGroups(ref);

// Traverse from the reference to the target tv. The reference is
// not guaranteed to cover all loop IDs of target, so
// require_all_to_visited needs to be false
auto path = ValGraphBFS::getExprGroupsBetween(
graph,
ref_groups,
target_groups,
/*require_all_to_visited=*/false)
.first;

// Traverse the expr path to create an ordered ID groups
std::deque<ValGroup> ordered_domain{
ref_groups.vector().begin(), ref_groups.vector().end()};

for (const auto& [expr_g, dir] : path) {
auto inputs = getInputsOfExpr(
expr_g, dir, ValGraphInputs(graph), ValGraphOutputs(graph));
auto outputs = getOutputsOfExpr(
expr_g, dir, ValGraphInputs(graph), ValGraphOutputs(graph));

// Inserts the outputs at the innermost position
std::deque<ValGroup>::iterator innermost_it = ordered_domain.end();
for (auto it = inputs.rbegin(); it != inputs.rend(); ++it) {
innermost_it =
std::find(ordered_domain.begin(), ordered_domain.end(), *it);
NVF_ERROR(innermost_it != ordered_domain.end());
break;
}
ordered_domain.insert(innermost_it, outputs.begin(), outputs.end());

// Removes the inputs
for (const auto& inp : inputs) {
ordered_domain.erase(
std::remove(ordered_domain.begin(), ordered_domain.end(), inp),
ordered_domain.end());
}
}

std::unordered_map<int64_t, int64_t> old2new;

// Place IDs that do not appear in ref at the outer position
int64_t new_id_pos = 0;
for (const auto i : c10::irange(tv_loop_domain.size())) {
const auto& loop_id_group = graph.toGroup(tv_loop_domain.at(i));
auto it =
std::find(ordered_domain.begin(), ordered_domain.end(), loop_id_group);
if (it == ordered_domain.end()) {
old2new.emplace((int64_t)i, new_id_pos);
++new_id_pos;
}
}
for (const auto i : c10::irange(tv_loop_domain.size())) {
const auto& loop_id_group = graph.toGroup(tv_loop_domain.at(i));
auto it =
std::find(ordered_domain.begin(), ordered_domain.end(), loop_id_group);
if (it != ordered_domain.end()) {
int64_t new_pos =
(int64_t)std::distance(ordered_domain.begin(), it) + new_id_pos;
old2new.emplace((int64_t)i, new_pos);
}
}

target_tv->reorder(old2new);
}

} // namespace scheduler_utils

} // namespace nvfuser
4 changes: 4 additions & 0 deletions csrc/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -745,5 +745,9 @@ inline int64_t nLogicalDims(const TensorView* tv) {
return tv_n_dims;
}

// Reorer the loop domain of a given tensor to align with a given list of
// reference IDs. Non-matching loop IDs are placed outermost positions.
void reorderTensorLike(TensorView* tv, const std::vector<IterDomain*>& ref);

} // namespace scheduler_utils
} // namespace nvfuser
Loading