Skip to content

Commit

Permalink
Add preliminary support for DCN sharding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685979091
  • Loading branch information
Google-ML-Automation committed Oct 15, 2024
1 parent b527ed5 commit b910a38
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 5 deletions.
32 changes: 32 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3993,6 +3993,16 @@ void RecordPassEndAndDumpModule(absl::Time start_time,
DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding");
}

std::vector<int> FindAllIndices(std::vector<int64_t> vec, int64_t element) {
std::vector<int> result;
for (int i = 0; i < vec.size(); ++i) {
if (vec[i] == element) {
result.push_back(i);
}
}
return result;
}

absl::StatusOr<bool> AutoSharding::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand Down Expand Up @@ -4123,6 +4133,7 @@ absl::StatusOr<bool> AutoSharding::Run(
std::vector<std::string> mesh_shape_error_messages(mesh_shapes.size());
for (size_t i = 0; i < mesh_shapes.size(); ++i) {
VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]);

AutoShardingOption this_option = option_;
this_option.device_mesh_shape = mesh_shapes[i];
if (this_option.device_mesh_shape.size() !=
Expand All @@ -4135,6 +4146,27 @@ absl::StatusOr<bool> AutoSharding::Run(
this_option.solver_timeout_in_seconds /= mesh_shapes.size();
LOG(INFO) << "Setting solver timeout per mesh shape to "
<< this_option.solver_timeout_in_seconds << " seconds.";

// Try to infer DCN axis if the HLO is multi-slice.
// TODO(b/372720563) Improve this DCN axis inference. Currently, we assume
// there is only one DCN axis, and that there is no ICI axis with the same
// size as the DCN axis.
if (option_.num_dcn_slices.has_value() && *option_.num_dcn_slices > 1) {
std::vector<int> dcn_indices =
FindAllIndices(mesh_shapes[i], *option_.num_dcn_slices);
if (dcn_indices.empty()) {
VLOG(1) << " Mesh shape does not contain DCN axis.";
continue;
}

if (dcn_indices.size() > 1) {
LOG(WARNING)
<< "Could not infer a unique DCN axis. Choosing one randomly.";
}
this_option.device_mesh_alpha[dcn_indices[0]] = kDcnDeviceMeshAlpha;
this_option.device_mesh_beta[dcn_indices[0]] = kDcnDeviceMeshBeta;
}

auto pass = std::make_unique<AutoShardingImplementation>(this_option);
std::unique_ptr<HloModule> module_clone = CloneModule(module);
absl::StatusOr<bool> pass_result =
Expand Down
4 changes: 3 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ class AutoSharding : public HloModulePass {

std::vector<int64_t> GetChosenDeviceMeshShape() { return chosen_mesh_shape_; }

private:
protected:
AutoShardingOption option_;

private:
// Stores the optimal value of the objective the solver found.
double solver_optimal_objective_value_ = -1.0;
// Stores the optimal mesh shape found.
Expand Down
10 changes: 8 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ std::string AutoShardingOption::ToString() const {
lines.push_back(absl::StrCat("insert_resharding_reshapes_for_non_dot_ops: ",
insert_resharding_reshapes_for_non_dot_ops));

if (num_dcn_slices.has_value()) {
lines.push_back(absl::StrCat("num_dcn_slices: ", *num_dcn_slices));
}

return absl::StrJoin(lines, "\n");
}

Expand All @@ -164,14 +168,16 @@ absl::Status AutoShardingOption::CheckAndSetup() {
if (device_mesh_alpha.empty()) {
// Generates simple device_mesh_alpha based on the size of
// device_mesh_shape.
device_mesh_alpha = std::vector(device_mesh_shape.size(), kDeviceMeshAlpha);
device_mesh_alpha =
std::vector(device_mesh_shape.size(), kIciDeviceMeshAlpha);
VLOG(0) << "Using default values for device_mesh_alpha: "
<< absl::StrJoin(device_mesh_alpha, ",");
}
if (device_mesh_beta.empty()) {
// Generates simple device_mesh_beta based on the size of
// device_mesh_shape.
device_mesh_beta = std::vector(device_mesh_shape.size(), kDeviceMeshBeta);
device_mesh_beta =
std::vector(device_mesh_shape.size(), kIciDeviceMeshBeta);
VLOG(0) << "Using default values for device_mesh_beta: "
<< absl::StrJoin(device_mesh_beta, ",");
}
Expand Down
12 changes: 10 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_option.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@ limitations under the License.
#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_OPTION_H_

#include <cstdint>
#include <optional>
#include <string>
#include <vector>

#include "absl/status/status.h"

namespace xla {

static constexpr double kDeviceMeshAlpha = 1.0;
static constexpr double kDeviceMeshBeta = 1.0;
static constexpr double kIciDeviceMeshAlpha = 1.0;
static constexpr double kIciDeviceMeshBeta = 1.0;
// By default, assume that DCN communication is 10 times slower than ICI
// communication
static constexpr double kDcnDeviceMeshAlpha = 10.0;
static constexpr double kDcnDeviceMeshBeta = 10.0;
static constexpr double kOverbudgetCoeff = 1e6;

// Options for the autosharding pass
Expand Down Expand Up @@ -202,6 +207,9 @@ struct AutoShardingOption {
// ops in a principled manner.
bool insert_resharding_reshapes_for_non_dot_ops = false;

// The number of slices used
std::optional<int64_t> num_dcn_slices = std::nullopt;

// Prints a debug string.
std::string ToString() const;

Expand Down
33 changes: 33 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,39 @@ ENTRY matmul {
EXPECT_TRUE(changed);
}

TEST_F(AutoShardingTest, SimpleDCNTest) {
constexpr absl::string_view kHloString = R"(
HloModule module
%func (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY %entry {
%param0 = f32[32,8192]{1,0} parameter(0)
%param1 = f32[] parameter(1)
%reduce = f32[32]{0} reduce(f32[32,8192]{1,0} %param0, f32[] %param1), dimensions={1}, to_apply=%func
})";
AutoShardingOption option;
option.enable = true;
option.solve_nd_sharding_iteratively = false;
option.allow_mixed_mesh_shape = false;
option.device_mesh_shape = {8, 16};
option.num_dcn_slices = 8;

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(kHloString));
TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get()));
VLOG(5) << module->ToString();
EXPECT_TRUE(changed);
const HloInstruction* slice = FindInstruction(module.get(), "reduce");
EXPECT_NE(slice, nullptr);
EXPECT_THAT(slice,
op::Sharding("{devices=[8,16]<=[128] last_tile_dim_replicate}"));
}

TEST(NormalizeTest, NormalizeHandlesNegativeCosts) {
EdgeReshardingCostMatrix edge_cost(2, 2);
edge_cost(0, 0).communication_cost = -100;
Expand Down

0 comments on commit b910a38

Please sign in to comment.