Skip to content

Commit f3694d7

Browse files
committed
feat: add support for querying Python Arrow tables directly
1 parent cda94fc commit f3694d7

26 files changed

+1287
-98
lines changed

programs/local/ArrowSchema.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "ArrowSchema.h"
2+
3+
#include <base/defines.h>
4+
5+
namespace CHDB
6+
{
7+
8+
ArrowSchemaWrapper::~ArrowSchemaWrapper()
9+
{
10+
if (arrow_schema.release != nullptr)
11+
{
12+
arrow_schema.release(&arrow_schema);
13+
chassert(!arrow_schema.release);
14+
}
15+
}
16+
17+
} // namespace CHDB

programs/local/ArrowSchema.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <arrow/c/abi.h>
4+
5+
namespace CHDB
6+
{
7+
8+
class ArrowSchemaWrapper
9+
{
10+
public:
11+
ArrowSchema arrow_schema;
12+
13+
ArrowSchemaWrapper()
14+
{
15+
arrow_schema.release = nullptr;
16+
}
17+
18+
~ArrowSchemaWrapper();
19+
};
20+
21+
} // namespace CHDB
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#include "ArrowStreamWrapper.h"
2+
#include "PyArrowTable.h"
3+
#include "PybindWrapper.h"
4+
#include "PythonImporter.h"
5+
6+
#include <Common/Exception.h>
7+
#include <base/defines.h>
8+
#include <pybind11/gil.h>
9+
#include <unordered_set>
10+
11+
namespace DB
12+
{
13+
14+
namespace ErrorCodes
15+
{
16+
extern const int PY_EXCEPTION_OCCURED;
17+
}
18+
19+
}
20+
21+
namespace py = pybind11;
22+
using namespace DB;
23+
24+
namespace CHDB
25+
{
26+
27+
/// ArrowSchemaWrapper implementation
28+
ArrowSchemaWrapper::~ArrowSchemaWrapper()
29+
{
30+
if (arrow_schema.release)
31+
{
32+
arrow_schema.release(&arrow_schema);
33+
}
34+
}
35+
36+
ArrowSchemaWrapper::ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept
37+
: arrow_schema(other.arrow_schema)
38+
{
39+
other.arrow_schema.release = nullptr;
40+
}
41+
42+
ArrowSchemaWrapper & ArrowSchemaWrapper::operator=(ArrowSchemaWrapper && other) noexcept
43+
{
44+
if (this != &other)
45+
{
46+
if (arrow_schema.release)
47+
{
48+
arrow_schema.release(&arrow_schema);
49+
}
50+
arrow_schema = other.arrow_schema;
51+
other.arrow_schema.release = nullptr;
52+
}
53+
return *this;
54+
}
55+
56+
/// ArrowArrayWrapper implementation
57+
ArrowArrayWrapper::~ArrowArrayWrapper()
58+
{
59+
if (arrow_array.release)
60+
{
61+
arrow_array.release(&arrow_array);
62+
}
63+
}
64+
65+
ArrowArrayWrapper::ArrowArrayWrapper(ArrowArrayWrapper && other) noexcept
66+
: arrow_array(other.arrow_array)
67+
{
68+
other.arrow_array.release = nullptr;
69+
}
70+
71+
ArrowArrayWrapper & ArrowArrayWrapper::operator=(ArrowArrayWrapper && other) noexcept
72+
{
73+
if (this != &other)
74+
{
75+
if (arrow_array.release)
76+
{
77+
arrow_array.release(&arrow_array);
78+
}
79+
arrow_array = other.arrow_array;
80+
other.arrow_array.release = nullptr;
81+
}
82+
return *this;
83+
}
84+
85+
/// ArrowArrayStreamWrapper implementation
86+
ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper()
87+
{
88+
if (arrow_array_stream.release)
89+
{
90+
arrow_array_stream.release(&arrow_array_stream);
91+
}
92+
}
93+
94+
ArrowArrayStreamWrapper::ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept
95+
: arrow_array_stream(other.arrow_array_stream)
96+
{
97+
other.arrow_array_stream.release = nullptr;
98+
}
99+
100+
ArrowArrayStreamWrapper & ArrowArrayStreamWrapper::operator=(ArrowArrayStreamWrapper && other) noexcept
101+
{
102+
if (this != &other)
103+
{
104+
if (arrow_array_stream.release)
105+
{
106+
arrow_array_stream.release(&arrow_array_stream);
107+
}
108+
arrow_array_stream = other.arrow_array_stream;
109+
other.arrow_array_stream.release = nullptr;
110+
}
111+
return *this;
112+
}
113+
114+
void ArrowArrayStreamWrapper::getSchema(ArrowSchemaWrapper& schema)
115+
{
116+
if (!isValid())
117+
{
118+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "ArrowArrayStream is not valid");
119+
}
120+
121+
if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema) != 0)
122+
{
123+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED,
124+
"Failed to get schema from ArrowArrayStream: {}", getError());
125+
}
126+
127+
if (!schema.arrow_schema.release)
128+
{
129+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Released schema returned from ArrowArrayStream");
130+
}
131+
}
132+
133+
std::unique_ptr<ArrowArrayWrapper> ArrowArrayStreamWrapper::getNextChunk()
134+
{
135+
chassert(isValid());
136+
137+
auto chunk = std::make_unique<ArrowArrayWrapper>();
138+
139+
/// Get next non-empty chunk, skipping empty ones
140+
do
141+
{
142+
chunk->reset();
143+
if (arrow_array_stream.get_next(&arrow_array_stream, &chunk->arrow_array) != 0)
144+
{
145+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED,
146+
"Failed to get next chunk from ArrowArrayStream: {}", getError());
147+
}
148+
149+
/// Check if we've reached the end of the stream
150+
if (!chunk->arrow_array.release)
151+
{
152+
return nullptr;
153+
}
154+
}
155+
while (chunk->arrow_array.length == 0);
156+
157+
return chunk;
158+
}
159+
160+
const char* ArrowArrayStreamWrapper::getError()
161+
{
162+
if (!isValid())
163+
{
164+
return "ArrowArrayStream is not valid";
165+
}
166+
167+
return arrow_array_stream.get_last_error(&arrow_array_stream);
168+
}
169+
170+
std::unique_ptr<ArrowArrayStreamWrapper> PyArrowStreamFactory::createFromPyObject(
171+
py::object & py_obj,
172+
const Names & column_names)
173+
{
174+
py::gil_scoped_acquire acquire;
175+
176+
try
177+
{
178+
auto arrow_object_type = PyArrowTable::getArrowType(py_obj);
179+
180+
switch (arrow_object_type)
181+
{
182+
case PyArrowObjectType::Table:
183+
return createFromTable(py_obj, column_names);
184+
default:
185+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED,
186+
"Unsupported PyArrow object type: {}", arrow_object_type);
187+
}
188+
}
189+
catch (const py::error_already_set & e)
190+
{
191+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED,
192+
"Failed to convert PyArrow object to arrow array stream: {}", e.what());
193+
}
194+
}
195+
196+
std::unique_ptr<ArrowArrayStreamWrapper> PyArrowStreamFactory::createFromTable(
197+
py::object & table,
198+
const Names & column_names)
199+
{
200+
chassert(py::gil_check());
201+
202+
py::handle table_handle(table);
203+
auto & import_cache = PythonImporter::ImportCache();
204+
auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset");
205+
206+
auto dataset = arrow_dataset(table_handle);
207+
py::object arrow_scanner = dataset.attr("__class__").attr("scanner");
208+
209+
py::dict kwargs;
210+
if (!column_names.empty()) {
211+
ArrowSchemaWrapper schema;
212+
auto obj_schema = table_handle.attr("schema");
213+
auto export_to_c = obj_schema.attr("_export_to_c");
214+
export_to_c(reinterpret_cast<uint64_t>(&schema.arrow_schema));
215+
216+
/// Get available column names from schema
217+
std::unordered_set<std::string> available_columns;
218+
if (schema.arrow_schema.n_children > 0 && schema.arrow_schema.children)
219+
{
220+
for (int64_t i = 0; i < schema.arrow_schema.n_children; ++i)
221+
{
222+
if (schema.arrow_schema.children[i] && schema.arrow_schema.children[i]->name)
223+
{
224+
available_columns.insert(schema.arrow_schema.children[i]->name);
225+
}
226+
}
227+
}
228+
229+
/// Only add column names that exist in the schema
230+
py::list projection_list;
231+
for (const auto & name : column_names)
232+
{
233+
if (available_columns.contains(name))
234+
{
235+
projection_list.append(name);
236+
}
237+
}
238+
239+
/// Only set columns if we have valid projections
240+
if (projection_list.size() > 0)
241+
{
242+
kwargs["columns"] = projection_list;
243+
}
244+
}
245+
246+
auto scanner = arrow_scanner(dataset, **kwargs);
247+
248+
auto record_batches = scanner.attr("to_reader")();
249+
auto res = std::make_unique<ArrowArrayStreamWrapper>();
250+
auto export_to_c = record_batches.attr("_export_to_c");
251+
export_to_c(reinterpret_cast<uint64_t>(&res->arrow_array_stream));
252+
return res;
253+
}
254+
255+
} // namespace CHDB
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <arrow/c/abi.h>
5+
#include <pybind11/pybind11.h>
6+
#include <Core/Names.h>
7+
8+
namespace CHDB
9+
{
10+
11+
/// Wrapper for Arrow C Data Interface structures with RAII resource management
12+
class ArrowSchemaWrapper
13+
{
14+
public:
15+
ArrowSchema arrow_schema;
16+
17+
ArrowSchemaWrapper() {
18+
arrow_schema.release = nullptr;
19+
}
20+
21+
~ArrowSchemaWrapper();
22+
23+
/// Non-copyable but moveable
24+
ArrowSchemaWrapper(const ArrowSchemaWrapper &) = delete;
25+
ArrowSchemaWrapper & operator=(const ArrowSchemaWrapper &) = delete;
26+
ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept;
27+
ArrowSchemaWrapper & operator=(ArrowSchemaWrapper && other) noexcept;
28+
};
29+
30+
class ArrowArrayWrapper
31+
{
32+
public:
33+
ArrowArray arrow_array;
34+
35+
ArrowArrayWrapper()
36+
{
37+
reset();
38+
}
39+
40+
~ArrowArrayWrapper();
41+
42+
void reset()
43+
{
44+
arrow_array.length = 0;
45+
arrow_array.release = nullptr;
46+
}
47+
48+
/// Non-copyable but moveable
49+
ArrowArrayWrapper(const ArrowArrayWrapper &) = delete;
50+
ArrowArrayWrapper & operator=(const ArrowArrayWrapper &) = delete;
51+
ArrowArrayWrapper(ArrowArrayWrapper && other) noexcept;
52+
ArrowArrayWrapper & operator=(ArrowArrayWrapper && other) noexcept;
53+
};
54+
55+
class ArrowArrayStreamWrapper
56+
{
57+
public:
58+
ArrowArrayStream arrow_array_stream;
59+
60+
ArrowArrayStreamWrapper() {
61+
arrow_array_stream.release = nullptr;
62+
}
63+
64+
~ArrowArrayStreamWrapper();
65+
66+
// Non-copyable but moveable
67+
ArrowArrayStreamWrapper(const ArrowArrayStreamWrapper&) = delete;
68+
ArrowArrayStreamWrapper& operator=(const ArrowArrayStreamWrapper&) = delete;
69+
ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept;
70+
ArrowArrayStreamWrapper& operator=(ArrowArrayStreamWrapper&& other) noexcept;
71+
72+
/// Get schema from the stream
73+
void getSchema(ArrowSchemaWrapper& schema);
74+
75+
/// Get next chunk from the stream
76+
std::unique_ptr<ArrowArrayWrapper> getNextChunk();
77+
78+
/// Get last error message
79+
const char* getError();
80+
81+
/// Check if stream is valid
82+
bool isValid() const { return arrow_array_stream.release != nullptr; }
83+
};
84+
85+
/// Factory class for creating ArrowArrayStream from Python objects
86+
class PyArrowStreamFactory
87+
{
88+
public:
89+
static std::unique_ptr<ArrowArrayStreamWrapper> createFromPyObject(
90+
pybind11::object & py_obj,
91+
const DB::Names & column_names);
92+
93+
private:
94+
static std::unique_ptr<ArrowArrayStreamWrapper> createFromTable(
95+
pybind11::object & table,
96+
const DB::Names & column_names);
97+
};
98+
99+
} // namespace CHDB

0 commit comments

Comments
 (0)