forked from apache/arrow
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…apache#36190) ### Rationale for this change Now that the MATLAB interface supports some basic `arrow.array.Array` types, it would be helpful to start building out the tabular types (e.g. `RecordBatch` and `Table`) in parallel. This pull request contains a basic implementation of `arrow.tabular.RecordBatch` (name subject to change). ### What changes are included in this PR? 1. Added new `arrow.tabular.RecordBatch` class that can be constructed from a MATLAB `table`. 2. Added new test class `tRecordBatch`. ### Are these changes tested? Yes. 1. Added new test class `tRecordBatch` containing basic tests for the `arrow.tabular.RecordBatch` class. ### Are there any user-facing changes? Yes. 1. Added new class `arrow.tabular.RecordBatch`. **Example**: ```matlab >> matlabTable = table(uint64([1,2,3]'), [true false true]', [0.1, 0.2, 0.3]', VariableNames=["UInt64", "Boolean", "Float64"]) matlabTable = 3x3 table UInt64 Boolean Float64 ______ _______ _______ 1 true 0.1 2 false 0.2 3 true 0.3 >> arrowRecordBatch = arrow.tabular.RecordBatch(matlabTable) arrowRecordBatch = UInt64: [ 1, 2, 3 ] Boolean: [ true, false, true ] Float64: [ 0.1, 0.2, 0.3 ] >> convertedMatlabTable = table(arrowRecordBatch) convertedMatlabTable = 3x3 table UInt64 Boolean Float64 ______ _______ _______ 1 true 0.1 2 false 0.2 3 true 0.3 >> isequal(matlabTable, convertedMatlabTable) ans = logical 1 ``` 2. Added properties `NumColumns` and `ColumnNames` to `arrow.tabular.RecordBatch`: **Example**: ```matlab >> arrowRecordBatch.NumColumns ans = int32 3 >> arrowRecordBatch.ColumnNames ans = 1x3 string array "UInt64" "Boolean" "Float64" ``` 3. Added `column(i)` method to `arrow.tabular.RecordBatch` to retrieve the `i`th column of a `RecordBatch` as an `arrow.array.Array`. **Example**: ```matlab >> arrowUInt64Array = arrowRecordBatch.column(1) arrowUInt64Array = [ 1, 2, 3 ] >> class(arrowUInt64Array) ans = 'arrow.array.UInt64Array' >> arrowBooleanArray = arrowRecordBatch.column(2) arrowBooleanArray = [ true, false, true ] >> class(arrowBooleanArray) ans = 'arrow.array.UInt64Array' >> arrowFloat64Array = arrowRecordBatch.column(3) arrowFloat64Array = [ 0.1, 0.2, 0.3 ] >> class(arrowFloat64Array) ans = 'arrow.array.Float64Array' ``` 4. Added `toMATLAB` and `table` conversion methods to convert from a `RecordBatch` to a MATLAB `table`. ### Future Directions 1. Implement C++ logic for `toMATLAB` when the Arrow memory for a `RecordBatch` did originate from a MATLAB array (e.g. read from a Parquet file or somewhere else). 2. Add more supported construction interfaces (e.g. `arrow.tabular.RecordBatch(array1, ..., arrayN)`, arrow.tabular.RecordBatch.fromArrays(arrays)`, etc.). 3. Create an `arrow.tabular.Schema` class. Expose this as a public property on the `RecordBatch` class. Create related `arrow.type.Field` and `arrow.type.Type` classes. 4. Create an `arrow.tabular.Table` and related `arrow.array.ChunkedArray` class. 5. Add more `arrow.array.Array` types (e.g. `StringArray`, `TimestampArray`, `Time64Array`). 6. Create a basic workflow example of serializing a `RecordBatch` to disk using an I/O function (e.g. Parquet writing). ### Notes 1. Thanks @ sgilmore10 for your help with this pull request! 2. While writing the tests for `RecordBatch`, we stumbled upon a set of [accidentally committed diff markers] in `UInt64Array.m` or `tUInt64Array.m`. We removed these diff markers in this PR to unblock the `RecordBatch` tests. The unfortunate thing is that this wasn't caught before because MATLAB was simply ignoring the test file `tUInt64Array.m` because it had a syntax error in it. We could choose to explicitly list out all test files in the MATLAB CI workflows to try and avoid similar situations in the future, but this might get unwieldy to maintain over time as we add more tests. We are happy to hear any suggestions from other community members related to this topic. * Closes: apache#36072 Lead-authored-by: Kevin Gurney <kgurney@mathworks.com> Co-authored-by: Kevin Gurney <kevin.p.gurney@gmail.com> Co-authored-by: Sarah Gilmore <sgilmore@mathworks.com> Co-authored-by: Sutou Kouhei <kou@cozmixng.org> Signed-off-by: Sutou Kouhei <kou@clear-code.com>
- Loading branch information
1 parent
bd1ebec
commit 382230d
Showing
12 changed files
with
432 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
#include "libmexclass/proxy/ProxyManager.h" | ||
|
||
#include "arrow/matlab/array/proxy/array.h" | ||
#include "arrow/matlab/error/error.h" | ||
#include "arrow/matlab/tabular/proxy/record_batch.h" | ||
#include "arrow/type.h" | ||
#include "arrow/util/utf8.h" | ||
|
||
namespace arrow::matlab::tabular::proxy { | ||
|
||
RecordBatch::RecordBatch(std::shared_ptr<arrow::RecordBatch> record_batch) : record_batch{record_batch} { | ||
REGISTER_METHOD(RecordBatch, toString); | ||
REGISTER_METHOD(RecordBatch, numColumns); | ||
REGISTER_METHOD(RecordBatch, columnNames); | ||
} | ||
|
||
void RecordBatch::toString(libmexclass::proxy::method::Context& context) { | ||
namespace mda = ::matlab::data; | ||
mda::ArrayFactory factory; | ||
const auto maybe_utf16_string = arrow::util::UTF8StringToUTF16(record_batch->ToString()); | ||
// TODO: Add a helper macro to avoid having to write out an explicit if-statement here when handling errors. | ||
if (!maybe_utf16_string.ok()) { | ||
// TODO: This error message could probably be improved. | ||
context.error = libmexclass::error::Error{error::UNICODE_CONVERSION_ERROR_ID, maybe_utf16_string.status().message()}; | ||
return; | ||
} | ||
auto str_mda = factory.createScalar(*maybe_utf16_string); | ||
context.outputs[0] = str_mda; | ||
} | ||
|
||
libmexclass::proxy::MakeResult RecordBatch::make(const libmexclass::proxy::FunctionArguments& constructor_arguments) { | ||
namespace mda = ::matlab::data; | ||
mda::StructArray opts = constructor_arguments[0]; | ||
const mda::TypedArray<uint64_t> arrow_array_proxy_ids = opts[0]["ArrayProxyIDs"]; | ||
const mda::StringArray column_names = opts[0]["ColumnNames"]; | ||
|
||
std::vector<std::shared_ptr<arrow::Array>> arrow_arrays; | ||
// Retrieve all of the Arrow Array Proxy instances from the libmexclass ProxyManager. | ||
for (const auto& arrow_array_proxy_id : arrow_array_proxy_ids) { | ||
auto proxy = libmexclass::proxy::ProxyManager::getProxy(arrow_array_proxy_id); | ||
auto arrow_array_proxy = std::static_pointer_cast<arrow::matlab::array::proxy::Array>(proxy); | ||
auto arrow_array = arrow_array_proxy->getArray(); | ||
arrow_arrays.push_back(arrow_array); | ||
} | ||
|
||
std::vector<std::shared_ptr<Field>> fields; | ||
for (size_t i = 0; i < arrow_arrays.size(); ++i) { | ||
const auto type = arrow_arrays[i]->type(); | ||
const auto column_name_str = std::u16string(column_names[i]); | ||
const auto maybe_column_name_str = arrow::util::UTF16StringToUTF8(column_name_str); | ||
MATLAB_ERROR_IF_NOT_OK(maybe_column_name_str.status(), error::UNICODE_CONVERSION_ERROR_ID); | ||
fields.push_back(std::make_shared<arrow::Field>(*maybe_column_name_str, type)); | ||
} | ||
|
||
arrow::SchemaBuilder schema_builder; | ||
MATLAB_ERROR_IF_NOT_OK(schema_builder.AddFields(fields), error::SCHEMA_BUILDER_ADD_FIELDS_ERROR_ID); | ||
auto maybe_schema = schema_builder.Finish(); | ||
MATLAB_ERROR_IF_NOT_OK(maybe_schema.status(), error::SCHEMA_BUILDER_FINISH_ERROR_ID); | ||
|
||
const auto schema = *maybe_schema; | ||
const auto num_rows = arrow_arrays.size() == 0 ? 0 : arrow_arrays[0]->length(); | ||
const auto record_batch = arrow::RecordBatch::Make(schema, num_rows, arrow_arrays); | ||
auto record_batch_proxy = std::make_shared<arrow::matlab::tabular::proxy::RecordBatch>(record_batch); | ||
|
||
return record_batch_proxy; | ||
} | ||
|
||
void RecordBatch::numColumns(libmexclass::proxy::method::Context& context) { | ||
namespace mda = ::matlab::data; | ||
mda::ArrayFactory factory; | ||
const auto num_columns = record_batch->num_columns(); | ||
auto num_columns_mda = factory.createScalar(num_columns); | ||
context.outputs[0] = num_columns_mda; | ||
} | ||
|
||
void RecordBatch::columnNames(libmexclass::proxy::method::Context& context) { | ||
namespace mda = ::matlab::data; | ||
mda::ArrayFactory factory; | ||
const int num_columns = record_batch->num_columns(); | ||
|
||
std::vector<mda::MATLABString> column_names; | ||
for (int i = 0; i < num_columns; ++i) { | ||
const auto column_name_utf8 = record_batch->column_name(i); | ||
auto maybe_column_name_utf16 = arrow::util::UTF8StringToUTF16(column_name_utf8); | ||
// TODO: Add a helper macro to avoid having to write out an explicit if-statement here when handling errors. | ||
if (!maybe_column_name_utf16.ok()) { | ||
// TODO: This error message could probably be improved. | ||
context.error = libmexclass::error::Error{error::UNICODE_CONVERSION_ERROR_ID, maybe_column_name_utf16.status().message()}; | ||
return; | ||
} | ||
auto column_name_utf16 = *maybe_column_name_utf16; | ||
const mda::MATLABString matlab_string = mda::MATLABString(std::move(column_name_utf16)); | ||
column_names.push_back(matlab_string); | ||
} | ||
auto column_names_mda = factory.createArray({size_t{1}, static_cast<size_t>(num_columns)}, column_names.begin(), column_names.end()); | ||
context.outputs[0] = column_names_mda; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
#pragma once | ||
|
||
#include "arrow/record_batch.h" | ||
|
||
#include "libmexclass/proxy/Proxy.h" | ||
|
||
namespace arrow::matlab::tabular::proxy { | ||
|
||
class RecordBatch : public libmexclass::proxy::Proxy { | ||
public: | ||
RecordBatch(std::shared_ptr<arrow::RecordBatch> record_batch); | ||
|
||
virtual ~RecordBatch() {} | ||
|
||
static libmexclass::proxy::MakeResult make(const libmexclass::proxy::FunctionArguments& constructor_arguments); | ||
|
||
protected: | ||
void toString(libmexclass::proxy::method::Context& context); | ||
void numColumns(libmexclass::proxy::method::Context& context); | ||
void columnNames(libmexclass::proxy::method::Context& context); | ||
|
||
std::shared_ptr<arrow::RecordBatch> record_batch; | ||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
% Licensed to the Apache Software Foundation (ASF) under one or more | ||
% contributor license agreements. See the NOTICE file distributed with | ||
% this work for additional information regarding copyright ownership. | ||
% The ASF licenses this file to you under the Apache License, Version | ||
% 2.0 (the "License"); you may not use this file except in compliance | ||
% with the License. You may obtain a copy of the License at | ||
% | ||
% http://www.apache.org/licenses/LICENSE-2.0 | ||
% | ||
% Unless required by applicable law or agreed to in writing, software | ||
% distributed under the License is distributed on an "AS IS" BASIS, | ||
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||
% implied. See the License for the specific language governing | ||
% permissions and limitations under the License. | ||
|
||
classdef RecordBatch < matlab.mixin.CustomDisplay & ... | ||
matlab.mixin.Scalar | ||
%arrow.tabular.RecordBatch A tabular data structure representing | ||
% a set of arrow.array.Array objects with a fixed schema. | ||
|
||
properties (Access=private) | ||
ArrowArrays = {}; | ||
end | ||
|
||
properties (Dependent, SetAccess=private, GetAccess=public) | ||
NumColumns | ||
ColumnNames | ||
end | ||
|
||
properties (Access=protected) | ||
Proxy | ||
end | ||
|
||
methods | ||
|
||
function numColumns = get.NumColumns(obj) | ||
numColumns = obj.Proxy.numColumns(); | ||
end | ||
|
||
function columnNames = get.ColumnNames(obj) | ||
columnNames = obj.Proxy.columnNames(); | ||
end | ||
|
||
function arrowArray = column(obj, idx) | ||
arrowArray = obj.ArrowArrays{idx}; | ||
end | ||
|
||
function obj = RecordBatch(T) | ||
obj.ArrowArrays = arrow.tabular.RecordBatch.decompose(T); | ||
columnNames = string(T.Properties.VariableNames); | ||
arrayProxyIDs = arrow.tabular.RecordBatch.getArrowProxyIDs(obj.ArrowArrays); | ||
opts = struct("ArrayProxyIDs", arrayProxyIDs, ... | ||
"ColumnNames", columnNames); | ||
obj.Proxy = libmexclass.proxy.Proxy("Name", "arrow.tabular.proxy.RecordBatch", "ConstructorArguments", {opts}); | ||
end | ||
|
||
function T = table(obj) | ||
matlabArrays = cell(1, numel(obj.ArrowArrays)); | ||
|
||
for ii = 1:numel(obj.ArrowArrays) | ||
matlabArrays{ii} = toMATLAB(obj.ArrowArrays{ii}); | ||
end | ||
|
||
variableNames = matlab.lang.makeUniqueStrings(obj.ColumnNames); | ||
% NOTE: Does not currently handle edge cases like ColumnNames | ||
% matching the table DimensionNames. | ||
T = table(matlabArrays{:}, VariableNames=variableNames); | ||
end | ||
|
||
function T = toMATLAB(obj) | ||
T = obj.table(); | ||
end | ||
|
||
end | ||
|
||
methods (Static) | ||
|
||
function arrowArrays = decompose(T) | ||
% Decompose the input MATLAB table | ||
% input a cell array of equivalent arrow.array.Array | ||
% instances. | ||
arguments | ||
T table | ||
end | ||
|
||
numColumns = width(T); | ||
arrowArrays = cell(1, numColumns); | ||
|
||
% Convert each MATLAB array into a corresponding | ||
% arrow.array.Array. | ||
for ii = 1:numColumns | ||
arrowArrays{ii} = arrow.tabular.RecordBatch.makeArray(T{:, ii}); | ||
end | ||
end | ||
|
||
function arrowArray = makeArray(matlabArray) | ||
% Decompose the input MATLAB table | ||
% input a cell array of equivalent arrow.array.Array | ||
% instances. | ||
|
||
switch class(matlabArray) | ||
case "single" | ||
arrowArray = arrow.array.Float32Array(matlabArray); | ||
case "double" | ||
arrowArray = arrow.array.Float64Array(matlabArray); | ||
case "uint8" | ||
arrowArray = arrow.array.UInt8Array(matlabArray); | ||
case "uint16" | ||
arrowArray = arrow.array.UInt16Array(matlabArray); | ||
case "uint32" | ||
arrowArray = arrow.array.UInt32Array(matlabArray); | ||
case "uint64" | ||
arrowArray = arrow.array.UInt64Array(matlabArray); | ||
case "int8" | ||
arrowArray = arrow.array.Int8Array(matlabArray); | ||
case "int16" | ||
arrowArray = arrow.array.Int16Array(matlabArray); | ||
case "int32" | ||
arrowArray = arrow.array.Int32Array(matlabArray); | ||
case "int64" | ||
arrowArray = arrow.array.Int64Array(matlabArray); | ||
case "logical" | ||
arrowArray = arrow.array.BooleanArray(matlabArray); | ||
otherwise | ||
error("arrow:tabular:recordbatch:UnsupportedMatlabArrayType", ... | ||
"RecordBatch cannot be constructed from a MATLAB array of type '" + class(matlabArray) + "'."); | ||
end | ||
|
||
end | ||
|
||
function proxyIDs = getArrowProxyIDs(arrowArrays) | ||
% Extract the Proxy IDs underlying a cell array of | ||
% arrow.array.Array instances. | ||
proxyIDs = zeros(1, numel(arrowArrays), "uint64"); | ||
|
||
% Convert each MATLAB array into a corresponding | ||
% arrow.array.Array. | ||
for ii = 1:numel(arrowArrays) | ||
proxyIDs(ii) = arrowArrays{ii}.Proxy.ID; | ||
end | ||
end | ||
|
||
end | ||
|
||
methods (Access = private) | ||
function str = toString(obj) | ||
str = obj.Proxy.toString(); | ||
end | ||
end | ||
|
||
methods (Access=protected) | ||
function displayScalarObject(obj) | ||
disp(obj.toString()); | ||
end | ||
end | ||
|
||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.