Skip to content

Commit

Permalink
[XLA:GPU] Implement TiledHloInstruction graph in SymbolicTileAnalysis.
Browse files Browse the repository at this point in the history
SymbolicTileAnalysis create a new tiled HLO node for each unique (HLO instruction, indexing map) pair. Tiled instructions are stored in def-before-use order for easier access.

This representation makes it easier to write codegen, because each instruction knows about it's operands and we can efficiently cache emitter values. Those also allows to do some degree of CSE, but it can be CSEd more for concrete tile sizes.

PiperOrigin-RevId: 618788735
  • Loading branch information
olegshyshkov authored and copybara-github committed Mar 25, 2024
1 parent c6028ab commit 24a91cb
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 130 deletions.
4 changes: 1 addition & 3 deletions xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -560,15 +560,13 @@ cc_library(
":indexing_analysis",
":indexing_map",
":tile_analysis",
"//xla:shape_util",
"//xla:status",
"//xla/hlo/ir:hlo",
"//xla/service:instruction_fusion",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand Down
177 changes: 83 additions & 94 deletions xla/service/gpu/model/symbolic_tile_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,31 @@ limitations under the License.
#include "xla/service/gpu/model/symbolic_tile_analysis.h"

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <queue>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/model/indexing_analysis.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/tile_analysis.h"
#include "xla/shape.h"
#include "xla/service/instruction_fusion.h"
#include "xla/status.h"

namespace xla {
Expand All @@ -51,106 +53,109 @@ using ::mlir::AffineMap;
using ::mlir::MLIRContext;
using ::mlir::SmallVector;

struct HloAndPath {
const HloInstruction* hlo;
SymbolicTileAnalysis::InstructionPathFromRoot path;
};

} // namespace

/*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation(
const HloComputation& computation, MLIRContext* ctx) {
absl::flat_hash_map<InstructionPathFromRoot, SymbolicTile>
symbolic_tile_from_path;
ConstHloInstructionMap<absl::flat_hash_set<InstructionPathFromRoot>>
paths_from_root_to_instruction;
absl::flat_hash_map<const InstructionPathFromRoot, IndexingMap>
indexing_map_from_path;
std::queue<HloAndPath> to_process;

const HloInstruction* root = computation.root_instruction();
paths_from_root_to_instruction.insert({root, {{}}});

to_process.push(HloAndPath{root, /*path=*/{}});
indexing_map_from_path.insert({{}, CreateIdentityMap(root->shape(), ctx)});

while (!to_process.empty()) {
const HloAndPath hlo_and_path = to_process.front();
to_process.pop();

const HloInstruction* hlo = hlo_and_path.hlo;
std::vector<std::unique_ptr<TiledHloInstruction>> tiled_hlo_instructions;
absl::flat_hash_map<std::pair<const HloInstruction*, IndexingMap>,
TiledHloInstruction*>
tiled_hlo_instructions_map;

absl::flat_hash_map<TiledHloInstruction*, int64_t> topological_order;

std::function<std::variant<TiledHloInstruction*, FusionDecision>(
const HloInstruction*, IndexingMap)>
get_tiled_hlo_instruction;

// Create a new tiled hlo instruction or return existing instruction from
// cache for the given hlo and indexing map.
get_tiled_hlo_instruction = [&](const HloInstruction* hlo,
IndexingMap indexing_map)
-> std::variant<TiledHloInstruction*, FusionDecision> {
auto key = std::make_pair(hlo, indexing_map);

auto it = tiled_hlo_instructions_map.find(key);
if (it != tiled_hlo_instructions_map.end()) {
return it->second;
}

// Bail out on instructions that are known to cause problems down the line.
// This is not an inherent limitation of the approach, but simply issues
// to be resolved in the current implementation.
// Bail out on instructions that are known to cause problems down the
// line. This is not an inherent limitation of the approach, but simply
// issues to be resolved in the current implementation.
if (hlo->opcode() == HloOpcode::kDot ||
hlo->opcode() == HloOpcode::kReshape ||
hlo->opcode() == HloOpcode::kBitcast ||
hlo->opcode() == HloOpcode::kConcatenate) {
return absl::StrCat("Bailing out on ", hlo->ToString()).c_str();
return FusionDecision{} << "Bailing out on " << hlo->ToString();
}

// Bail out on instructions that do not output a single array.
if (!hlo->shape().IsArray()) {
return absl::StrCat(hlo->ToString(), " outputs more than a single array")
.c_str();
return FusionDecision{} << hlo->ToString()
<< " outputs more than a single array";
}

auto hlo_indexing_map_it = indexing_map_from_path.find(hlo_and_path.path);
CHECK(hlo_indexing_map_it != indexing_map_from_path.end());

const IndexingMap& hlo_indexing_map = hlo_indexing_map_it->second;
std::optional<SymbolicTile> symbolic_tile =
SymbolicTile::FromIndexingMap(hlo_indexing_map);
auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map);
if (!symbolic_tile.has_value()) {
return absl::StrCat("Failed to compute symbolic tile for ",
hlo_indexing_map.ToString(), " for HLO ",
hlo->ToString())
.c_str();
return FusionDecision{} << "Failed to compute symbolic tile for "
<< indexing_map.ToString() << " for HLO "
<< hlo->ToString();
}
symbolic_tile_from_path.insert({hlo_and_path.path, symbolic_tile.value()});

tiled_hlo_instructions.push_back(std::make_unique<TiledHloInstruction>(
hlo, std::move(indexing_map), std::move(*symbolic_tile)));

auto tiled_hlo_instruction = tiled_hlo_instructions.back().get();

std::optional<HloInstructionIndexing> operands_indexing =
ComputeOutputToInputIndexing(hlo, /*output_id=*/0, ctx);
ComputeOutputToInputIndexing(tiled_hlo_instruction->hlo,
/*output_id=*/0, ctx);

if (!operands_indexing.has_value()) {
return absl::StrCat("Failed to compute operands indexing for ",
hlo->ToString())
.c_str();
return FusionDecision{} << "Failed to compute operands indexing for "
<< tiled_hlo_instruction->hlo->ToString();
}

int operand_id = 0;
for (auto [operand, operand_indexing_map_set] :
llvm::zip(hlo->operands(), operands_indexing->indexing_maps)) {
// Assign hlo_indexing_map again, since the reference may have been
// invalidated by the insertion below.
const IndexingMap& hlo_indexing_map =
indexing_map_from_path.at(hlo_and_path.path);
llvm::zip(tiled_hlo_instruction->hlo->operands(),
operands_indexing->indexing_maps)) {
CHECK_EQ(operand_indexing_map_set.size(), 1);

IndexingMap operand_indexing_map = ComposeIndexingMaps(
hlo_indexing_map, *operand_indexing_map_set.begin());
IndexingMap operand_indexing_map =
ComposeIndexingMaps(tiled_hlo_instruction->indexing_map,
*operand_indexing_map_set.begin());

InstructionPathFromRoot operand_path = InstructionPathFromRoot(
hlo_and_path.path.begin(), hlo_and_path.path.end());
operand_path.push_back(operand_id);
auto tiled_operand_or =
get_tiled_hlo_instruction(operand, std::move(operand_indexing_map));

indexing_map_from_path.insert({operand_path, operand_indexing_map});
to_process.push(HloAndPath{operand, operand_path});

if (paths_from_root_to_instruction.find(operand) ==
paths_from_root_to_instruction.end()) {
paths_from_root_to_instruction.insert({operand, {operand_path}});
} else {
paths_from_root_to_instruction.at(operand).insert(operand_path);
if (auto fusion_decison =
std::get_if<FusionDecision>(&tiled_operand_or)) {
return *fusion_decison;
}

++operand_id;
tiled_hlo_instruction->operands.push_back(
std::get<TiledHloInstruction*>(tiled_operand_or));
}

topological_order[tiled_hlo_instruction] = topological_order.size();
tiled_hlo_instructions_map.emplace(key, tiled_hlo_instruction);
return tiled_hlo_instruction;
};

const HloInstruction* root = computation.root_instruction();
auto tiled_root =
get_tiled_hlo_instruction(root, CreateIdentityMap(root->shape(), ctx));
if (auto* fusion_decision = std::get_if<FusionDecision>(&tiled_root)) {
return *fusion_decision;
}

return SymbolicTileAnalysis(symbolic_tile_from_path,
paths_from_root_to_instruction, ctx);
// Order instructions in def-before-use order.
absl::c_sort(tiled_hlo_instructions, [&](const auto& i1, const auto& i2) {
return topological_order.at(i1.get()) < topological_order.at(i2.get());
});

return SymbolicTileAnalysis(std::move(tiled_hlo_instructions), ctx);
}

namespace {
Expand Down Expand Up @@ -181,39 +186,23 @@ std::vector<int64_t> EvaluateTileMap(AffineMap affine_map,
} // namespace

std::vector<int64_t> SymbolicTileAnalysis::TileOffsets(
absl::Nonnull<const HloInstruction*> hlo,
const InstructionPathFromRoot& path) const {
const TiledHloInstruction& tiled_hlo) const {
CHECK(tile_parameters_.has_value());
CHECK(paths_from_root_to_instruction_.find(hlo) !=
paths_from_root_to_instruction_.end());
CHECK(paths_from_root_to_instruction_.at(hlo).find(path) !=
paths_from_root_to_instruction_.at(hlo).end());
return EvaluateTileMap(symbolic_tile_from_path_.at(path).offset_map(),
return EvaluateTileMap(tiled_hlo.symbolic_tile.offset_map(),
*tile_parameters_);
}

// TODO(bchetioui): remove dependency on stride and offset parameters.
std::vector<int64_t> SymbolicTileAnalysis::TileSizes(
absl::Nonnull<const HloInstruction*> hlo,
const InstructionPathFromRoot& path) const {
const TiledHloInstruction& tiled_hlo) const {
CHECK(tile_parameters_.has_value());
CHECK(paths_from_root_to_instruction_.find(hlo) !=
paths_from_root_to_instruction_.end());
CHECK(paths_from_root_to_instruction_.at(hlo).find(path) !=
paths_from_root_to_instruction_.at(hlo).end());
return EvaluateTileMap(symbolic_tile_from_path_.at(path).size_map(),
*tile_parameters_);
return EvaluateTileMap(tiled_hlo.symbolic_tile.size_map(), *tile_parameters_);
}

std::vector<int64_t> SymbolicTileAnalysis::TileStrides(
absl::Nonnull<const HloInstruction*> hlo,
const InstructionPathFromRoot& path) const {
const TiledHloInstruction& tiled_hlo) const {
CHECK(tile_parameters_.has_value());
CHECK(paths_from_root_to_instruction_.find(hlo) !=
paths_from_root_to_instruction_.end());
CHECK(paths_from_root_to_instruction_.at(hlo).find(path) !=
paths_from_root_to_instruction_.at(hlo).end());
return EvaluateTileMap(symbolic_tile_from_path_.at(path).stride_map(),
return EvaluateTileMap(tiled_hlo.symbolic_tile.stride_map(),
*tile_parameters_);
}

Expand Down
66 changes: 45 additions & 21 deletions xla/service/gpu/model/symbolic_tile_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ limitations under the License.
#define XLA_SERVICE_GPU_MODEL_SYMBOLIC_TILE_ANALYSIS_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <variant>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/gpu/model/indexing_map.h"
#include "xla/service/gpu/model/tile_analysis.h"
#include "xla/service/instruction_fusion.h"

Expand All @@ -37,6 +37,29 @@ class SymbolicTileAnalysis;
using SymbolicTileAnalysisOrError =
std::variant<SymbolicTileAnalysis, FusionDecision>;

// A node in the tiled representation of an HLO computation. During tiling and
// codegen an HLO instruction may need to be emitted multiple times with
// different tiling parameters.
struct TiledHloInstruction {
// Pointer to the original HLO instruction.
const HloInstruction* hlo;

// Indexing map from the computation root to this instruction output.
IndexingMap indexing_map;

// Symbolic tile derived from the indexing map.
SymbolicTile symbolic_tile;

// Operands of the instruction in the tiled computation graph.
std::vector<TiledHloInstruction*> operands;

TiledHloInstruction(const HloInstruction* hlo, IndexingMap indexing_map,
SymbolicTile symbolic_tile)
: hlo(hlo),
indexing_map(std::move(indexing_map)),
symbolic_tile(std::move(symbolic_tile)) {}
};

// Constructs and holds symbolic tiles for all the instructions within a
// computation. We may hold several different symbolic tiles for the same
// instruction if the instruction is indexed in several different ways in order
Expand All @@ -59,43 +82,44 @@ class SymbolicTileAnalysis {
// Evaluates the tile offsets of an instruction from the analyzed computation
// following the provided path from the root. Tile parameters must have been
// set before calling this method.
std::vector<int64_t> TileOffsets(absl::Nonnull<const HloInstruction*> hlo,
const InstructionPathFromRoot& path) const;
std::vector<int64_t> TileOffsets(const TiledHloInstruction& tiled_hlo) const;
// Evaluates the tile sizes of an instruction from the analyzed computation
// following the provided path from the root. Tile parameters must have been
// set before calling this method.
std::vector<int64_t> TileSizes(absl::Nonnull<const HloInstruction*> hlo,
const InstructionPathFromRoot& path) const;
std::vector<int64_t> TileSizes(const TiledHloInstruction& tiled_hlo) const;
// Evaluates the tile strides of an instruction from the analyzed computation
// following the provided path from the root. Tile parameters must have been
// set before calling this method.
std::vector<int64_t> TileStrides(absl::Nonnull<const HloInstruction*> hlo,
const InstructionPathFromRoot& path) const;
std::vector<int64_t> TileStrides(const TiledHloInstruction& tiled_hlo) const;

// Populates input tile sizes. This is a prerequisite in order to extract
// concrete values using `TileOffsets`, `TileSizes`, and `TileStrides`.
void SetTileSizes(absl::Span<int64_t const> sizes);

// Returns the tiled root instruction.
const TiledHloInstruction* GetRoot() const {
return tiled_hlo_instructions_.back().get();
}

// Returns the tiled HLO instructions in def-before-use order.
const std::vector<std::unique_ptr<TiledHloInstruction>>&
GetTiledHloInstructions() const {
return tiled_hlo_instructions_;
}

// Return the underlying MLIRContext.
mlir::MLIRContext* GetMLIRContext() const { return context_; };

private:
SymbolicTileAnalysis(
absl::flat_hash_map<InstructionPathFromRoot, SymbolicTile>
symbolic_tile_from_path,
ConstHloInstructionMap<absl::flat_hash_set<InstructionPathFromRoot>>
paths_from_root_to_instruction,
std::vector<std::unique_ptr<TiledHloInstruction>> tiled_hlo_instructions,
mlir::MLIRContext* context)
: symbolic_tile_from_path_(symbolic_tile_from_path),
paths_from_root_to_instruction_(paths_from_root_to_instruction),
: tiled_hlo_instructions_(std::move(tiled_hlo_instructions)),
context_(context) {}

absl::flat_hash_map<InstructionPathFromRoot, SymbolicTile>
symbolic_tile_from_path_;
// Maps each instruction in the analyzed computation to a set containing all
// the possible paths from the root instruction to the key instruction.
ConstHloInstructionMap<absl::flat_hash_set<InstructionPathFromRoot>>
paths_from_root_to_instruction_;
// The tiled HLO instructions in def-before-use order.
std::vector<std::unique_ptr<TiledHloInstruction>> tiled_hlo_instructions_;

mlir::MLIRContext* context_;
// Optionally set tile parameters. These parameters can be set by calling
// `SetTileParameters`, and correspond to the output tile for the analyzed
Expand Down
Loading

0 comments on commit 24a91cb

Please sign in to comment.