Skip to content

Commit

Permalink
[StableHLO] Add get_compiler_ir(stage='stablehlo[_serialized]') for…
Browse files Browse the repository at this point in the history
… dumping StableHLO from a TF function.

PiperOrigin-RevId: 686152556
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Oct 15, 2024
1 parent 2871d40 commit e4ca19f
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 39 deletions.
21 changes: 21 additions & 0 deletions xla/hlo/translate/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@bazel_skylib//rules:build_test.bzl", "build_test")
load("//xla:xla.bzl", "xla_cc_binary")
load("//xla/tsl:tsl.bzl", "internal_visibility")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand All @@ -11,6 +12,26 @@ package(
licenses = ["notice"],
)

cc_library(
name = "portable_api",
srcs = ["portable_api.cc"],
hdrs = ["portable_api.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
"//xla/mlir_hlo:hlo_dialect_registration",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@stablehlo//:register",
"@tsl//tsl/platform:statusor",
],
)

build_test(
name = "xla-translate_build_test",
targets = [
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/translate/hlo_to_mhlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,15 @@ cc_library(
":hlo_module_importer",
"//xla:status_macros",
"//xla/mlir/utils:error_util",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/llvm_ir:llvm_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
22 changes: 22 additions & 0 deletions xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ limitations under the License.

#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"

#include <utility>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h"
#include "xla/mlir/utils/error_util.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -69,4 +75,20 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToMlirHlo(
return module;
}

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToStablehlo(
mlir::MLIRContext& ctx, const xla::HloModule* hlo_module) {
TF_ASSIGN_OR_RETURN(
mlir::OwningOpRef<mlir::ModuleOp> module,
ConvertHloToMlirHlo(ctx, hlo_module, /*import_all_computations=*/true,
/*flatten_computation_args_result=*/true));

mlir::BaseScopedDiagnosticHandler diag_handler(&ctx);
mlir::PassManager pm(&ctx);
pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
if (failed(pm.run(*module))) {
return diag_handler.ConsumeStatus();
}
return std::move(module);
}

} // namespace xla
4 changes: 4 additions & 0 deletions xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module,
bool import_all_computations = false,
bool flatten_computation_args_result = false);

// Entrypoint for HLO to StableHLO conversion.
absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ConvertHloToStablehlo(
mlir::MLIRContext& ctx, const xla::HloModule* hlo_module);

} // namespace xla

#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_
71 changes: 71 additions & 0 deletions xla/hlo/translate/portable_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/hlo/translate/portable_api.h"

#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/dialect/Register.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
#include "xla/mlir_hlo/mhlo/IR/register.h"
#include "tsl/platform/statusor.h"

namespace xla {

std::string PrintModule(mlir::ModuleOp module) {
std::string s;
llvm::raw_string_ostream os(s);
mlir::OpPrintingFlags flags;
flags.enableDebugInfo();
module->print(os, flags);
return s;
}

void LoadHloDialects(mlir::MLIRContext& context) {
mlir::DialectRegistry registry;
mlir::stablehlo::registerAllDialects(registry);
mlir::mhlo::registerAllMhloDialects(registry);
context.appendDialectRegistry(registry);
}

absl::StatusOr<std::string> SerializeUsingBytecode(mlir::ModuleOp module) {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
mlir::BytecodeWriterConfig config;
if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) {
return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed");
}
return bytecode;
}

absl::StatusOr<std::string> ConvertHloToStablehlo(
xla::HloModule const& hlo_module, bool emit_bytecode) {
mlir::MLIRContext context;
LoadHloDialects(context);
TF_ASSIGN_OR_RETURN(auto module, ConvertHloToStablehlo(context, &hlo_module));
if (emit_bytecode) return SerializeUsingBytecode(*module);
return PrintModule(*module);
}

} // namespace xla
35 changes: 35 additions & 0 deletions xla/hlo/translate/portable_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_HLO_TRANSLATE_PORTABLE_API_H_
#define XLA_HLO_TRANSLATE_PORTABLE_API_H_

#include <string>

#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_module.h"

// This file is a portable version of the HLO API.
// Is offers a string API passthrough for MLIR datatypes and is intended
// to offer a safe means of using StableHLO opaquely in non-MLIR code.

namespace xla {

absl::StatusOr<std::string> ConvertHloToStablehlo(
xla::HloModule const& hlo_module, bool emit_bytecode = false);

}

#endif // XLA_HLO_TRANSLATE_PORTABLE_API_H_
47 changes: 8 additions & 39 deletions xla/python/mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
Expand All @@ -36,8 +35,6 @@ limitations under the License.
#include "nanobind/nanobind.h"
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/Serialization.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/hlo/builder/xla_computation.h"
Expand All @@ -58,34 +55,6 @@ namespace nb = nanobind;
namespace xla {
namespace {

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ParseModule(
mlir::MLIRContext* context, std::string_view str) {
mlir::OwningOpRef<mlir::ModuleOp> module;
context->loadDialect<mlir::func::FuncDialect>();
context->loadDialect<mlir::mhlo::MhloDialect>();
context->loadDialect<mlir::chlo::ChloDialect>();
context->loadDialect<mlir::sparse_tensor::SparseTensorDialect>();
context->loadDialect<mlir::stablehlo::StablehloDialect>();
context->loadDialect<mlir::sdy::SdyDialect>();

mlir::DialectRegistry registry;
mlir::func::registerAllExtensions(registry);
context->appendDialectRegistry(registry);

mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);
module = mlir::parseSourceString<mlir::ModuleOp>(
llvm::StringRef(str.data(), str.size()), context);
if (!module) {
return diagnostic_handler.ConsumeStatus();
}
if (failed(module->verifyInvariants())) {
VLOG(1) << "MLIR verification failed.";
module->dump();
return diagnostic_handler.ConsumeStatus();
}
return module;
}

std::string PrintModule(mlir::ModuleOp module) {
std::string s;
llvm::raw_string_ostream os(s);
Expand Down Expand Up @@ -144,7 +113,7 @@ absl::StatusOr<XlaComputation> PyMlirModuleToXlaComputation(
std::string_view mlir_module, bool use_tuple_args, bool return_tuple) {
mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseModule(&context, mlir_module));
ParseMlirModuleString(mlir_module, context));
XlaComputation computation;
// SDY dialect may be part of the module which XLA doesn't know about.
TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module));
Expand All @@ -159,13 +128,13 @@ absl::StatusOr<nb::bytes> PyMhloToStablehlo(std::string_view mlir_module) {
if (VLOG_IS_ON(3)) context.disableMultithreading();
// JAX can be customized in a way that involves operations from custom
// dialects showing up in JAX IR.
// `ParseModule` won't know about these dialects, but that's fine since we
// just want to convert MHLO ops to StableHLO ops here and leave everything
// else unchanged.
// `ParseMlirModuleString` won't know about these dialects, but that's fine
// since we just want to convert MHLO ops to StableHLO ops here and leave
// everything else unchanged.
// In order to achieve that, we're allowing unregistered dialects here.
context.allowUnregisteredDialects(true);
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseModule(&context, mlir_module));
ParseMlirModuleString(mlir_module, context));
mlir::PassManager pm(&context);
if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm);
pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
Expand All @@ -186,8 +155,8 @@ absl::StatusOr<nb::bytes> PyStablehloToMhlo(const nb::bytes& mlir_module) {
context.allowUnregisteredDialects(true);
TF_ASSIGN_OR_RETURN(
mlir::OwningOpRef<mlir::ModuleOp> module,
ParseModule(&context,
std::string_view(mlir_module.c_str(), mlir_module.size())));
ParseMlirModuleString(
std::string_view(mlir_module.c_str(), mlir_module.size()), context));
mlir::PassManager pm(&context);
if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm);
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
Expand All @@ -206,7 +175,7 @@ absl::StatusOr<nb::bytes> PySerializePortableArtifact(
mlir::MLIRContext context;
if (VLOG_IS_ON(3)) context.disableMultithreading();
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseModule(&context, mlir_module));
ParseMlirModuleString(mlir_module, context));

// Serialize portable artifact
TF_ASSIGN_OR_RETURN(
Expand Down

0 comments on commit e4ca19f

Please sign in to comment.