@@ -58,9 +58,6 @@ namespace xla {
58
58
namespace spmd {
59
59
namespace {
60
60
61
- using MeshDimSet = StableSet<int >;
62
- using DimMap = StableMap</* tensor dim*/ int , /* mesh dims*/ MeshDimSet>;
63
-
64
61
// Contains base functionality common to both DotHandler and ConvHandler.
65
62
class HandlerBase {
66
63
protected:
@@ -145,33 +142,6 @@ class HandlerBase {
145
142
std::optional<HloSharding> GetShardingFromUser (const HloSharding& lhs_spec,
146
143
const HloSharding& rhs_spec);
147
144
148
- // Given a set of tensor dims, and a set of mesh dims, enumerates all mappings
149
- // where a subset of all tensor dims is mapped to a subset of mesh dims, such
150
- // that each tensor dim is mapped to at most mesh dim, and no two tensor dims
151
- // are mapped to the same mesh dim.
152
- void Enumerate (std::function<void (const DimMap&)> split_func, int tensor_rank,
153
- int current_mesh_dim_idx,
154
- const std::vector<int>& unassigned_mesh_dims,
155
- const DimMap& current_dim_map) {
156
- if (current_mesh_dim_idx == unassigned_mesh_dims.size ()) {
157
- split_func (current_dim_map);
158
- return ;
159
- }
160
- // Current mesh dim is not assigned to any tensor dim
161
- Enumerate (split_func, tensor_rank, current_mesh_dim_idx + 1 ,
162
- unassigned_mesh_dims, current_dim_map);
163
-
164
- for (int i = 0 ; i < tensor_rank; ++i) {
165
- DimMap updated_dim_map = current_dim_map;
166
- if (!updated_dim_map[i].empty () && !option_.allow_mixed_mesh_shape ) {
167
- continue ;
168
- }
169
- updated_dim_map[i].insert (unassigned_mesh_dims[current_mesh_dim_idx]);
170
- Enumerate (split_func, tensor_rank, current_mesh_dim_idx + 1 ,
171
- unassigned_mesh_dims, updated_dim_map);
172
- }
173
- }
174
-
175
145
bool IsMeshDimSetNonTrivial (const MeshDimSet& mesh_dim_set) {
176
146
return absl::c_any_of (mesh_dim_set, [&](int mesh_dim) {
177
147
return device_mesh_.dim (mesh_dim) > 1 ;
@@ -732,9 +702,8 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding(
732
702
/* compute_cost=*/ 0 , communication_cost_fn);
733
703
};
734
704
735
- Enumerate (split_func, reduction_dims.size (),
736
- /* current_mesh_dim_idx=*/ 0 , unused_mesh_dims,
737
- /* current_dim_map=*/ {});
705
+ Enumerate (split_func, reduction_dims.size (), unused_mesh_dims,
706
+ option_.allow_mixed_mesh_shape );
738
707
}
739
708
740
709
void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand (
@@ -836,8 +805,7 @@ absl::Status DotHandler::RegisterStrategies() {
836
805
[&](const DimMap& output_dim_map) {
837
806
GenerateDotShardingStrategiesFromOutputSharding (output_dim_map);
838
807
},
839
- ins_->shape ().rank (), /* current_mesh_dim_idx=*/ 0 , all_mesh_dims,
840
- /* current_dim_map=*/ {});
808
+ ins_->shape ().rank (), all_mesh_dims, option_.allow_mixed_mesh_shape );
841
809
SortStrategies ();
842
810
return absl::OkStatus ();
843
811
}
@@ -957,8 +925,7 @@ absl::Status ConvHandler::RegisterStrategies() {
957
925
[&](const DimMap& output_dim_map) {
958
926
GenerateConvolutionShardingStrategiesFromOutputSharding (output_dim_map);
959
927
},
960
- 2 , /* current_mesh_dim_idx=*/ 0 , all_mesh_dims,
961
- /* current_dim_map=*/ {});
928
+ 2 , all_mesh_dims, option_.allow_mixed_mesh_shape );
962
929
963
930
SortStrategies ();
964
931
return absl::OkStatus ();
@@ -997,9 +964,8 @@ void ConvHandler::SplitDepthwise(bool forward) {
997
964
};
998
965
std::vector<int > all_mesh_dims (device_mesh_.num_dimensions ());
999
966
std::iota (all_mesh_dims.begin (), all_mesh_dims.end (), 0 );
1000
- Enumerate (split_func, ins_->shape ().rank (), /* current_mesh_dim_idx=*/ 0 ,
1001
- all_mesh_dims,
1002
- /* current_dim_map=*/ {});
967
+ Enumerate (split_func, ins_->shape ().rank (), all_mesh_dims,
968
+ option_.allow_mixed_mesh_shape );
1003
969
}
1004
970
1005
971
} // namespace
0 commit comments