Skip to content

Commit

Permalink
PR #18170: Cleanup hlo_extractor and hlo_bisect dependecies
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18170

//xla/tools:hlo_extractor and //xla/tools/hlo_bisect:hlo_bisect_state depend on //xla/tests:test_utils. This dependency should not exist. Moving the relevant functions to //xla:literal_util.
Copybara import of the project:

--
133e566 by Shraiysh Vaishay <svaishay@nvidia.com>:

Cleanup hlo_extractor and hlo_bisect dependecies

//xla/tools:hlo_extractor and //xla/tools/hlo_bisect:hlo_bisect_state
depend on //xla/tests:test_utils. This dependency should not exist.
Moving the relevant functions to //xla:literal_util.

Merging this change closes #18170

COPYBARA_INTEGRATE_REVIEW=#18170 from shraiysh:cleanup_extractor_deps 133e566
PiperOrigin-RevId: 686444491
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Oct 16, 2024
1 parent 8c79d4d commit 27ceec6
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 330 deletions.
1 change: 0 additions & 1 deletion xla/client/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ cc_library(
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/service",
"//xla/tests:test_utils",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
1 change: 0 additions & 1 deletion xla/client/lib/testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ limitations under the License.
#include "xla/service/service.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/test_utils.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down
277 changes: 277 additions & 0 deletions xla/literal_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,184 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal,
literal.Set<NativeT>(multi_index, scalar.Get<NativeT>({}));
}

template <typename FloatT>
void PopulateWithIntNext(Literal* literal) {
using BitRepT = UnsignedIntegerTypeForSizeType<sizeof(FloatT)>;
// Duplicates may be generated if we don't have enough bits.
// Skip bfloat16 and float32 subnormals.
const FloatT kFirstValue =
std::is_same_v<FloatT, bfloat16> || sizeof(FloatT) >= sizeof(float)
? std::numeric_limits<FloatT>::min()
: std::numeric_limits<FloatT>::denorm_min();
// `current` keeps track of the next value we need to populate.
auto current = literal->data<FloatT>().begin();
auto end = literal->data<FloatT>().end();
// `sign` keeps track of the sign of the next value.
bool sign = false;
while (current != end) {
// We start populating values at zero and increase magnitude from there.
*current = sign ? static_cast<FloatT>(-0.0f) : static_cast<FloatT>(0.0f);
current++;
// The next value is either the smallest denormal or normal.
auto value = sign ? -kFirstValue : kFirstValue;
// Fill the array with values of increasing magnitude until we hit a
// non-finite value.
while (current != end && Eigen::numext::isfinite(value)) {
// Populate the value.
*current = value;
// Generate the next value by lexicographically increasing the bit
// representation.
const BitRepT next_value = Eigen::numext::bit_cast<BitRepT>(value) + 1;
value = Eigen::numext::bit_cast<FloatT>(next_value);
current++;
}
// We ran out of finite values, flip the sign and begin again.
sign = !sign;
}
}

template <typename FloatT>
void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
PopulateWithIntNext<FloatT>(literal);
std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
*engine);
}

// Populates a floating point literal with random floating points sampled from a
// uniform-log distribution spanning approximately the entire range of the
// representable floating point.
template <typename FloatT>
void PopulateWithRandomFullRangeFloatingPointData(Literal* literal,
std::minstd_rand0* engine) {
constexpr float kSpecialValueProbability = 1e-6;
constexpr float kSpecialValues[] = {+0.F,
-0.F,
1.F,
-1.F,
std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity()};
constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float);
std::uniform_real_distribution<float> special_value_gen(0, 1);

// Generates floating points with a log-uniform distribution. This causes the
// exponent of the floating point to have a uniform distribution.
const int min_exp = std::numeric_limits<FloatT>::min_exponent;
const int max_exp = std::numeric_limits<FloatT>::max_exponent;
std::uniform_real_distribution<double> generator(min_exp - 1, max_exp - 1);

for (FloatT& value : literal->data<FloatT>()) {
// Each special value has a kSpecialValueProbability chance to be generated
// instead of sampling using the normal distributions.
if (special_value_gen(*engine) <
kSpecialValueProbability * kNumSpecialValues) {
value =
static_cast<FloatT>(kSpecialValues[(*engine)() % kNumSpecialValues]);
} else {
float sign = ((*engine)() % 2 == 0) ? 1 : -1;
value = static_cast<FloatT>(pow(2, generator(*engine)) * sign);
}
}
}

template <typename FloatT, typename GeneratorT>
void PopulateWithRandomFloatingPointData(Literal* literal,
std::minstd_rand0* engine) {
std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
for (FloatT& value : literal->data<FloatT>()) {
value = static_cast<FloatT>(generator(*engine));
}
}

template <typename FloatT>
void PopulateWithFloatingPointData(
Literal* literal, std::minstd_rand0* engine, bool no_duplicates,
bool use_large_range, std::optional<int64_t> max_bits_of_precision) {
using ComputeT =
std::conditional_t<sizeof(FloatT) < sizeof(float), float, FloatT>;
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
if (max_bits_of_precision.has_value()) {
CHECK(!use_large_range) << "Cannot set both use_large_range and "
"max_bits_of_precision for floating points.";
CHECK(!no_duplicates) << "Cannot set both no_duplicates and "
"max_bits_of_precision for floating points.";
std::uniform_int_distribution<int64_t> generator(
-(1 << *max_bits_of_precision), 1 << *max_bits_of_precision);
for (FloatT& value : literal->data<FloatT>()) {
int64_t temp = generator(*engine);
// We want to generate floating point numbers to a fixed precision, while
// keeping them between -1 and 1. This preserves their bits of precision
// while keeping the numbers small.
value = static_cast<FloatT>(temp * pow(2, -ceil(log2(abs(temp)))));
}
} else if (no_duplicates) {
PopulateWithNoDuplicateData<FloatT>(literal, engine);
} else if (use_large_range) {
PopulateWithRandomFullRangeFloatingPointData<FloatT>(literal, engine);
} else {
PopulateWithRandomFloatingPointData<FloatT, ComputeT>(literal, engine);
}
}

template <typename ComplexT>
void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine,
bool no_duplicates, bool use_large_range) {
using InnerFloatT = typename ComplexT::value_type;
CHECK(engine != nullptr);
CHECK_EQ(result->shape().element_type(),
primitive_util::NativeToPrimitiveType<ComplexT>());
Shape floating_point_shape = ShapeUtil::ChangeElementType(
result->shape(), primitive_util::NativeToPrimitiveType<InnerFloatT>());
Literal real_lit(floating_point_shape);
Literal imaginary_lit(floating_point_shape);

PopulateWithFloatingPointData<InnerFloatT>(
&real_lit, engine, no_duplicates, use_large_range,
/*max_bits_of_precision=*/std::nullopt);
PopulateWithFloatingPointData<InnerFloatT>(
&imaginary_lit, engine, no_duplicates, use_large_range,
/*max_bits_of_precision=*/std::nullopt);

absl::Span<const InnerFloatT> real_data = real_lit.data<InnerFloatT>();
absl::Span<const InnerFloatT> imaginary_data =
imaginary_lit.data<InnerFloatT>();
absl::Span<ComplexT> result_data = result->data<ComplexT>();
for (int i = 0; i < real_lit.data<InnerFloatT>().size(); i++) {
result_data[i] = ComplexT(real_data[i], imaginary_data[i]);
}
}

// uniform_int_distribution is not defined for 8-bit integers.
// Use 'short' for those types.
template <typename IntT>
using RngT = std::conditional_t<
sizeof(IntT) < sizeof(uint16_t),
std::conditional_t<std::numeric_limits<IntT>::is_signed, int16_t, uint16_t>,
IntT>;
template <typename IntT>
void PopulateWithRandomIntegralDataWithBounds(Literal* literal,
std::minstd_rand0* engine,
bool no_duplicates, IntT min,
IntT max) {
CHECK(engine != nullptr);
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<IntT>());
if (no_duplicates &&
ShapeUtil::ElementsIn(literal->shape()) < static_cast<int64_t>(max)) {
std::iota(literal->data<IntT>().begin(), literal->data<IntT>().end(),
static_cast<IntT>(0));
std::shuffle(literal->data<IntT>().begin(), literal->data<IntT>().end(),
*engine);
} else {
std::uniform_int_distribution<RngT<IntT>> generator(
static_cast<RngT<IntT>>(min), static_cast<RngT<IntT>>(max));
for (IntT& value : literal->data<IntT>()) {
value = static_cast<IntT>(generator(*engine));
}
}
}

} // namespace

/* static */ Literal LiteralUtil::CreateFromDimensions(
Expand Down Expand Up @@ -498,4 +676,103 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal,
return l.GetFirstInteger();
}

absl::StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random,
bool use_large_range) {
auto engine = pseudo_random ? std::make_unique<std::minstd_rand0>() : nullptr;
return MakeFakeLiteral(shape, engine.get(), /*limit=*/std::nullopt,
/*is_sorted=*/false,
/*no_duplicates=*/false, use_large_range,
/*max_bits_of_precision=*/std::nullopt);
}

absl::StatusOr<Literal> MakeFakeLiteral(
const Shape& shape, std::minstd_rand0* engine,
std::optional<std::pair<int64_t, int64_t>> limit, bool is_sorted,
bool no_duplicates, bool use_large_range,
std::optional<int64_t> max_bits_of_precision) {
if (shape.IsTuple()) {
std::vector<Literal> elements;
const auto& shape_tuple_shapes = shape.tuple_shapes();
elements.reserve(shape_tuple_shapes.size());
for (const Shape& element_shape : shape_tuple_shapes) {
TF_ASSIGN_OR_RETURN(
Literal element,
MakeFakeLiteral(element_shape, engine, limit, is_sorted,
no_duplicates, use_large_range,
max_bits_of_precision));
elements.push_back(std::move(element));
}
return LiteralUtil::MakeTupleOwned(std::move(elements));
}
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
}
// Clear tiles/element size in shape's layout before using it for creating
// literal.
Shape new_shape = shape;
new_shape.mutable_layout()->clear_tiles();
new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1);
new_shape.mutable_layout()->set_element_size_in_bits(0);
Literal literal(new_shape);

TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch<absl::Status>(
[&](auto primitive_type_constant) -> absl::Status {
if constexpr (primitive_util::IsArrayType(primitive_type_constant)) {
using NativeT = primitive_util::NativeTypeOf<primitive_type_constant>;
if constexpr (primitive_util::IsFloatingPointType(
primitive_type_constant)) {
PopulateWithFloatingPointData<NativeT>(
&literal, engine, no_duplicates, use_large_range,
max_bits_of_precision);
return absl::OkStatus();
}
if constexpr (primitive_type_constant == PRED) {
std::uniform_int_distribution<int> generator(0, 1);
TF_CHECK_OK(literal.Populate<bool>(
[&](absl::Span<const int64_t> /*indices*/) {
return generator(*engine);
}));
return absl::OkStatus();
}
if constexpr (primitive_util::IsIntegralType(
primitive_type_constant)) {
NativeT max = std::numeric_limits<NativeT>::max();
NativeT min = std::numeric_limits<NativeT>::lowest();
if (limit.has_value()) {
max = static_cast<NativeT>(limit->second);
min = static_cast<NativeT>(limit->first);
}
if (max_bits_of_precision.has_value()) {
max = std::min(max,
static_cast<NativeT>(1 << *max_bits_of_precision));
if (primitive_util::IsSignedIntegralType(
primitive_type_constant)) {
min = std::max(
min, static_cast<NativeT>(-(1 << *max_bits_of_precision)));
}
}
PopulateWithRandomIntegralDataWithBounds<NativeT>(
&literal, engine, /*no_duplicate*/ no_duplicates, min, max);
if (is_sorted) {
std::sort(literal.data<NativeT>().begin(),
literal.data<NativeT>().end());
}
return absl::OkStatus();
}
if constexpr (primitive_util::IsComplexType(
primitive_type_constant)) {
PopulateWithComplexData<NativeT>(&literal, engine, no_duplicates,
use_large_range);
return absl::OkStatus();
}
}
return Unimplemented(
"Unsupported type for fake random literal generation with bounds: "
"%s",
ShapeUtil::HumanString(shape));
},
shape.element_type()));
return std::move(literal);
}

} // namespace xla
26 changes: 26 additions & 0 deletions xla/literal_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,32 @@ template <PrimitiveType type, typename T>
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}

// Generates fake data in a literal of the given shape, or returns an error
// status if the element type is currently unhandled for fake data
// generation. See below for documentation of pseudo_random and use_large_range.
absl::StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
bool pseudo_random = true,
bool use_large_range = false);

// Similar to MakeFakeLiteral above but takes a random number generator engine
// to enable reusing the engine across randomly generated literals. 'limit' is a
// optional pair that contains the min and the max values to be sample for
// integers (integer format only). 'is_sorted' sorts the sample data for
// integers (integer format only). 'no_duplicates' indicates that there should
// be no duplicate values in each generated array. This is uniqueness is
// best-effort only. Some types (half and bfloat16) are not supported and
// uniqueness cannot be guaranteed if the number of elements exceeds the number
// of different values supported by the type. (floating point format only)
// 'use_large_range' indicates the sampled data is from the full range of the
// floating point format. (floating point format only)
// 'max_bits_of_precision' sets the data to have the given number of bits or
// less (integer or floating point formats only).
absl::StatusOr<Literal> MakeFakeLiteral(
const Shape& shape, std::minstd_rand0* engine,
std::optional<std::pair<int64_t, int64_t>> limit, bool is_sorted,
bool no_duplicates, bool use_large_range,
std::optional<int64_t> max_bits_of_precision);

} // namespace xla

#endif // XLA_LITERAL_UTIL_H_
1 change: 0 additions & 1 deletion xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ xla_cc_test(
"//xla/pjrt:pjrt_executable",
"//xla/service:hlo_proto_cc",
"//xla/tests:literal_test_util",
"//xla/tests:test_utils",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down
8 changes: 3 additions & 5 deletions xla/pjrt/cpu/cpu_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ limitations under the License.

#include "xla/pjrt/cpu/cpu_client.h"

#include "xla/service/hlo.pb.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"

#ifndef _WIN32
#include <unistd.h>
#endif
Expand All @@ -45,12 +41,14 @@ limitations under the License.
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_utils.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/types.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/file_system.h"
Expand Down
Loading

0 comments on commit 27ceec6

Please sign in to comment.