Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Add a PjRt callback to customize XLA:CPU HloModuleConfig #18521

Merged
merged 1 commit into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
#endif
opts.set_xla_cpu_use_thunk_runtime(true);
opts.set_xla_cpu_parallel_codegen_split_count(32);
opts.set_xla_cpu_copy_insertion_use_region_analysis(false);
opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false);
opts.set_xla_cpu_prefer_vector_width(256);
opts.set_xla_cpu_max_isa("");
Expand Down Expand Up @@ -880,6 +881,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_cpu_parallel_codegen_split_count(),
"Split LLVM module into at most this many parts before codegen to enable "
"parallel compilation for the CPU backend."));
flag_list->push_back(tsl::Flag(
"xla_cpu_copy_insertion_use_region_analysis",
bool_setter_for(
&DebugOptions::set_xla_cpu_copy_insertion_use_region_analysis),
debug_options->xla_cpu_copy_insertion_use_region_analysis(),
"Use region based analysis in copy insertion pass."));
flag_list->push_back(tsl::Flag(
"xla_cpu_enable_concurrency_optimized_scheduler",
bool_setter_for(
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_cpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) {
xla::CpuClientOptions options;
options.cpu_device_count = 4;
PJRT_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetTfrtCpuClient(options));
xla::GetTfrtCpuClient(std::move(options)));
args->client = pjrt::CreateWrapperClient(std::move(client));
return nullptr;
}
Expand Down
22 changes: 16 additions & 6 deletions xla/pjrt/cpu/cpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ static int CpuDeviceCount() {
}

absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
const CpuClientOptions& options) {
CpuClientOptions options) {
// Need at least CpuDeviceCount threads to launch one collective.
int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount());
size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count);
Expand All @@ -398,7 +398,8 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(

return std::unique_ptr<PjRtClient>(std::make_unique<TfrtCpuClient>(
options.process_id, std::move(devices), std::move(options.collectives),
num_threads, options.asynchronous));
num_threads, options.asynchronous,
std::move(options.customize_hlo_module_config)));
}

// An upper bound on the number of threads to use for intra-op parallelism. It
Expand All @@ -419,7 +420,8 @@ static tsl::ThreadOptions GetThreadOptions() {
TfrtCpuClient::TfrtCpuClient(
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives, size_t num_threads,
bool asynchronous)
bool asynchronous,
std::function<void(HloModuleConfig&)> customize_hlo_module_config)
: process_index_(process_index),
owned_devices_(std::move(devices)),
computation_placer_(std::make_unique<ComputationPlacer>()),
Expand All @@ -441,7 +443,8 @@ TfrtCpuClient::TfrtCpuClient(
topology_(TfrtCpuTopologyDescription::Create(
platform_id(), platform_name(), platform_version(), owned_devices_,
cpu::DetectMachineAttributes())),
asynchronous_(asynchronous) {
asynchronous_(asynchronous),
customize_hlo_module_config_(std::move(customize_hlo_module_config)) {
for (const std::unique_ptr<TfrtCpuDevice>& device : owned_devices_) {
devices_.push_back(device.get());
CHECK(
Expand Down Expand Up @@ -708,7 +711,8 @@ static absl::StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options,
const ExecutionOptions& execution_options,
const xla::Compiler::CompileOptions& compile_options, int num_threads) {
const xla::Compiler::CompileOptions& compile_options, int num_threads,
std::function<void(HloModuleConfig&)> customize_hlo_module_config) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
// Unoptimized HloModuleConfig.
Expand All @@ -718,6 +722,11 @@ static absl::StatusOr<std::unique_ptr<xla::Executable>> JitCompile(
execution_options.num_replicas(), num_threads,
/*aot_options=*/nullptr));

// Apply the user-provided callback to customize the HloModuleConfig.
if (customize_hlo_module_config) {
customize_hlo_module_config(*hlo_module_config);
}

// Unoptimized HloModule.
const xla::HloModuleProto& hlo_module_proto = computation.proto();
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -826,7 +835,8 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> TfrtCpuClient::Compile(
std::unique_ptr<Executable> cpu_executable,
JitCompile(computation, argument_layout_pointers, build_options,
execution_options, compile_options,
eigen_intraop_device()->getPool()->NumThreads()));
eigen_intraop_device()->getPool()->NumThreads(),
customize_hlo_module_config_));
auto cpu_executable_ptr =
tensorflow::down_cast<cpu::CpuExecutable*>(cpu_executable.get());

Expand Down
24 changes: 17 additions & 7 deletions xla/pjrt/cpu/cpu_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "xla/service/executable.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/util.h"
Expand Down Expand Up @@ -255,10 +256,11 @@ class TfrtCpuDevice final : public PjRtDevice {

class TfrtCpuClient final : public PjRtClient {
public:
TfrtCpuClient(int process_index,
std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives,
size_t num_threads, bool asynchronous);
TfrtCpuClient(
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives,
size_t num_threads, bool asynchronous,
std::function<void(HloModuleConfig&)> customize_hlo_module_config);
~TfrtCpuClient() override;

int process_index() const override { return process_index_; }
Expand Down Expand Up @@ -479,6 +481,9 @@ class TfrtCpuClient final : public PjRtClient {
// this client. Only applies to non-parallel computations.
bool asynchronous_;

// A callback to customize the HloModuleConfig for each compiled module.
std::function<void(HloModuleConfig&)> customize_hlo_module_config_;

// Used to prevent too much parallelism: we will not enqueue next non-parallel
// computation until last one is done within each user thread.
// TODO(yueshengys): Consider moving the enqueuing/ordering logic to JAX via
Expand Down Expand Up @@ -709,16 +714,21 @@ struct CpuClientOptions {
// Distributed collectives implementation. Optional. If not provided, an
// in-process collectives implementation will be used.
std::shared_ptr<cpu::CollectivesInterface> collectives;

// If defined this function will be called on the HloModuleConfig before
// compilation, and allows users to set custom flags.
std::function<void(HloModuleConfig&)> customize_hlo_module_config;
};

absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
const CpuClientOptions& options);
CpuClientOptions options);

// Deprecated. Use the overload that takes 'options' instead.
inline absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
bool asynchronous) {
CpuClientOptions options;
options.asynchronous = asynchronous;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}

// Deprecated. Use the overload that takes 'options' instead.
Expand All @@ -730,7 +740,7 @@ inline absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
options.cpu_device_count = cpu_device_count;
options.max_inflight_computations_per_device =
max_inflight_computations_per_device;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}

} // namespace xla
Expand Down
4 changes: 3 additions & 1 deletion xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -153,7 +154,8 @@ TEST(TfrtCpuClientTest, HloSnapshot) {

CpuClientOptions cpu_options;
cpu_options.cpu_device_count = 1;
TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(cpu_options));
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetTfrtCpuClient(std::move(cpu_options)));
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
ParseAndReturnUnverifiedModule(kProgram, {}));

Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/cpu/pjrt_client_test_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/pjrt_client_test.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client_test.h"

namespace xla {
namespace {
Expand All @@ -23,7 +23,7 @@ namespace {
const bool kUnused = (RegisterTestClientFactory([]() {
CpuClientOptions options;
options.cpu_device_count = 4;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}),
true);

Expand Down
3 changes: 2 additions & 1 deletion xla/python/ifrt_proxy/integration_tests/mock_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class MockArrayTest : public testing::Test {
CpuClientOptions options;
options.asynchronous = true;
options.cpu_device_count = 2;
TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, xla::GetTfrtCpuClient(options));
TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client,
xla::GetTfrtCpuClient(std::move(options)));
auto mock_backend = std::make_unique<MockClient>(
/*delegate=*/xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client)));

Expand Down
21 changes: 11 additions & 10 deletions xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ namespace xla {
namespace ifrt {
namespace {

const bool kUnused = (test_util::RegisterClientFactory(
[]() -> absl::StatusOr<std::shared_ptr<Client>> {
CpuClientOptions options;
options.cpu_device_count = 4;
TF_ASSIGN_OR_RETURN(auto pjrt_client,
xla::GetTfrtCpuClient(options));
return std::shared_ptr<Client>(
PjRtClient::Create(std::move(pjrt_client)));
}),
true);
const bool kUnused =
(test_util::RegisterClientFactory(
[]() -> absl::StatusOr<std::shared_ptr<Client>> {
CpuClientOptions options;
options.cpu_device_count = 4;
TF_ASSIGN_OR_RETURN(auto pjrt_client,
xla::GetTfrtCpuClient(std::move(options)));
return std::shared_ptr<Client>(
PjRtClient::Create(std::move(pjrt_client)));
}),
true);

} // namespace
} // namespace ifrt
Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ NB_MODULE(xla_extension, m_nb) {
options.collectives = std::move(collectives);
options.process_id = node_id;
std::unique_ptr<PjRtClient> client =
xla::ValueOrThrow(GetTfrtCpuClient(options));
xla::ValueOrThrow(GetTfrtCpuClient(std::move(options)));
ifrt::PjRtClient::CreateOptions ifrt_options;
ifrt_options.pjrt_client =
std::shared_ptr<PjRtClient>(std::move(client));
Expand Down
13 changes: 12 additions & 1 deletion xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,18 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
// interfering with the rewrites.
pipeline.AddPass<HloDCE>();
pipeline.AddPass<OptimizeInputOutputBufferAlias>(true);
pipeline.AddPass<CopyInsertion>();

// If enabled we'll use more precise region based analysis for copy removal.
if (module->config()
.debug_options()
.xla_cpu_copy_insertion_use_region_analysis()) {
pipeline.AddPass<CopyInsertion>(
/*can_share_buffer=*/nullptr,
/*use_region_based_live_range_analysis=*/-1);
} else {
pipeline.AddPass<CopyInsertion>();
}

pipeline.AddPass<HloDCE>();
return pipeline.Run(module).status();
}
Expand Down
4 changes: 3 additions & 1 deletion xla/tests/pjrt_cpu_client_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/tests/pjrt_client_registry.h"

Expand All @@ -23,7 +25,7 @@ namespace {
const bool kUnused = (RegisterPjRtClientTestFactory([]() {
CpuClientOptions options;
options.cpu_device_count = 4;
return GetTfrtCpuClient(options);
return GetTfrtCpuClient(std::move(options));
}),
true);

Expand Down
8 changes: 5 additions & 3 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ message DebugOptions {
//--------------------------------------------------------------------------//
// XLA:CPU options.
//--------------------------------------------------------------------------//
// go/keep-sorted start newline_separated=yes skip_lines=1

// Use region analysis in copy insertion pass.
bool xla_cpu_copy_insertion_use_region_analysis = 337;

// go/keep-sorted start newline_separated=yes
//
// When true, XLA:CPU uses HLO module scheduler that is optimized for
// extracting concurrency at the cost of extra memory: we extend the live
// ranges of temporaries to allow XLA runtime to schedule independent
Expand Down Expand Up @@ -1020,7 +1022,7 @@ message DebugOptions {
// coll.2-done = collective(coll.2-start)
int32 xla_gpu_experimental_parallel_collective_overlap_limit = 336;

// Next id: 337
// Next id: 338

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down
Loading