From 48a8e38f5203e77f2bf6b7c6bf681e7f090809ac Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Sat, 19 Oct 2024 14:40:01 -0700 Subject: [PATCH] [xla:cpu] Add a PjRt callback to customize XLA:CPU HloModuleConfig PiperOrigin-RevId: 687685174 --- xla/debug_options_flags.cc | 7 ++++++ xla/pjrt/c/pjrt_c_api_cpu_internal.cc | 2 +- xla/pjrt/cpu/cpu_client.cc | 22 ++++++++++++----- xla/pjrt/cpu/cpu_client.h | 24 +++++++++++++------ xla/pjrt/cpu/cpu_client_test.cc | 4 +++- .../integration_tests/mock_array_test.cc | 3 ++- .../pjrt_ifrt/tfrt_cpu_client_test_lib.cc | 21 ++++++++-------- xla/python/xla.cc | 2 +- xla/service/cpu/cpu_compiler.cc | 13 +++++++++- xla/tests/pjrt_cpu_client_registry.cc | 4 +++- xla/xla.proto | 8 ++++--- 11 files changed, 78 insertions(+), 32 deletions(-) diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index ec758ff3ff90ad..e0d66a1f649cdb 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -86,6 +86,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { #endif opts.set_xla_cpu_use_thunk_runtime(true); opts.set_xla_cpu_parallel_codegen_split_count(32); + opts.set_xla_cpu_copy_insertion_use_region_analysis(false); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); opts.set_xla_cpu_prefer_vector_width(256); opts.set_xla_cpu_max_isa(""); @@ -880,6 +881,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_cpu_parallel_codegen_split_count(), "Split LLVM module into at most this many parts before codegen to enable " "parallel compilation for the CPU backend.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_copy_insertion_use_region_analysis", + bool_setter_for( + &DebugOptions::set_xla_cpu_copy_insertion_use_region_analysis), + debug_options->xla_cpu_copy_insertion_use_region_analysis(), + "Use region based analysis in copy insertion pass.")); flag_list->push_back(tsl::Flag( "xla_cpu_enable_concurrency_optimized_scheduler", bool_setter_for( diff --git a/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index 24c1e56c3e9724..43b2d283ddf564 100644 --- a/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -39,7 +39,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { xla::CpuClientOptions options; options.cpu_device_count = 4; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, - xla::GetTfrtCpuClient(options)); + xla::GetTfrtCpuClient(std::move(options))); args->client = pjrt::CreateWrapperClient(std::move(client)); return nullptr; } diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index 71b0b7c9ab5d37..c46b57da801732 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -383,7 +383,7 @@ static int CpuDeviceCount() { } absl::StatusOr> GetTfrtCpuClient( - const CpuClientOptions& options) { + CpuClientOptions options) { // Need at least CpuDeviceCount threads to launch one collective. int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count); @@ -398,7 +398,8 @@ absl::StatusOr> GetTfrtCpuClient( return std::unique_ptr(std::make_unique( options.process_id, std::move(devices), std::move(options.collectives), - num_threads, options.asynchronous)); + num_threads, options.asynchronous, + std::move(options.customize_hlo_module_config))); } // An upper bound on the number of threads to use for intra-op parallelism. It @@ -419,7 +420,8 @@ static tsl::ThreadOptions GetThreadOptions() { TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, std::shared_ptr collectives, size_t num_threads, - bool asynchronous) + bool asynchronous, + absl::AnyInvocable customize_hlo_module_config) : process_index_(process_index), owned_devices_(std::move(devices)), computation_placer_(std::make_unique()), @@ -441,7 +443,8 @@ TfrtCpuClient::TfrtCpuClient( topology_(TfrtCpuTopologyDescription::Create( platform_id(), platform_name(), platform_version(), owned_devices_, cpu::DetectMachineAttributes())), - asynchronous_(asynchronous) { + asynchronous_(asynchronous), + customize_hlo_module_config_(std::move(customize_hlo_module_config)) { for (const std::unique_ptr& device : owned_devices_) { devices_.push_back(device.get()); CHECK( @@ -708,7 +711,8 @@ static absl::StatusOr> JitCompile( const absl::Span argument_layouts, const ExecutableBuildOptions& build_options, const ExecutionOptions& execution_options, - const xla::Compiler::CompileOptions& compile_options, int num_threads) { + const xla::Compiler::CompileOptions& compile_options, int num_threads, + absl::AnyInvocable& customize_hlo_module_config) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation.GetProgramShape()); // Unoptimized HloModuleConfig. @@ -718,6 +722,11 @@ static absl::StatusOr> JitCompile( execution_options.num_replicas(), num_threads, /*aot_options=*/nullptr)); + // Apply the user-provided callback to customize the HloModuleConfig. + if (customize_hlo_module_config) { + customize_hlo_module_config(*hlo_module_config); + } + // Unoptimized HloModule. const xla::HloModuleProto& hlo_module_proto = computation.proto(); TF_ASSIGN_OR_RETURN( @@ -826,7 +835,8 @@ absl::StatusOr> TfrtCpuClient::Compile( std::unique_ptr cpu_executable, JitCompile(computation, argument_layout_pointers, build_options, execution_options, compile_options, - eigen_intraop_device()->getPool()->NumThreads())); + eigen_intraop_device()->getPool()->NumThreads(), + customize_hlo_module_config_)); auto cpu_executable_ptr = tensorflow::down_cast(cpu_executable.get()); diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index d4d4954b557a08..d7742cf99960c5 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -60,6 +60,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" @@ -255,10 +256,11 @@ class TfrtCpuDevice final : public PjRtDevice { class TfrtCpuClient final : public PjRtClient { public: - TfrtCpuClient(int process_index, - std::vector> devices, - std::shared_ptr collectives, - size_t num_threads, bool asynchronous); + TfrtCpuClient( + int process_index, std::vector> devices, + std::shared_ptr collectives, + size_t num_threads, bool asynchronous, + absl::AnyInvocable customize_hlo_module_config); ~TfrtCpuClient() override; int process_index() const override { return process_index_; } @@ -479,6 +481,9 @@ class TfrtCpuClient final : public PjRtClient { // this client. Only applies to non-parallel computations. bool asynchronous_; + // A callback to customize the HloModuleConfig for each compiled module. + absl::AnyInvocable customize_hlo_module_config_; + // Used to prevent too much parallelism: we will not enqueue next non-parallel // computation until last one is done within each user thread. // TODO(yueshengys): Consider moving the enqueuing/ordering logic to JAX via @@ -709,16 +714,21 @@ struct CpuClientOptions { // Distributed collectives implementation. Optional. If not provided, an // in-process collectives implementation will be used. std::shared_ptr collectives; + + // If defined this function will be called on the HloModuleConfig before + // compilation, and allows users to set custom flags. + absl::AnyInvocable customize_hlo_module_config; }; + absl::StatusOr> GetTfrtCpuClient( - const CpuClientOptions& options); + CpuClientOptions options); // Deprecated. Use the overload that takes 'options' instead. inline absl::StatusOr> GetTfrtCpuClient( bool asynchronous) { CpuClientOptions options; options.asynchronous = asynchronous; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); } // Deprecated. Use the overload that takes 'options' instead. @@ -730,7 +740,7 @@ inline absl::StatusOr> GetTfrtCpuClient( options.cpu_device_count = cpu_device_count; options.max_inflight_computations_per_device = max_inflight_computations_per_device; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); } } // namespace xla diff --git a/xla/pjrt/cpu/cpu_client_test.cc b/xla/pjrt/cpu/cpu_client_test.cc index a0c7599ff817a2..21f33067a714bf 100644 --- a/xla/pjrt/cpu/cpu_client_test.cc +++ b/xla/pjrt/cpu/cpu_client_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -153,7 +154,8 @@ TEST(TfrtCpuClientTest, HloSnapshot) { CpuClientOptions cpu_options; cpu_options.cpu_device_count = 1; - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(cpu_options)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetTfrtCpuClient(std::move(cpu_options))); TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnUnverifiedModule(kProgram, {})); diff --git a/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc index 6545ebbab3e8e7..9fda694e648727 100644 --- a/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc +++ b/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -124,7 +124,8 @@ class MockArrayTest : public testing::Test { CpuClientOptions options; options.asynchronous = true; options.cpu_device_count = 2; - TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, xla::GetTfrtCpuClient(options)); + TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, + xla::GetTfrtCpuClient(std::move(options))); auto mock_backend = std::make_unique( /*delegate=*/xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client))); diff --git a/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index 356c2f9e5d2f3a..4fb1ca36e33a50 100644 --- a/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -24,16 +24,17 @@ namespace xla { namespace ifrt { namespace { -const bool kUnused = (test_util::RegisterClientFactory( - []() -> absl::StatusOr> { - CpuClientOptions options; - options.cpu_device_count = 4; - TF_ASSIGN_OR_RETURN(auto pjrt_client, - xla::GetTfrtCpuClient(options)); - return std::shared_ptr( - PjRtClient::Create(std::move(pjrt_client))); - }), - true); +const bool kUnused = + (test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + CpuClientOptions options; + options.cpu_device_count = 4; + TF_ASSIGN_OR_RETURN(auto pjrt_client, + xla::GetTfrtCpuClient(std::move(options))); + return std::shared_ptr( + PjRtClient::Create(std::move(pjrt_client))); + }), + true); } // namespace } // namespace ifrt diff --git a/xla/python/xla.cc b/xla/python/xla.cc index ccf54660c88001..22237751257969 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -340,7 +340,7 @@ NB_MODULE(xla_extension, m_nb) { options.collectives = std::move(collectives); options.process_id = node_id; std::unique_ptr client = - xla::ValueOrThrow(GetTfrtCpuClient(options)); + xla::ValueOrThrow(GetTfrtCpuClient(std::move(options))); ifrt::PjRtClient::CreateOptions ifrt_options; ifrt_options.pjrt_client = std::shared_ptr(std::move(client)); diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 1b5da63773d85f..e9b6fdab8a4821 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -866,7 +866,18 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( // interfering with the rewrites. pipeline.AddPass(); pipeline.AddPass(true); - pipeline.AddPass(); + + // If enabled we'll use more precise region based analysis for copy removal. + if (module->config() + .debug_options() + .xla_cpu_copy_insertion_use_region_analysis()) { + pipeline.AddPass( + /*can_share_buffer=*/nullptr, + /*use_region_based_live_range_analysis=*/-1); + } else { + pipeline.AddPass(); + } + pipeline.AddPass(); return pipeline.Run(module).status(); } diff --git a/xla/tests/pjrt_cpu_client_registry.cc b/xla/tests/pjrt_cpu_client_registry.cc index 540cf3d59ff6b8..a9205b640b7e3c 100644 --- a/xla/tests/pjrt_cpu_client_registry.cc +++ b/xla/tests/pjrt_cpu_client_registry.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "xla/pjrt/cpu/cpu_client.h" #include "xla/tests/pjrt_client_registry.h" @@ -23,7 +25,7 @@ namespace { const bool kUnused = (RegisterPjRtClientTestFactory([]() { CpuClientOptions options; options.cpu_device_count = 4; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); }), true); diff --git a/xla/xla.proto b/xla/xla.proto index 30e7a0a5a8dc35..846c5df7654ddd 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -48,9 +48,11 @@ message DebugOptions { //--------------------------------------------------------------------------// // XLA:CPU options. //--------------------------------------------------------------------------// + // go/keep-sorted start newline_separated=yes skip_lines=1 + + // Use region analysis in copy insertion pass. + bool xla_cpu_copy_insertion_use_region_analysis = 337; - // go/keep-sorted start newline_separated=yes - // // When true, XLA:CPU uses HLO module scheduler that is optimized for // extracting concurrency at the cost of extra memory: we extend the live // ranges of temporaries to allow XLA runtime to schedule independent @@ -1020,7 +1022,7 @@ message DebugOptions { // coll.2-done = collective(coll.2-start) int32 xla_gpu_experimental_parallel_collective_overlap_limit = 336; - // Next id: 337 + // Next id: 338 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.