Skip to content

Commit 2fff532

Browse files
Move some strategy generation utilities from auto_sharding_dot_handler.cc to
auto_sharding_strategy.h with the intention of using the utilities more broadly throughout the codebase. PiperOrigin-RevId: 731094359
1 parent d87634f commit 2fff532

File tree

3 files changed

+55
-40
lines changed

3 files changed

+55
-40
lines changed

xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc

+6-40
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ namespace xla {
5858
namespace spmd {
5959
namespace {
6060

61-
using MeshDimSet = StableSet<int>;
62-
using DimMap = StableMap</*tensor dim*/ int, /*mesh dims*/ MeshDimSet>;
63-
6461
// Contains base functionality common to both DotHandler and ConvHandler.
6562
class HandlerBase {
6663
protected:
@@ -145,33 +142,6 @@ class HandlerBase {
145142
std::optional<HloSharding> GetShardingFromUser(const HloSharding& lhs_spec,
146143
const HloSharding& rhs_spec);
147144

148-
// Given a set of tensor dims, and a set of mesh dims, enumerates all mappings
149-
// where a subset of all tensor dims is mapped to a subset of mesh dims, such
150-
// that each tensor dim is mapped to at most mesh dim, and no two tensor dims
151-
// are mapped to the same mesh dim.
152-
void Enumerate(std::function<void(const DimMap&)> split_func, int tensor_rank,
153-
int current_mesh_dim_idx,
154-
const std::vector<int>& unassigned_mesh_dims,
155-
const DimMap& current_dim_map) {
156-
if (current_mesh_dim_idx == unassigned_mesh_dims.size()) {
157-
split_func(current_dim_map);
158-
return;
159-
}
160-
// Current mesh dim is not assigned to any tensor dim
161-
Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1,
162-
unassigned_mesh_dims, current_dim_map);
163-
164-
for (int i = 0; i < tensor_rank; ++i) {
165-
DimMap updated_dim_map = current_dim_map;
166-
if (!updated_dim_map[i].empty() && !option_.allow_mixed_mesh_shape) {
167-
continue;
168-
}
169-
updated_dim_map[i].insert(unassigned_mesh_dims[current_mesh_dim_idx]);
170-
Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1,
171-
unassigned_mesh_dims, updated_dim_map);
172-
}
173-
}
174-
175145
bool IsMeshDimSetNonTrivial(const MeshDimSet& mesh_dim_set) {
176146
return absl::c_any_of(mesh_dim_set, [&](int mesh_dim) {
177147
return device_mesh_.dim(mesh_dim) > 1;
@@ -732,9 +702,8 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding(
732702
/*compute_cost=*/0, communication_cost_fn);
733703
};
734704

735-
Enumerate(split_func, reduction_dims.size(),
736-
/*current_mesh_dim_idx=*/0, unused_mesh_dims,
737-
/*current_dim_map=*/{});
705+
Enumerate(split_func, reduction_dims.size(), unused_mesh_dims,
706+
option_.allow_mixed_mesh_shape);
738707
}
739708

740709
void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand(
@@ -836,8 +805,7 @@ absl::Status DotHandler::RegisterStrategies() {
836805
[&](const DimMap& output_dim_map) {
837806
GenerateDotShardingStrategiesFromOutputSharding(output_dim_map);
838807
},
839-
ins_->shape().rank(), /*current_mesh_dim_idx=*/0, all_mesh_dims,
840-
/*current_dim_map=*/{});
808+
ins_->shape().rank(), all_mesh_dims, option_.allow_mixed_mesh_shape);
841809
SortStrategies();
842810
return absl::OkStatus();
843811
}
@@ -957,8 +925,7 @@ absl::Status ConvHandler::RegisterStrategies() {
957925
[&](const DimMap& output_dim_map) {
958926
GenerateConvolutionShardingStrategiesFromOutputSharding(output_dim_map);
959927
},
960-
2, /*current_mesh_dim_idx=*/0, all_mesh_dims,
961-
/*current_dim_map=*/{});
928+
2, all_mesh_dims, option_.allow_mixed_mesh_shape);
962929

963930
SortStrategies();
964931
return absl::OkStatus();
@@ -997,9 +964,8 @@ void ConvHandler::SplitDepthwise(bool forward) {
997964
};
998965
std::vector<int> all_mesh_dims(device_mesh_.num_dimensions());
999966
std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0);
1000-
Enumerate(split_func, ins_->shape().rank(), /*current_mesh_dim_idx=*/0,
1001-
all_mesh_dims,
1002-
/*current_dim_map=*/{});
967+
Enumerate(split_func, ins_->shape().rank(), all_mesh_dims,
968+
option_.allow_mixed_mesh_shape);
1003969
}
1004970

1005971
} // namespace

xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc

+38
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <cmath>
2020
#include <cstddef>
2121
#include <cstdint>
22+
#include <functional>
2223
#include <memory>
2324
#include <optional>
2425
#include <string>
@@ -64,6 +65,43 @@ limitations under the License.
6465
namespace xla {
6566
namespace spmd {
6667

68+
void EnumerateHelper(std::function<void(const DimMap&)> split_func,
69+
int tensor_rank, int current_mesh_dim_idx,
70+
const std::vector<int>& unassigned_mesh_dims,
71+
const DimMap& current_dim_map,
72+
bool allow_mixed_mesh_shape) {
73+
if (current_mesh_dim_idx == unassigned_mesh_dims.size()) {
74+
split_func(current_dim_map);
75+
return;
76+
}
77+
// Current mesh dim is not assigned to any tensor dim
78+
EnumerateHelper(split_func, tensor_rank, current_mesh_dim_idx + 1,
79+
unassigned_mesh_dims, current_dim_map,
80+
allow_mixed_mesh_shape);
81+
82+
for (int i = 0; i < tensor_rank; ++i) {
83+
DimMap updated_dim_map = current_dim_map;
84+
if (!updated_dim_map[i].empty() && !allow_mixed_mesh_shape) {
85+
continue;
86+
}
87+
updated_dim_map[i].insert(unassigned_mesh_dims[current_mesh_dim_idx]);
88+
EnumerateHelper(split_func, tensor_rank, current_mesh_dim_idx + 1,
89+
unassigned_mesh_dims, updated_dim_map,
90+
allow_mixed_mesh_shape);
91+
}
92+
}
93+
94+
// Map tensor dims from [0, tensor_shape.rank() - 1] to (atmost one or more,
95+
// depending on the value of allow_mixed_mesh_shape) mesh dims.
96+
void Enumerate(std::function<void(const DimMap&)> split_func,
97+
int64_t tensor_rank,
98+
const std::vector<int>& unassigned_mesh_dims,
99+
bool allow_mixed_mesh_shape) {
100+
EnumerateHelper(split_func, tensor_rank, /*current_mesh_dim_idx=*/0,
101+
unassigned_mesh_dims,
102+
/*current_dim_map=*/{}, allow_mixed_mesh_shape);
103+
}
104+
67105
bool LeafVectorsAreConsistent(const std::vector<ShardingStrategy>& one,
68106
const std::vector<ShardingStrategy>& two) {
69107
if (one.size() != two.size()) return false;

xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h

+11
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <algorithm>
2020
#include <cstddef>
2121
#include <cstdint>
22+
#include <functional>
2223
#include <iterator>
2324
#include <memory>
2425
#include <optional>
@@ -387,6 +388,16 @@ using AssociativeDotPairs =
387388
// The set of all alias pairs
388389
using AliasSet = StableSet<std::pair<NodeIdx, NodeIdx>>;
389390

391+
// Utilities for creating sharding objects
392+
using MeshDimSet = StableSet<int>;
393+
using DimMap = StableMap</*tensor dim*/ int, /*mesh dims*/ MeshDimSet>;
394+
395+
// Map tensor dims from [0, tensor_shape.rank() - 1] to (atmost one or more,
396+
// depending on the value of allow_mixed_mesh_shape) mesh dims.
397+
void Enumerate(std::function<void(const DimMap&)> split_func,
398+
int64_t tensor_rank,
399+
const std::vector<int>& unassigned_mesh_dims,
400+
bool allow_mixed_mesh_shape);
390401
} // namespace spmd
391402
} // namespace xla
392403
#endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_

0 commit comments

Comments
 (0)