From 85ec07e5b104dcb198bf041652b5f3086384d06f Mon Sep 17 00:00:00 2001 From: sgilmore10 <74676073+sgilmore10@users.noreply.github.com> Date: Wed, 6 Sep 2023 22:38:19 -0400 Subject: [PATCH] GH-37515: [C++] Remove memory address optimization from `ChunkedArray::Equals(const std::shared_ptr& other)` if the `ChunkedArray` can have `NaN` values (#37579) ### Rationale for this change `ChunkedArray::Equals(const std::shared_ptr& other)` assumes that if the two `ChunkedArray`s share the same memory address, then they must be equal. However, this optimization doesn't take into account that `NaN` values are not considered equal by default. Consequently, this can lead to surprising, inconsistent results from a user's perspective. For example, `ChunkedArray::Equals(const std::shared_ptr& other)` and `ChunkedArray::Equals(const ChunkedArray& other)` may return different results. The program below illustrates this inconsistency: ```c++ #include #include #include #include #include arrow::Result> make_chunked_array() { arrow::NumericBuilder builder; std::shared_ptr array; ARROW_RETURN_NOT_OK(builder.AppendValues({0, 1, NAN, 2, 4})); ARROW_RETURN_NOT_OK(builder.Finish(&array)); return arrow::ChunkedArray::Make({array}); } int main(int argc, char *argv[]) { auto maybe_chunked_array = make_chunked_array(); if (!maybe_chunked_array.ok()) { return -1; } auto chunked_array = std::move(maybe_chunked_array).ValueUnsafe(); auto array = chunked_array->chunk(0); std::stringstream stream; stream << "chunked_array contents: "; stream << "\n\n"; stream << chunked_array->ToString(); stream << "\n\n"; stream << "chunked_array->Equals(chunked_array): "; stream << chunked_array->Equals(chunked_array); stream << "chunked_array->Equals(*chunked_array): "; stream << chunked_array->Equals(*chunked_array); std::cout << stream.str() << std::endl; } ``` Here is the output of this program: ```shell chunked_array contents: [ [ 0, 1, nan, 2, 4 ] ] chunked_array->Equals(chunked_array): 1 chunked_array->Equals(*chunked_array): 0 ``` ### What changes are included in this PR? Updated `ChunkedArray::Equals(const std::shared_ptr& other)` to only return `true` early IF: - The two share the same address AND - They cannot have `NaN` values If both of those conditions are not satisfied, `ChunkedArray::Equals(const std::shared_ptr& other)` will do the element-by-element comparison. ### Are these changes tested? Yes. I added a new test case called `EqualsSameAddressWithNaNs` to `chunked_array_test.cc`. ### Are there any user-facing changes? Yes. `ChunkedArray::Equals(const std::shared_ptr& other)` may return `false` even if the two `ChunkedArray`s have the same memory address. This will only occur if the `ChunkedArray`'s contain `NaN` values. * Closes: #37515 Authored-by: Sarah Gilmore Signed-off-by: Sutou Kouhei --- cpp/src/arrow/chunked_array.cc | 24 +++++++++++++++++++++--- cpp/src/arrow/chunked_array_test.cc | 29 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/chunked_array.cc b/cpp/src/arrow/chunked_array.cc index c5e6d7fa4bdf0..12937406e7800 100644 --- a/cpp/src/arrow/chunked_array.cc +++ b/cpp/src/arrow/chunked_array.cc @@ -30,6 +30,7 @@ #include "arrow/pretty_print.h" #include "arrow/status.h" #include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" @@ -111,13 +112,30 @@ bool ChunkedArray::Equals(const ChunkedArray& other) const { .ok(); } -bool ChunkedArray::Equals(const std::shared_ptr& other) const { - if (this == other.get()) { - return true; +namespace { + +bool mayHaveNaN(const arrow::DataType& type) { + if (type.num_fields() == 0) { + return is_floating(type.id()); + } else { + for (const auto& field : type.fields()) { + if (mayHaveNaN(*field->type())) { + return true; + } + } } + return false; +} + +} // namespace + +bool ChunkedArray::Equals(const std::shared_ptr& other) const { if (!other) { return false; } + if (this == other.get() && !mayHaveNaN(*type_)) { + return true; + } return Equals(*other.get()); } diff --git a/cpp/src/arrow/chunked_array_test.cc b/cpp/src/arrow/chunked_array_test.cc index 08410b4cd5367..46dccaf3c6b86 100644 --- a/cpp/src/arrow/chunked_array_test.cc +++ b/cpp/src/arrow/chunked_array_test.cc @@ -146,6 +146,35 @@ TEST_F(TestChunkedArray, EqualsDifferingMetadata) { ASSERT_TRUE(left.Equals(right)); } +TEST_F(TestChunkedArray, EqualsSameAddressWithNaNs) { + auto chunk_with_nan1 = ArrayFromJSON(float64(), "[0, 1, 2, NaN]"); + auto chunk_without_nan1 = ArrayFromJSON(float64(), "[3, 4, 5]"); + ArrayVector chunks1 = {chunk_with_nan1, chunk_without_nan1}; + ASSERT_OK_AND_ASSIGN(auto chunked_array_with_nan1, ChunkedArray::Make(chunks1)); + ASSERT_FALSE(chunked_array_with_nan1->Equals(chunked_array_with_nan1)); + + auto chunk_without_nan2 = ArrayFromJSON(float64(), "[6, 7, 8, 9]"); + ArrayVector chunks2 = {chunk_without_nan1, chunk_without_nan2}; + ASSERT_OK_AND_ASSIGN(auto chunked_array_without_nan1, ChunkedArray::Make(chunks2)); + ASSERT_TRUE(chunked_array_without_nan1->Equals(chunked_array_without_nan1)); + + auto int32_array = ArrayFromJSON(int32(), "[0, 1, 2]"); + auto float64_array_with_nan = ArrayFromJSON(float64(), "[0, 1, NaN]"); + ArrayVector arrays1 = {int32_array, float64_array_with_nan}; + std::vector fieldnames = {"Int32Type", "Float64Type"}; + ASSERT_OK_AND_ASSIGN(auto struct_with_nan, StructArray::Make(arrays1, fieldnames)); + ArrayVector chunks3 = {struct_with_nan}; + ASSERT_OK_AND_ASSIGN(auto chunked_array_with_nan2, ChunkedArray::Make(chunks3)); + ASSERT_FALSE(chunked_array_with_nan2->Equals(chunked_array_with_nan2)); + + auto float64_array_without_nan = ArrayFromJSON(float64(), "[0, 1, 2]"); + ArrayVector arrays2 = {int32_array, float64_array_without_nan}; + ASSERT_OK_AND_ASSIGN(auto struct_without_nan, StructArray::Make(arrays2, fieldnames)); + ArrayVector chunks4 = {struct_without_nan}; + ASSERT_OK_AND_ASSIGN(auto chunked_array_without_nan2, ChunkedArray::Make(chunks4)); + ASSERT_TRUE(chunked_array_without_nan2->Equals(chunked_array_without_nan2)); +} + TEST_F(TestChunkedArray, SliceEquals) { random::RandomArrayGenerator gen(42);