From 2fea185fb251a3bf93fb2914f6fdf83bc1038e26 Mon Sep 17 00:00:00 2001 From: Johan Mabille Date: Tue, 4 Feb 2025 17:53:10 +0100 Subject: [PATCH] Added record bact import from/export to struct_array (#336) Added record bacth import from/export to struct_array --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alexis Placet --- include/sparrow/array_api.hpp | 12 +++++ .../buffer/dynamic_bitset/dynamic_bitset.hpp | 4 +- include/sparrow/record_batch.hpp | 52 +++++++++++++++++++ src/array.cpp | 10 ++++ src/record_batch.cpp | 32 ++++++++++++ test/test_record_batch.cpp | 48 ++++++++++++++--- 6 files changed, 148 insertions(+), 10 deletions(-) diff --git a/include/sparrow/array_api.hpp b/include/sparrow/array_api.hpp index 5a0fa7e7c..e8193012b 100644 --- a/include/sparrow/array_api.hpp +++ b/include/sparrow/array_api.hpp @@ -143,6 +143,18 @@ namespace sparrow */ SPARROW_API enum data_type data_type() const; + /** + * @returns the name of the \ref array or an empty + * string if the array does not have a name. + */ + SPARROW_API std::optional name() const; + + /** + * Sets the name of the array to \ref name. + * @param name The new name of the array. + */ + SPARROW_API void set_name(std::optional name); + /** * Checks if the array has no element, i.e. whether size() == 0. */ diff --git a/include/sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp b/include/sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp index 2d75d7634..eb6da0653 100644 --- a/include/sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp +++ b/include/sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp @@ -177,7 +177,7 @@ namespace sparrow && std::unsigned_integral>) ) && (!std::same_as, std::string> && !std::same_as, std::string_view> - && !std::same_as); + && !std::same_as, const char*>); template validity_bitmap ensure_validity_bitmap(std::size_t size, R&& validity_input) @@ -185,4 +185,4 @@ namespace sparrow return detail::ensure_validity_bitmap_impl(size, std::forward(validity_input)); } -} // namespace sparrow \ No newline at end of file +} // namespace sparrow diff --git a/include/sparrow/record_batch.hpp b/include/sparrow/record_batch.hpp index 38c42ff7e..3ad2de246 100644 --- a/include/sparrow/record_batch.hpp +++ b/include/sparrow/record_batch.hpp @@ -22,6 +22,7 @@ #include #include "sparrow/array.hpp" +#include "sparrow/layout/struct_layout/struct_array.hpp" #include "sparrow/utils/contracts.hpp" #if defined(__cpp_lib_format) @@ -65,6 +66,17 @@ namespace sparrow requires(std::convertible_to, std::string> and std::same_as, array>) record_batch(NR&& names, CR&& columns); + /* + * Constructs a @ref record_batch from a range of arrays. Each array + * must have a name: if \c arr is an array, then \c arr.name(), must + * not return an empty string. + * + * @param comumns An input range of arrays + */ + template + requires std::same_as, array> + record_batch(CR&& columns); + /** * Constructs a record_batch from a list of \c std::pair. * @@ -72,6 +84,14 @@ namespace sparrow */ SPARROW_API record_batch(initializer_type init); + /** + * Construct a record batch from the given struct array. + * The array must owns its internal arrow structures. + * + * @param ar An input struct array + */ + SPARROW_API record_batch(struct_array&& ar); + SPARROW_API record_batch(const record_batch&); SPARROW_API record_batch& operator=(const record_batch&); @@ -129,6 +149,13 @@ namespace sparrow */ SPARROW_API column_range columns() const; + /** + * Moves the internal columns of the record batch into a struct_array + * object. The record batch is empty anymore after calling this + * method. + */ + SPARROW_API struct_array extract_struct_array(); + private: template @@ -169,6 +196,31 @@ namespace sparrow SPARROW_ASSERT_TRUE(check_consistency()); } + namespace detail + { + std::vector get_names(const std::vector& array_list) + { + const auto names = array_list + | std::views::transform( + [](const array& ar) + { + return ar.name().value(); + } + ); + return {names.begin(), names.end()}; + } + } + + template + requires std::same_as, array> + record_batch::record_batch(CR&& columns) + : m_name_list(detail::get_names(columns)) + , m_array_list(to_vector(std::move(columns))) + { + init_array_map(); + SPARROW_ASSERT_TRUE(check_consistency()); + } + template std::vector record_batch::to_vector(R&& range) const { diff --git a/src/array.cpp b/src/array.cpp index 2b08f14cb..5d6419052 100644 --- a/src/array.cpp +++ b/src/array.cpp @@ -47,6 +47,16 @@ namespace sparrow return p_array->data_type(); } + std::optional array::name() const + { + return get_arrow_proxy().name(); + } + + void array::set_name(std::optional name) + { + get_arrow_proxy().set_name(name); + } + bool array::empty() const { return size() == size_type(0); diff --git a/src/record_batch.cpp b/src/record_batch.cpp index 6c7ed8a20..2854dc229 100644 --- a/src/record_batch.cpp +++ b/src/record_batch.cpp @@ -35,6 +35,25 @@ namespace sparrow SPARROW_ASSERT_TRUE(check_consistency()); } + record_batch::record_batch(struct_array&& arr) + { + SPARROW_ASSERT_TRUE(owns_arrow_array(arr)); + SPARROW_ASSERT_TRUE(owns_arrow_schema(arr)); + + auto [struct_arr, struct_sch] = extract_arrow_structures(std::move(arr)); + auto n_children = static_cast(struct_arr.n_children); + m_name_list.reserve(n_children); + m_array_list.reserve(n_children); + for (std::size_t i = 0; i < n_children; ++i) + { + array arr(move_array(*(struct_arr.children[i])), move_schema(*(struct_sch.children[i]))); + m_name_list.push_back(std::string(arr.name().value())); + m_array_list.push_back(std::move(arr)); + } + init_array_map(); + SPARROW_ASSERT_TRUE(check_consistency()); + } + record_batch::record_batch(const record_batch& rhs) : m_name_list(rhs.m_name_list) , m_array_list(rhs.m_array_list) @@ -97,6 +116,16 @@ namespace sparrow return std::ranges::ref_view(m_array_list); } + struct_array record_batch::extract_struct_array() + { + for (std::size_t i = 0; i < m_name_list.size(); ++i) + { + m_array_list[i].set_name(m_name_list[i]); + } + m_array_map.clear(); + return struct_array(std::move(m_array_list)); + } + void record_batch::init_array_map() { m_array_map.clear(); @@ -114,6 +143,9 @@ namespace sparrow "The size of the names and of the array list must be the same" ); + auto iter = std::find(m_name_list.begin(), m_name_list.end(), ""); + SPARROW_ASSERT(iter == m_name_list.end(), "A column can not have an empty name"); + const auto unique_names = std::unordered_set(m_name_list.begin(), m_name_list.end()); SPARROW_ASSERT(unique_names.size() == m_name_list.size(), "The names of the columns must be unique"); diff --git a/test/test_record_batch.cpp b/test/test_record_batch.cpp index db646e048..8528c10cc 100644 --- a/test/test_record_batch.cpp +++ b/test/test_record_batch.cpp @@ -25,15 +25,22 @@ namespace sparrow { primitive_array pr0( std::ranges::iota_view{std::size_t(0), std::size_t(data_size)} - | std::views::transform( - [](auto i) - { - return static_cast(i); - } - ) + | std::views::transform( + [](auto i) + { + return static_cast(i); + } + ), + "column0" + ); + primitive_array pr1( + std::ranges::iota_view{std::int32_t(4), 4 + std::int32_t(data_size)}, + "column1" + ); + primitive_array pr2( + std::ranges::iota_view{std::int32_t(2), 2 + std::int32_t(data_size)}, + "column2" ); - primitive_array pr1(std::ranges::iota_view{std::int32_t(4), 4 + std::int32_t(data_size)}); - primitive_array pr2(std::ranges::iota_view{std::int32_t(2), 2 + std::int32_t(data_size)}); std::vector arr_list = {array(std::move(pr0)), array(std::move(pr1)), array(std::move(pr2))}; return arr_list; @@ -71,6 +78,21 @@ namespace sparrow CHECK_EQ(record.nb_columns(), 3u); CHECK_EQ(record.nb_rows(), 10u); } + + SUBCASE("from column list") + { + record_batch record(make_array_list(col_size)); + CHECK_EQ(record.nb_columns(), 3u); + CHECK_EQ(record.nb_rows(), 10u); + CHECK_FALSE(std::ranges::equal(record.names(), make_name_list())); + } + + SUBCASE("from struct array") + { + record_batch record0(struct_array(make_array_list(col_size))); + record_batch record1(make_array_list(col_size)); + CHECK_EQ(record0, record1); + } } TEST_CASE("operator==") @@ -164,6 +186,16 @@ namespace sparrow CHECK(res); } + TEST_CASE("extract_struct_array") + { + struct_array arr(make_array_list(col_size)); + struct_array control(arr); + + record_batch r(std::move(arr)); + auto extr = r.extract_struct_array(); + CHECK_EQ(extr, control); + } + #if defined(__cpp_lib_format) TEST_CASE("formatter") {