Skip to content

Commit

Permalink
[XLA] Add a utility to extract the non contracting dimensions from a dot
Browse files Browse the repository at this point in the history
operand and use it in some places found through code search.

PiperOrigin-RevId: 686333391
  • Loading branch information
blakehechtman authored and Google-ML-Automation committed Oct 16, 2024
1 parent 918e7cf commit 6c6e570
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 60 deletions.
20 changes: 6 additions & 14 deletions xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1129,20 +1129,12 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault {
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());

DimensionVector lhs_non_contracting_dims;
DimensionVector rhs_non_contracting_dims;
for (int64_t i = 0; i < lhs_rank; i++) {
if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) &&
!absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) {
lhs_non_contracting_dims.push_back(i);
}
}
for (int64_t i = 0; i < rhs_rank; i++) {
if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) &&
!absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) {
rhs_non_contracting_dims.push_back(i);
}
}
DimensionVector lhs_non_contracting_dims =
GetNonContractingDims(lhs_rank, dnums.lhs_contracting_dimensions(),
dnums.lhs_batch_dimensions());
DimensionVector rhs_non_contracting_dims =
GetNonContractingDims(rhs_rank, dnums.rhs_contracting_dimensions(),
dnums.rhs_batch_dimensions());

DimensionVector contracting_dim_sizes;
contracting_dim_sizes.reserve(dnums.lhs_contracting_dimensions_size());
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ cc_library(
deps = [
":shape_inference",
"//xla:status_macros",
"//xla:util",
"//xla/hlo/ir:hlo",
],
)
Expand Down
47 changes: 24 additions & 23 deletions xla/service/dot_as_convolution_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/shape_inference.h"
#include "xla/status_macros.h"
#include "xla/util.h"

namespace xla {
namespace dot_as_convolution_util {
Expand Down Expand Up @@ -202,30 +203,30 @@ DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) {
dnums.contracting_dims.back().output = -1;
dnums.contracting_dims.back().spatial_dim = -1;
}
for (int64_t i = 0; i < dot->operand(0)->shape().rank(); ++i) {
if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) &&
!absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) {
dnums.lhs_non_contracting_dims.emplace_back();
dnums.lhs_non_contracting_dims.back().lhs = i;
dnums.lhs_non_contracting_dims.back().rhs = -1;
dnums.lhs_non_contracting_dims.back().output =
dot_dim_numbs.lhs_batch_dimensions_size() +
dnums.lhs_non_contracting_dims.size() - 1;
dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
}
for (auto i :
GetNonContractingDims(dot->operand(0)->shape().rank(),
dot_dim_numbs.lhs_contracting_dimensions(),
dot_dim_numbs.lhs_batch_dimensions())) {
dnums.lhs_non_contracting_dims.emplace_back();
dnums.lhs_non_contracting_dims.back().lhs = i;
dnums.lhs_non_contracting_dims.back().rhs = -1;
dnums.lhs_non_contracting_dims.back().output =
dot_dim_numbs.lhs_batch_dimensions_size() +
dnums.lhs_non_contracting_dims.size() - 1;
dnums.lhs_non_contracting_dims.back().spatial_dim = -1;
}
for (int64_t i = 0; i < dot->operand(1)->shape().rank(); ++i) {
if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) &&
!absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) {
dnums.rhs_non_contracting_dims.emplace_back();
dnums.rhs_non_contracting_dims.back().lhs = -1;
dnums.rhs_non_contracting_dims.back().rhs = i;
dnums.rhs_non_contracting_dims.back().output =
dot_dim_numbs.lhs_batch_dimensions_size() +
dnums.lhs_non_contracting_dims.size() +
dnums.rhs_non_contracting_dims.size() - 1;
dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
}
for (auto i :
GetNonContractingDims(dot->operand(1)->shape().rank(),
dot_dim_numbs.rhs_contracting_dimensions(),
dot_dim_numbs.rhs_batch_dimensions())) {
dnums.rhs_non_contracting_dims.emplace_back();
dnums.rhs_non_contracting_dims.back().lhs = -1;
dnums.rhs_non_contracting_dims.back().rhs = i;
dnums.rhs_non_contracting_dims.back().output =
dot_dim_numbs.lhs_batch_dimensions_size() +
dnums.lhs_non_contracting_dims.size() +
dnums.rhs_non_contracting_dims.size() - 1;
dnums.rhs_non_contracting_dims.back().spatial_dim = -1;
}

dnums.lhs_shape_rank = dot->operand(0)->shape().rank();
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,6 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
Expand Down
16 changes: 4 additions & 12 deletions xla/service/gpu/matmul_indexing_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include <cstdint>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/types/span.h"
Expand All @@ -34,19 +33,12 @@ namespace gpu {
absl::StatusOr<std::vector<int64_t>> GetNonContractingDims(
const Shape& shape, absl::Span<const int64_t> batch_dims,
absl::Span<const int64_t> contracting_dims) {
std::vector<int64_t> non_contracting_dims;
// This is O(rank**2), but we expect rank to be small.
for (int64_t dim = 0; dim < shape.rank(); ++dim) {
bool is_batch = absl::c_count(batch_dims, dim) != 0;
bool is_contracting = absl::c_count(contracting_dims, dim) != 0;
TF_RET_CHECK(!(is_batch && is_contracting));
if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim);
}
auto nc =
::xla::GetNonContractingDims(shape.rank(), contracting_dims, batch_dims);

TF_RET_CHECK(batch_dims.size() + contracting_dims.size() +
non_contracting_dims.size() ==
TF_RET_CHECK(batch_dims.size() + contracting_dims.size() + nc.size() ==
shape.rank());
return non_contracting_dims;
return std::vector<int64_t>(nc.begin(), nc.end());
}

const tsl::protobuf::RepeatedField<int64_t>& BatchDimensionsForOperand(
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ cc_library(
":numeric_options",
":scratch_allocator",
":stream",
"//xla:util",
"//xla/tsl/lib/strings:proto_serialization",
"//xla/tsl/protobuf:dnn_proto_cc",
"@com_google_absl//absl/algorithm:container",
Expand Down
14 changes: 4 additions & 10 deletions xla/stream_executor/dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/stream_executor/numeric_options.h"
#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/tsl/protobuf/dnn.pb.h"
#include "xla/util.h"
#include "tsl/platform/ml_dtypes.h"

namespace stream_executor {
Expand Down Expand Up @@ -594,16 +595,9 @@ std::string TensorDescriptor::ToString() const {

absl::StatusOr<std::vector<int64_t>>
MatmulTensorDescriptor::GetNonContractingDims() const {
std::vector<int64_t> non_contracting_dims;
for (int64_t dim = 0; dim < tensor_.dimensions().size(); ++dim) {
bool is_batch = absl::c_count(batch_dimension_numbers_, dim) != 0;
bool is_contracting = absl::c_count(contracting_dim_, dim) != 0;
if (is_batch && is_contracting)
return absl::InternalError(
"A dimension cannot be both a batch dimension and a contracting "
"dimension.");
if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim);
}
auto nc = xla::GetNonContractingDims(
tensor_.dimensions().size(), contracting_dim_, batch_dimension_numbers_);
std::vector<int64_t> non_contracting_dims(nc.begin(), nc.end());

if (batch_dimension_numbers_.size() + contracting_dim_.size() +
non_contracting_dims.size() !=
Expand Down
14 changes: 14 additions & 0 deletions xla/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,20 @@ ConvertedDimensionNumbers ConvertDimensionNumbers(
absl::c_sort(dimensions.to_dimensions);
return dimensions;
}

DimensionVector GetNonContractingDims(
int64_t rank, absl::Span<const int64_t> contracting_dim_numbers,
absl::Span<const int64_t> batch_dim_numbers) {
DimensionVector non_contracting_dim_numbers;
for (int64_t i = 0; i < rank; ++i) {
if (!absl::c_linear_search(contracting_dim_numbers, i) &&
!absl::c_linear_search(batch_dim_numbers, i)) {
non_contracting_dim_numbers.push_back(i);
}
}
return non_contracting_dim_numbers;
}

std::string SanitizeFileName(std::string file_name) {
for (char& c : file_name) {
if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
Expand Down
6 changes: 6 additions & 0 deletions xla/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,12 @@ ConvertedDimensionNumbers ConvertDimensionNumbers(
absl::Span<const int64_t> from_dimensions,
absl::Span<const int64_t> from_sizes, absl::Span<const int64_t> to_sizes);

// Returns non contracting dimensions for a dot operand based on rank, batch and
// contracting dimension numbers.
DimensionVector GetNonContractingDims(
int64_t rank, absl::Span<const int64_t> contracting_dim_numbers,
absl::Span<const int64_t> batch_dim_numbers);

// Removes illegal characters from filenames.
std::string SanitizeFileName(std::string file_name);

Expand Down

0 comments on commit 6c6e570

Please sign in to comment.