Skip to content

Commit

Permalink
Added record bact import from/export to struct_array (#336)
Browse files Browse the repository at this point in the history
Added record bacth import from/export to struct_array

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Alexis Placet <alexis.placet@gmail.com>
  • Loading branch information
3 people authored Feb 4, 2025
1 parent bab08c1 commit 2fea185
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 10 deletions.
12 changes: 12 additions & 0 deletions include/sparrow/array_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ namespace sparrow
*/
SPARROW_API enum data_type data_type() const;

/**
* @returns the name of the \ref array or an empty
* string if the array does not have a name.
*/
SPARROW_API std::optional<std::string_view> name() const;

/**
* Sets the name of the array to \ref name.
* @param name The new name of the array.
*/
SPARROW_API void set_name(std::optional<std::string_view> name);

/**
* Checks if the array has no element, i.e. whether size() == 0.
*/
Expand Down
4 changes: 2 additions & 2 deletions include/sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ namespace sparrow
&& std::unsigned_integral<std::ranges::range_value_t<T>>) )
&& (!std::same_as<std::remove_cvref_t<T>, std::string>
&& !std::same_as<std::remove_cvref_t<T>, std::string_view>
&& !std::same_as<T, const char*>);
&& !std::same_as<std::decay_t<T>, const char*>);

template <validity_bitmap_input R>
validity_bitmap ensure_validity_bitmap(std::size_t size, R&& validity_input)
{
return detail::ensure_validity_bitmap_impl(size, std::forward<R>(validity_input));
}

} // namespace sparrow
} // namespace sparrow
52 changes: 52 additions & 0 deletions include/sparrow/record_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "sparrow/array.hpp"
#include "sparrow/layout/struct_layout/struct_array.hpp"
#include "sparrow/utils/contracts.hpp"

#if defined(__cpp_lib_format)
Expand Down Expand Up @@ -65,13 +66,32 @@ namespace sparrow
requires(std::convertible_to<std::ranges::range_value_t<NR>, std::string> and std::same_as<std::ranges::range_value_t<CR>, array>)
record_batch(NR&& names, CR&& columns);

/*
* Constructs a @ref record_batch from a range of arrays. Each array
* must have a name: if \c arr is an array, then \c arr.name(), must
* not return an empty string.
*
* @param comumns An input range of arrays
*/
template <std::ranges::input_range CR>
requires std::same_as<std::ranges::range_value_t<CR>, array>
record_batch(CR&& columns);

/**
* Constructs a record_batch from a list of \c std::pair<name_type, array>.
*
* @param init a list of pair "name - array".
*/
SPARROW_API record_batch(initializer_type init);

/**
* Construct a record batch from the given struct array.
* The array must owns its internal arrow structures.
*
* @param ar An input struct array
*/
SPARROW_API record_batch(struct_array&& ar);

SPARROW_API record_batch(const record_batch&);
SPARROW_API record_batch& operator=(const record_batch&);

Expand Down Expand Up @@ -129,6 +149,13 @@ namespace sparrow
*/
SPARROW_API column_range columns() const;

/**
* Moves the internal columns of the record batch into a struct_array
* object. The record batch is empty anymore after calling this
* method.
*/
SPARROW_API struct_array extract_struct_array();

private:

template <class U, class R>
Expand Down Expand Up @@ -169,6 +196,31 @@ namespace sparrow
SPARROW_ASSERT_TRUE(check_consistency());
}

namespace detail
{
std::vector<record_batch::name_type> get_names(const std::vector<array>& array_list)
{
const auto names = array_list
| std::views::transform(
[](const array& ar)
{
return ar.name().value();
}
);
return {names.begin(), names.end()};
}
}

template <std::ranges::input_range CR>
requires std::same_as<std::ranges::range_value_t<CR>, array>
record_batch::record_batch(CR&& columns)
: m_name_list(detail::get_names(columns))
, m_array_list(to_vector<array>(std::move(columns)))
{
init_array_map();
SPARROW_ASSERT_TRUE(check_consistency());
}

template <class U, class R>
std::vector<U> record_batch::to_vector(R&& range) const
{
Expand Down
10 changes: 10 additions & 0 deletions src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ namespace sparrow
return p_array->data_type();
}

std::optional<std::string_view> array::name() const
{
return get_arrow_proxy().name();
}

void array::set_name(std::optional<std::string_view> name)
{
get_arrow_proxy().set_name(name);
}

bool array::empty() const
{
return size() == size_type(0);
Expand Down
32 changes: 32 additions & 0 deletions src/record_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,25 @@ namespace sparrow
SPARROW_ASSERT_TRUE(check_consistency());
}

record_batch::record_batch(struct_array&& arr)
{
SPARROW_ASSERT_TRUE(owns_arrow_array(arr));
SPARROW_ASSERT_TRUE(owns_arrow_schema(arr));

auto [struct_arr, struct_sch] = extract_arrow_structures(std::move(arr));
auto n_children = static_cast<std::size_t>(struct_arr.n_children);
m_name_list.reserve(n_children);
m_array_list.reserve(n_children);
for (std::size_t i = 0; i < n_children; ++i)
{
array arr(move_array(*(struct_arr.children[i])), move_schema(*(struct_sch.children[i])));
m_name_list.push_back(std::string(arr.name().value()));
m_array_list.push_back(std::move(arr));
}
init_array_map();
SPARROW_ASSERT_TRUE(check_consistency());
}

record_batch::record_batch(const record_batch& rhs)
: m_name_list(rhs.m_name_list)
, m_array_list(rhs.m_array_list)
Expand Down Expand Up @@ -97,6 +116,16 @@ namespace sparrow
return std::ranges::ref_view(m_array_list);
}

struct_array record_batch::extract_struct_array()
{
for (std::size_t i = 0; i < m_name_list.size(); ++i)
{
m_array_list[i].set_name(m_name_list[i]);
}
m_array_map.clear();
return struct_array(std::move(m_array_list));
}

void record_batch::init_array_map()
{
m_array_map.clear();
Expand All @@ -114,6 +143,9 @@ namespace sparrow
"The size of the names and of the array list must be the same"
);

auto iter = std::find(m_name_list.begin(), m_name_list.end(), "");
SPARROW_ASSERT(iter == m_name_list.end(), "A column can not have an empty name");

const auto unique_names = std::unordered_set<name_type>(m_name_list.begin(), m_name_list.end());
SPARROW_ASSERT(unique_names.size() == m_name_list.size(), "The names of the columns must be unique");

Expand Down
48 changes: 40 additions & 8 deletions test/test_record_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ namespace sparrow
{
primitive_array<std::uint16_t> pr0(
std::ranges::iota_view{std::size_t(0), std::size_t(data_size)}
| std::views::transform(
[](auto i)
{
return static_cast<std::uint16_t>(i);
}
)
| std::views::transform(
[](auto i)
{
return static_cast<std::uint16_t>(i);
}
),
"column0"
);
primitive_array<std::int32_t> pr1(
std::ranges::iota_view{std::int32_t(4), 4 + std::int32_t(data_size)},
"column1"
);
primitive_array<std::int32_t> pr2(
std::ranges::iota_view{std::int32_t(2), 2 + std::int32_t(data_size)},
"column2"
);
primitive_array<std::int32_t> pr1(std::ranges::iota_view{std::int32_t(4), 4 + std::int32_t(data_size)});
primitive_array<std::int32_t> pr2(std::ranges::iota_view{std::int32_t(2), 2 + std::int32_t(data_size)});

std::vector<array> arr_list = {array(std::move(pr0)), array(std::move(pr1)), array(std::move(pr2))};
return arr_list;
Expand Down Expand Up @@ -71,6 +78,21 @@ namespace sparrow
CHECK_EQ(record.nb_columns(), 3u);
CHECK_EQ(record.nb_rows(), 10u);
}

SUBCASE("from column list")
{
record_batch record(make_array_list(col_size));
CHECK_EQ(record.nb_columns(), 3u);
CHECK_EQ(record.nb_rows(), 10u);
CHECK_FALSE(std::ranges::equal(record.names(), make_name_list()));
}

SUBCASE("from struct array")
{
record_batch record0(struct_array(make_array_list(col_size)));
record_batch record1(make_array_list(col_size));
CHECK_EQ(record0, record1);
}
}

TEST_CASE("operator==")
Expand Down Expand Up @@ -164,6 +186,16 @@ namespace sparrow
CHECK(res);
}

TEST_CASE("extract_struct_array")
{
struct_array arr(make_array_list(col_size));
struct_array control(arr);

record_batch r(std::move(arr));
auto extr = r.extract_struct_array();
CHECK_EQ(extr, control);
}

#if defined(__cpp_lib_format)
TEST_CASE("formatter")
{
Expand Down

0 comments on commit 2fea185

Please sign in to comment.