Skip to content

Commit

Permalink
[xla:cpu] Add a PjRt callback to customize XLA:CPU HloModuleConfig
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687685174
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Oct 20, 2024
1 parent 7e3e155 commit 48a8e38
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 32 deletions.
7 changes: 7 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
#endif
opts.set_xla_cpu_use_thunk_runtime(true);
opts.set_xla_cpu_parallel_codegen_split_count(32);
opts.set_xla_cpu_copy_insertion_use_region_analysis(false);
opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false);
opts.set_xla_cpu_prefer_vector_width(256);
opts.set_xla_cpu_max_isa("");
Expand Down Expand Up @@ -880,6 +881,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_cpu_parallel_codegen_split_count(),
"Split LLVM module into at most this many parts before codegen to enable "
"parallel compilation for the CPU backend."));
flag_list->push_back(tsl::Flag(
"xla_cpu_copy_insertion_use_region_analysis",
bool_setter_for(
&DebugOptions::set_xla_cpu_copy_insertion_use_region_analysis),
debug_options->xla_cpu_copy_insertion_use_region_analysis(),
"Use region based analysis in copy insertion pass."));
flag_list->push_back(tsl::Flag(
"xla_cpu_enable_concurrency_optimized_scheduler",
bool_setter_for(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_cpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
xla::CpuClientOptions options;
options.cpu_device_count = 4;
PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetTfrtCpuClient(options));
xla::GetTfrtCpuClient(std::move(options)));
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
Expand Down
22 changes: 16 additions & 6 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ static int CpuDeviceCount() {
}

absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
const CpuClientOptions& options) {
CpuClientOptions options) {
// Need at least CpuDeviceCount threads to launch one collective.
int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount());
size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count);
Expand All @@ -398,7 +398,8 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(

return std::unique_ptr<PjRtClient>(std::make_unique<TfrtCpuClient>(
options.process_id, std::move(devices), std::move(options.collectives),
num_threads, options.asynchronous));
num_threads, options.asynchronous,
std::move(options.customize_hlo_module_config)));
}

// An upper bound on the number of threads to use for intra-op parallelism. It
Expand All @@ -419,7 +420,8 @@ static tsl::ThreadOptions GetThreadOptions() {
TfrtCpuClient::TfrtCpuClient(
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives, size_t num_threads,
bool asynchronous)
bool asynchronous,
absl::AnyInvocable<void(HloModuleConfig&)> customize_hlo_module_config)
: process_index_(process_index),
owned_devices_(std::move(devices)),
computation_placer_(std::make_unique<ComputationPlacer>()),
Expand All @@ -441,7 +443,8 @@ TfrtCpuClient::TfrtCpuClient(
topology_(TfrtCpuTopologyDescription::Create(
platform_id(), platform_name(), platform_version(), owned_devices_,
cpu::DetectMachineAttributes())),
asynchronous_(asynchronous) {
asynchronous_(asynchronous),
customize_hlo_module_config_(std::move(customize_hlo_module_config)) {
for (const std::unique_ptr<TfrtCpuDevice>& device : owned_devices_) {
devices_.push_back(device.get());
CHECK(
Expand Down Expand Up @@ -708,7 +711,8 @@ static absl::StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options,
const ExecutionOptions& execution_options,
const xla::Compiler::CompileOptions& compile_options, int num_threads) {
const xla::Compiler::CompileOptions& compile_options, int num_threads,
absl::AnyInvocable<void(HloModuleConfig&)>& customize_hlo_module_config) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
// Unoptimized HloModuleConfig.
Expand All @@ -718,6 +722,11 @@ static absl::StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
execution_options.num_replicas(), num_threads,
/*aot_options=*/nullptr));

// Apply the user-provided callback to customize the HloModuleConfig.
if (customize_hlo_module_config) {
customize_hlo_module_config(*hlo_module_config);
}

// Unoptimized HloModule.
const xla::HloModuleProto& hlo_module_proto = computation.proto();
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -826,7 +835,8 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
std::unique_ptr<Executable> cpu_executable,
JitCompile(computation, argument_layout_pointers, build_options,
execution_options, compile_options,
eigen_intraop_device()->getPool()->NumThreads()));
eigen_intraop_device()->getPool()->NumThreads(),
customize_hlo_module_config_));
auto cpu_executable_ptr =
tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable.get());

Expand Down
24 changes: 17 additions & 7 deletions xla/pjrt/cpu/cpu_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "xla/service/executable.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/util.h"
Expand Down Expand Up @@ -255,10 +256,11 @@ class TfrtCpuDevice final : public PjRtDevice {

class TfrtCpuClient final : public PjRtClient {
public:
TfrtCpuClient(int process_index,
std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives,
size_t num_threads, bool asynchronous);
TfrtCpuClient(
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives,
size_t num_threads, bool asynchronous,
absl::AnyInvocable<void(HloModuleConfig&)> customize_hlo_module_config);
~TfrtCpuClient() override;

int process_index() const override { return process_index_; }
Expand Down Expand Up @@ -479,6 +481,9 @@ class TfrtCpuClient final : public PjRtClient {
// this client. Only applies to non-parallel computations.
bool asynchronous_;

// A callback to customize the HloModuleConfig for each compiled module.
absl::AnyInvocable<void(HloModuleConfig&)> customize_hlo_module_config_;

// Used to prevent too much parallelism: we will not enqueue next non-parallel
// computation until last one is done within each user thread.
// TODO(yueshengys): Consider moving the enqueuing/ordering logic to JAX via
Expand Down Expand Up @@ -709,16 +714,21 @@ struct CpuClientOptions {
// Distributed collectives implementation. Optional. If not provided, an
// in-process collectives implementation will be used.
std::shared_ptr<cpu::CollectivesInterface> collectives;

// If defined this function will be called on the HloModuleConfig before
// compilation, and allows users to set custom flags.
absl::AnyInvocable<void(HloModuleConfig&)> customize_hlo_module_config;
};

absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
const CpuClientOptions& options);
CpuClientOptions options);

// Deprecated. Use the overload that takes 'options' instead.
inline absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
bool asynchronous) {
CpuClientOptions options;
options.asynchronous = asynchronous;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}

// Deprecated. Use the overload that takes 'options' instead.
Expand All @@ -730,7 +740,7 @@ inline absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
options.cpu_device_count = cpu_device_count;
options.max_inflight_computations_per_device =
max_inflight_computations_per_device;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}

} // namespace xla
Expand Down
4 changes: 3 additions & 1 deletion xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -153,7 +154,8 @@ TEST(TfrtCpuClientTest, HloSnapshot) {

CpuClientOptions cpu_options;
cpu_options.cpu_device_count = 1;
TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(cpu_options));
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetTfrtCpuClient(std::move(cpu_options)));
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kProgram, {}));

Expand Down
3 changes: 2 additions & 1 deletion xla/python/ifrt_proxy/integration_tests/mock_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class MockArrayTest : public testing::Test {
CpuClientOptions options;
options.asynchronous = true;
options.cpu_device_count = 2;
TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, xla::GetTfrtCpuClient(options));
TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client,
xla::GetTfrtCpuClient(std::move(options)));
auto mock_backend = std::make_unique<MockClient>(
/*delegate=*/xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client)));

Expand Down
21 changes: 11 additions & 10 deletions xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ namespace xla {
namespace ifrt {
namespace {

const bool kUnused = (test_util::RegisterClientFactory(
[]() -> absl::StatusOr<std::shared_ptr<Client>> {
CpuClientOptions options;
options.cpu_device_count = 4;
TF_ASSIGN_OR_RETURN(auto pjrt_client,
xla::GetTfrtCpuClient(options));
return std::shared_ptr<Client>(
PjRtClient::Create(std::move(pjrt_client)));
}),
true);
const bool kUnused =
(test_util::RegisterClientFactory(
[]() -> absl::StatusOr<std::shared_ptr<Client>> {
CpuClientOptions options;
options.cpu_device_count = 4;
TF_ASSIGN_OR_RETURN(auto pjrt_client,
xla::GetTfrtCpuClient(std::move(options)));
return std::shared_ptr<Client>(
PjRtClient::Create(std::move(pjrt_client)));
}),
true);

} // namespace
} // namespace ifrt
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ NB_MODULE(xla_extension, m_nb) {
options.collectives = std::move(collectives);
options.process_id = node_id;
std::unique_ptr<PjRtClient> client =
xla::ValueOrThrow(GetTfrtCpuClient(options));
xla::ValueOrThrow(GetTfrtCpuClient(std::move(options)));
ifrt::PjRtClient::CreateOptions ifrt_options;
ifrt_options.pjrt_client =
std::shared_ptr<PjRtClient>(std::move(client));
Expand Down
13 changes: 12 additions & 1 deletion xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,18 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
// interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<OptimizeInputOutputBufferAlias>(true);
pipeline.AddPass<CopyInsertion>();

// If enabled we'll use more precise region based analysis for copy removal.
if (module->config()
.debug_options()
.xla_cpu_copy_insertion_use_region_analysis()) {
pipeline.AddPass<CopyInsertion>(
/*can_share_buffer=*/nullptr,
/*use_region_based_live_range_analysis=*/-1);
} else {
pipeline.AddPass<CopyInsertion>();
}

pipeline.AddPass<HloDCE>();
return pipeline.Run(module).status();
}
Expand Down
4 changes: 3 additions & 1 deletion xla/tests/pjrt_cpu_client_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/tests/pjrt_client_registry.h"

Expand All @@ -23,7 +25,7 @@ namespace {
const bool kUnused = (RegisterPjRtClientTestFactory([]() {
CpuClientOptions options;
options.cpu_device_count = 4;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}),
true);

Expand Down
8 changes: 5 additions & 3 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ message DebugOptions {
//--------------------------------------------------------------------------//
// XLA:CPU options.
//--------------------------------------------------------------------------//
// go/keep-sorted start newline_separated=yes skip_lines=1

// Use region analysis in copy insertion pass.
bool xla_cpu_copy_insertion_use_region_analysis = 337;

// go/keep-sorted start newline_separated=yes
//
// When true, XLA:CPU uses HLO module scheduler that is optimized for
// extracting concurrency at the cost of extra memory: we extend the live
// ranges of temporaries to allow XLA runtime to schedule independent
Expand Down Expand Up @@ -1020,7 +1022,7 @@ message DebugOptions {
// coll.2-done = collective(coll.2-start)
int32 xla_gpu_experimental_parallel_collective_overlap_limit = 336;

// Next id: 337
// Next id: 338

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit 48a8e38

Please sign in to comment.