diff --git a/.github/workflows/build_linux_arm64_wheels-gh.yml b/.github/workflows/build_linux_arm64_wheels-gh.yml index 37cf62ad904..e661ad544aa 100644 --- a/.github/workflows/build_linux_arm64_wheels-gh.yml +++ b/.github/workflows/build_linux_arm64_wheels-gh.yml @@ -138,6 +138,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTest.sh - name: Check ccache statistics run: | ccache -s diff --git a/.github/workflows/build_linux_x86_wheels.yml b/.github/workflows/build_linux_x86_wheels.yml index 3ff06698d27..d3b4ca61efe 100644 --- a/.github/workflows/build_linux_x86_wheels.yml +++ b/.github/workflows/build_linux_x86_wheels.yml @@ -138,6 +138,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTest.sh - name: Check ccache statistics run: | ccache -s diff --git a/.github/workflows/build_macos_arm64_wheels.yml b/.github/workflows/build_macos_arm64_wheels.yml index 96ef0b988a6..aeb71d2fc41 100644 --- a/.github/workflows/build_macos_arm64_wheels.yml +++ b/.github/workflows/build_macos_arm64_wheels.yml @@ -22,8 +22,21 @@ on: jobs: build_universal_wheel: name: Build Universal Wheel (macOS ARM64) - runs-on: macos-13-xlarge + runs-on: macos-14-xlarge steps: + - name: Check machine architecture + run: | + echo "=== Machine Architecture Information ===" + echo "Machine type: $(uname -m)" + echo "Architecture: $(arch)" + echo "System info: $(uname -a)" + echo "Hardware info:" + system_profiler SPHardwareDataType | grep "Chip\|Processor" + if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then + echo "This is an ARM64 (Apple Silicon) machine" + else + echo "This is an x86_64 (Intel) machine" + fi - name: Setup pyenv run: | curl https://pyenv.run | bash @@ -79,7 +92,7 @@ jobs: brew install ca-certificates lz4 mpdecimal openssl@3 readline sqlite xz z3 zstd brew install --ignore-dependencies llvm@19 brew install git ninja libtool gettext gcc binutils grep findutils nasm - brew install --build-from-source ccache + brew install ccache || echo "ccache installation failed, continuing without it" brew install go cd /usr/local/opt/ && sudo rm -f llvm && sudo ln -sf llvm@19 llvm export PATH=$(brew --prefix llvm@19)/bin:$PATH @@ -97,7 +110,7 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2 with: - key: macos-13-xlarge + key: macos-14-xlarge max-size: 5G append-timestamp: true - name: Run chdb/build.sh @@ -141,6 +154,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | diff --git a/.github/workflows/build_macos_x86_wheels.yml b/.github/workflows/build_macos_x86_wheels.yml index 85ebe048c87..35cc72e284d 100644 --- a/.github/workflows/build_macos_x86_wheels.yml +++ b/.github/workflows/build_macos_x86_wheels.yml @@ -22,8 +22,21 @@ on: jobs: build_universal_wheel: name: Build Universal Wheel (macOS x86_64) - runs-on: macos-13 + runs-on: macos-14-large steps: + - name: Check machine architecture + run: | + echo "=== Machine Architecture Information ===" + echo "Machine type: $(uname -m)" + echo "Architecture: $(arch)" + echo "System info: $(uname -a)" + echo "Hardware info:" + system_profiler SPHardwareDataType | grep "Chip\|Processor" + if sysctl -n hw.optional.arm64 2>/dev/null | grep -q "1"; then + echo "This is an ARM64 (Apple Silicon) machine" + else + echo "This is an x86_64 (Intel) machine" + fi - name: Setup pyenv run: | curl https://pyenv.run | bash @@ -79,7 +92,7 @@ jobs: brew install ca-certificates lz4 mpdecimal openssl@3 readline sqlite xz z3 zstd brew install --ignore-dependencies llvm@19 brew install git ninja libtool gettext gcc binutils grep findutils nasm - brew install --build-from-source ccache + brew install ccache || echo "ccache installation failed, continuing without it" brew install go cd /usr/local/opt/ && sudo rm -f llvm && sudo ln -sf llvm@19 llvm export PATH=$(brew --prefix llvm@19)/bin:$PATH @@ -97,7 +110,7 @@ jobs: - name: ccache uses: hendrikmuhs/ccache-action@v1.2 with: - key: macos-13-x86_64 + key: macos-14-x86_64 max-size: 5G append-timestamp: true - name: Run chdb/build.sh @@ -142,6 +155,7 @@ jobs: - name: Run libchdb stub in examples dir run: | bash -x ./examples/runStub.sh + bash -x ./examples/runArrowTest.sh - name: Keep killall ccache and wait for ccache to finish if: always() run: | diff --git a/README.md b/README.md index 83760e75590..c8052829995 100644 --- a/README.md +++ b/README.md @@ -416,11 +416,7 @@ chDB automatically converts Python dictionary objects to ClickHouse JSON types f ``` - Columns are converted to `String` if sampling finds non-dictionary values. -2. **Arrow Table** - - `struct` type columns are automatically mapped to JSON columns. - - Nested structures preserve type information. - -3. **chdb.PyReader** +2. **chdb.PyReader** - Implement custom schema mapping in `get_schema()`: ```python def get_schema(self): diff --git a/chdb/__init__.py b/chdb/__init__.py index 0674a46927c..b4132409aa8 100644 --- a/chdb/__init__.py +++ b/chdb/__init__.py @@ -103,7 +103,9 @@ def to_arrowTable(res): raise ImportError("Failed to import pyarrow or pandas") from None if len(res) == 0: return pa.Table.from_batches([], schema=pa.schema([])) - return pa.RecordBatchFileReader(res.bytes()).read_all() + + memview = res.get_memview() + return pa.RecordBatchFileReader(memview.view()).read_all() # return pandas dataframe diff --git a/chdb/state/sqlitelike.py b/chdb/state/sqlitelike.py index 7694cb42ece..e743e1f722b 100644 --- a/chdb/state/sqlitelike.py +++ b/chdb/state/sqlitelike.py @@ -62,7 +62,9 @@ def to_arrowTable(res): raise ImportError("Failed to import pyarrow or pandas") from None if len(res) == 0: return pa.Table.from_batches([], schema=pa.schema([])) - return pa.RecordBatchFileReader(res.bytes()).read_all() + + memview = res.get_memview() + return pa.RecordBatchFileReader(memview.view()).read_all() # return pandas dataframe diff --git a/examples/chdbArrowTest.c b/examples/chdbArrowTest.c new file mode 100644 index 00000000000..f91d4d4c8c4 --- /dev/null +++ b/examples/chdbArrowTest.c @@ -0,0 +1,1000 @@ +#include +#include +#include +#include +#include +#include + +#include "../programs/local/chdb.h" +#include "../contrib/arrow/cpp/src/arrow/c/abi.h" + +// Custom ArrowArrayStream implementation data +typedef struct CustomStreamData +{ + bool schema_sent; + size_t current_row; + size_t total_rows; + size_t batch_size; + char* last_error; +} CustomStreamData; + +// Function to initialize CustomStreamData +static void init_custom_stream_data(CustomStreamData * data) +{ + data->schema_sent = false; + data->current_row = 0; + data->total_rows = 1000000; + data->batch_size = 10000; + data->last_error = NULL; +} + +// Reset the stream to allow reading from the beginning +static void reset_custom_stream_data(CustomStreamData * data) +{ + data->current_row = 0; + if (data->last_error) { + free(data->last_error); + data->last_error = NULL; + } +} + +// Release function prototypes +static void release_schema_child(struct ArrowSchema * s); +static void release_schema_main(struct ArrowSchema * s); +static void release_id_array(struct ArrowArray * arr); +static void release_string_array(struct ArrowArray * arr); +static void release_main_array(struct ArrowArray * arr); + +// Helper function to find minimum of two values +static size_t min_size_t(size_t a, size_t b) +{ + return (a < b) ? a : b; +} + +// Release function implementations +static void release_schema_child(struct ArrowSchema * s) +{ + s->release = NULL; +} + +static void release_schema_main(struct ArrowSchema * s) +{ + if (s->children) + { + for (int64_t i = 0; i < s->n_children; i++) + { + if (s->children[i] && s->children[i]->release) + { + s->children[i]->release(s->children[i]); + } + free(s->children[i]); + } + free(s->children); + } + s->release = NULL; +} + +static void release_id_array(struct ArrowArray * arr) +{ + if (arr->buffers) + { + free((void*)(uintptr_t)arr->buffers[1]); // free data buffer + free((void**)(uintptr_t)arr->buffers); + } + arr->release = NULL; +} + +static void release_string_array(struct ArrowArray * arr) +{ + if (arr->buffers) + { + free((void*)(uintptr_t)arr->buffers[1]); // free offset buffer + free((void*)(uintptr_t)arr->buffers[2]); // free data buffer + free((void**)(uintptr_t)arr->buffers); + } + arr->release = NULL; +} + +static void release_main_array(struct ArrowArray * arr) +{ + if (arr->children) + { + for (int64_t i = 0; i < arr->n_children; i++) + { + if (arr->children[i] && arr->children[i]->release) + { + arr->children[i]->release(arr->children[i]); + } + free(arr->children[i]); + } + free(arr->children); + } + if (arr->buffers) { + free((void**)(uintptr_t)arr->buffers); + } + arr->release = NULL; +} + +// Helper function to create schema with 2 columns: id(int64), value(string) +static void create_schema(struct ArrowSchema * schema) +{ + schema->format = "+s"; // struct format + schema->name = NULL; + schema->metadata = NULL; + schema->flags = 0; + schema->n_children = 2; + schema->children = (struct ArrowSchema**)malloc(2 * sizeof(struct ArrowSchema*)); + schema->dictionary = NULL; + schema->release = release_schema_main; + + // Field 0: id (int64) + schema->children[0] = (struct ArrowSchema*)malloc(sizeof(struct ArrowSchema)); + schema->children[0]->format = "l"; // int64 + schema->children[0]->name = "id"; + schema->children[0]->metadata = NULL; + schema->children[0]->flags = 0; + schema->children[0]->n_children = 0; + schema->children[0]->children = NULL; + schema->children[0]->dictionary = NULL; + schema->children[0]->release = release_schema_child; + + // Field 1: value (string) + schema->children[1] = (struct ArrowSchema*)malloc(sizeof(struct ArrowSchema)); + schema->children[1]->format = "u"; // utf8 string + schema->children[1]->name = "value"; + schema->children[1]->metadata = NULL; + schema->children[1]->flags = 0; + schema->children[1]->n_children = 0; + schema->children[1]->children = NULL; + schema->children[1]->dictionary = NULL; + schema->children[1]->release = release_schema_child; +} + +// Helper function to create a batch of data +static void create_batch(struct ArrowArray * array, size_t start_row, size_t batch_size) +{ + struct ArrowArray * id_array; + struct ArrowArray * str_array; + int64_t * id_data; + int32_t * offsets; + size_t total_str_len; + char ** strings; + char * str_data; + size_t pos; + size_t i; + + // Main array structure + array->length = batch_size; + array->null_count = 0; + array->offset = 0; + array->n_buffers = 1; + array->n_children = 2; + array->buffers = (const void **)malloc(1 * sizeof(void *)); + array->buffers[0] = NULL; // validity buffer (no nulls) + array->children = (struct ArrowArray **)malloc(2 * sizeof(struct ArrowArray *)); + array->dictionary = NULL; + + // Create id column (int64) + array->children[0] = (struct ArrowArray *)malloc(sizeof(struct ArrowArray)); + id_array = array->children[0]; + id_array->length = batch_size; + id_array->null_count = 0; + id_array->offset = 0; + id_array->n_buffers = 2; + id_array->n_children = 0; + id_array->buffers = (const void **)malloc(2 * sizeof(void *)); + id_array->buffers[0] = NULL; // validity buffer + + // Allocate and fill id data + id_data = (int64_t *)malloc(batch_size * sizeof(int64_t)); + for (i = 0; i < batch_size; i++) + id_data[i] = start_row + i; + + id_array->buffers[1] = id_data; // data buffer + id_array->children = NULL; + id_array->dictionary = NULL; + id_array->release = release_id_array; + + // Create value column (string) + array->children[1] = (struct ArrowArray *)malloc(sizeof(struct ArrowArray)); + str_array = array->children[1]; + str_array->length = batch_size; + str_array->null_count = 0; + str_array->offset = 0; + str_array->n_buffers = 3; + str_array->n_children = 0; + str_array->buffers = (const void **)malloc(3 * sizeof(void *)); + str_array->buffers[0] = NULL; // validity buffer + + // Create offset buffer (int32) + offsets = (int32_t *)malloc((batch_size + 1) * sizeof(int32_t)); + offsets[0] = 0; + + // Calculate total string length and create strings + total_str_len = 0; + strings = (char **)malloc(batch_size * sizeof(char *)); + for (i = 0; i < batch_size; i++) + { + char buffer[64]; + size_t len; + snprintf(buffer, sizeof(buffer), "value_%zu", start_row + i); + len = strlen(buffer); + strings[i] = (char*)malloc(len + 1); + strcpy(strings[i], buffer); + total_str_len += len; + offsets[i + 1] = total_str_len; + } + str_array->buffers[1] = offsets; // offset buffer + + // Create data buffer + str_data = (char*)malloc(total_str_len); + pos = 0; + for (i = 0; i < batch_size; i++) + { + size_t len = strlen(strings[i]); + memcpy(str_data + pos, strings[i], len); + pos += len; + free(strings[i]); + } + free(strings); + str_array->buffers[2] = str_data; // data buffer + + str_array->children = NULL; + str_array->dictionary = NULL; + str_array->release = release_string_array; + + // Main array release function + array->release = release_main_array; +} + +// Callback function to get schema +static int custom_get_schema(struct ArrowArrayStream* stream, struct ArrowSchema* out) +{ + (void)stream; // Suppress unused parameter warning + create_schema(out); + return 0; +} + +// Callback function to get next array +static int custom_get_next(struct ArrowArrayStream * stream, struct ArrowArray * out) +{ + CustomStreamData * data; + size_t remaining_rows; + size_t batch_size; + + data = (CustomStreamData *)stream->private_data; + if (!data) + return EINVAL; + + // Check if we've reached the end of the stream + if (data->current_row >= data->total_rows) + { + // End of stream - set release to NULL to indicate no more data + out->release = NULL; + return 0; + } + + // Calculate batch size for this iteration + remaining_rows = data->total_rows - data->current_row; + batch_size = min_size_t(data->batch_size, remaining_rows); + + // Create the batch + create_batch(out, data->current_row, batch_size); + + data->current_row += batch_size; + return 0; +} + +// Callback function to get last error +static const char * custom_get_last_error(struct ArrowArrayStream * stream) +{ + CustomStreamData * data = (CustomStreamData *)stream->private_data; + if (!data || !data->last_error) + return NULL; + + return data->last_error; +} + +// Callback function to release stream resources +static void custom_release(struct ArrowArrayStream * stream) +{ + if (stream->private_data) + { + CustomStreamData * data = (CustomStreamData *)stream->private_data; + if (data->last_error) + { + free(data->last_error); + } + free(data); + stream->private_data = NULL; + } + stream->release = NULL; +} + +// Helper function to reset the ArrowArrayStream for reuse +static void reset_arrow_stream(struct ArrowArrayStream * stream) +{ + if (stream && stream->private_data) + { + CustomStreamData * data = (CustomStreamData *)stream->private_data; + reset_custom_stream_data(data); + printf("✓ ArrowArrayStream has been reset, ready for re-reading\n"); + } +} + +//===--------------------------------------------------------------------===// +// Unit Test Utilities +//===--------------------------------------------------------------------===// + +static void test_assert(bool condition, const char * test_name, const char * message) +{ + if (condition) + { + printf("✓ PASS: %s\n", test_name); + } + else + { + printf("✗ FAIL: %s", test_name); + if (message && strlen(message) > 0) + { + printf(" - %s", message); + } + printf("\n"); + exit(1); + } +} + +static void test_assert_chdb_state(chdb_state state, const char * operation_name) +{ + char message[256]; + if (state == CHDBError) + { + strcpy(message, "Operation failed"); + } + else + { + strcpy(message, "Unknown state"); + } + + test_assert(state == CHDBSuccess, operation_name, + state == CHDBError ? message : NULL); +} + +static void test_assert_not_null(void * ptr, const char * test_name) +{ + test_assert(ptr != NULL, test_name, "Pointer is null"); +} + +static void test_assert_no_error(chdb_result * result, const char * query_name) +{ + char full_test_name[512]; + const char * error; + + snprintf(full_test_name, sizeof(full_test_name), "%s - Result is not null", query_name); + test_assert_not_null(result, full_test_name); + + error = chdb_result_error(result); + snprintf(full_test_name, sizeof(full_test_name), "%s - No query error", query_name); + + if (error) + { + char error_message[512]; + snprintf(error_message, sizeof(error_message), "Error: %s", error); + test_assert(error == NULL, full_test_name, error_message); + } + else + { + test_assert(error == NULL, full_test_name, NULL); + } +} + +static void test_assert_query_result_contains(chdb_result * result, const char * expected_content, const char * query_name) +{ + char * buffer; + char full_test_name[512]; + bool contains; + + test_assert_no_error(result, query_name); + + buffer = chdb_result_buffer(result); + snprintf(full_test_name, sizeof(full_test_name), "%s - Result buffer is not null", query_name); + test_assert_not_null(buffer, full_test_name); + + snprintf(full_test_name, sizeof(full_test_name), "%s - Result contains expected content", query_name); + + contains = strstr(buffer, expected_content) != NULL; + if (!contains) + { + char error_message[1024]; + snprintf(error_message, sizeof(error_message), "Expected: %s, Actual: %s", expected_content, buffer); + test_assert(contains, full_test_name, error_message); + } + else + { + test_assert(contains, full_test_name, NULL); + } +} + +static void test_assert_row_count(chdb_result * result, uint64_t expected_rows, const char * query_name) +{ + char * buffer; + char full_test_name[512]; + char * result_str; + char * end; + uint64_t actual_rows; + + test_assert_no_error(result, query_name); + + buffer = chdb_result_buffer(result); + snprintf(full_test_name, sizeof(full_test_name), "%s - Result buffer is not null", query_name); + test_assert_not_null(buffer, full_test_name); + + /* Parse the count result (assuming CSV format with just the number) */ + result_str = (char*)malloc(strlen(buffer) + 1); + strcpy(result_str, buffer); + + /* Remove trailing whitespace/newlines */ + end = result_str + strlen(result_str) - 1; + while (end > result_str && (*end == ' ' || *end == '\t' || *end == '\n' || *end == '\r' || *end == '\f' || *end == '\v')) { + *end = '\0'; + end--; + } + + actual_rows = strtoull(result_str, NULL, 10); + + snprintf(full_test_name, sizeof(full_test_name), "%s - Row count matches", query_name); + + if (actual_rows != expected_rows) + { + char error_message[256]; + snprintf(error_message, sizeof(error_message), "Expected: %llu, Actual: %llu", + (unsigned long long)expected_rows, (unsigned long long)actual_rows); + test_assert(actual_rows == expected_rows, full_test_name, error_message); + } + else + { + test_assert(actual_rows == expected_rows, full_test_name, NULL); + } + + free(result_str); +} + +void test_arrow_scan(chdb_connection conn) +{ + struct ArrowArrayStream stream; + struct ArrowArrayStream stream2; + struct ArrowArrayStream stream3; + CustomStreamData * stream_data; + CustomStreamData * stream_data2; + CustomStreamData * stream_data3; + const char* table_name = "test_arrow_table"; + const char* non_exist_table_name = "non_exist_table"; + const char* table_name2 = "test_arrow_table_2"; + const char* table_name3 = "test_arrow_table_3"; + chdb_arrow_stream arrow_stream; + chdb_arrow_stream arrow_stream2; + chdb_arrow_stream arrow_stream3; + chdb_state result; + chdb_result * count_result; + chdb_result * sample_result; + chdb_result * last_result; + chdb_result * count1_result; + chdb_result * count2_result; + chdb_result * count3_result; + chdb_result * join_result; + chdb_result * union_result; + chdb_result * unregister_result; + const char * error; + char error_message[512]; + + printf("\n=== Testing ArrowArrayStream Scan Functions ===\n"); + printf("Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"); + + memset(&stream, 0, sizeof(stream)); + + /* Create and initialize stream data */ + stream_data = (CustomStreamData*)malloc(sizeof(CustomStreamData)); + init_custom_stream_data(stream_data); + + /* Set up the ArrowArrayStream callbacks */ + stream.get_schema = custom_get_schema; + stream.get_next = custom_get_next; + stream.get_last_error = custom_get_last_error; + stream.release = custom_release; + stream.private_data = stream_data; + + printf("✓ ArrowArrayStream initialization completed\n"); + printf("Starting registration with chDB...\n"); + + arrow_stream = (chdb_arrow_stream)&stream; + result = chdb_arrow_scan(conn, table_name, arrow_stream); + + /* Test 1: Verify arrow registration succeeded */ + test_assert_chdb_state(result, "Register ArrowArrayStream to table: test_arrow_table"); + + /* Test 2: Unregister non-existent table should handle gracefully */ + result = chdb_arrow_unregister_table(conn, non_exist_table_name); + test_assert_chdb_state(result, "Unregister non-existent table: non_exist_table"); + + /* Test 3: Count rows - should be exactly 1,000,000 */ + count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_table)", "CSV"); + test_assert_row_count(count_result, 1000000, "Count total rows"); + chdb_destroy_query_result(count_result); + + /* Test 4: Sample first 5 rows - should contain id=0,1,2,3,4 */ + reset_arrow_stream(&stream); + sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) LIMIT 5", "CSV"); + test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); + test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); + chdb_destroy_query_result(sample_result); + + /* Test 5: Sample last 5 rows - should contain id=999999,999998,999997,999996,999995 */ + reset_arrow_stream(&stream); + last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); + test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); + test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); + chdb_destroy_query_result(last_result); + + /* Test 6: Multiple table registration tests */ + /* Create second ArrowArrayStream with different data (500,000 rows) */ + memset(&stream2, 0, sizeof(stream2)); + stream_data2 = (CustomStreamData *)malloc(sizeof(CustomStreamData)); + init_custom_stream_data(stream_data2); + stream_data2->total_rows = 500000; /* Different row count */ + stream_data2->current_row = 0; + stream2.get_schema = custom_get_schema; + stream2.get_next = custom_get_next; + stream2.get_last_error = custom_get_last_error; + stream2.release = custom_release; + stream2.private_data = stream_data2; + + /* Create third ArrowArrayStream with different data (100,000 rows) */ + memset(&stream3, 0, sizeof(stream3)); + stream_data3 = (CustomStreamData *)malloc(sizeof(CustomStreamData)); + init_custom_stream_data(stream_data3); + stream_data3->total_rows = 100000; /* Different row count */ + stream_data3->current_row = 0; + stream3.get_schema = custom_get_schema; + stream3.get_next = custom_get_next; + stream3.get_last_error = custom_get_last_error; + stream3.release = custom_release; + stream3.private_data = stream_data3; + + /* Register second table */ + arrow_stream2 = (chdb_arrow_stream)&stream2; + result = chdb_arrow_scan(conn, table_name2, arrow_stream2); + test_assert_chdb_state(result, "Register second ArrowArrayStream to table: test_arrow_table_2"); + + /* Register third table */ + arrow_stream3 = (chdb_arrow_stream)&stream3; + result = chdb_arrow_scan(conn, table_name3, arrow_stream3); + test_assert_chdb_state(result, "Register third ArrowArrayStream to table: test_arrow_table_3"); + + /* Test 6a: Verify each table has correct row counts */ + reset_arrow_stream(&stream); + count1_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table)", "CSV"); + test_assert_row_count(count1_result, 1000000, "First table row count"); + chdb_destroy_query_result(count1_result); + + reset_arrow_stream(&stream2); + count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_2)", "CSV"); + test_assert_row_count(count2_result, 500000, "Second table row count"); + chdb_destroy_query_result(count2_result); + + reset_arrow_stream(&stream3); + count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_table_3)", "CSV"); + test_assert_row_count(count3_result, 100000, "Third table row count"); + chdb_destroy_query_result(count3_result); + + /* Test 6b: Test cross-table JOIN query */ + reset_arrow_stream(&stream); + reset_arrow_stream(&stream2); + join_result = chdb_query(conn, + "SELECT t1.id, t1.value, t2.value as value2 " + "FROM arrowstream(test_arrow_table) t1 " + "INNER JOIN arrowstream(test_arrow_table_2) t2 ON t1.id = t2.id " + "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); + test_assert_query_result_contains(join_result, "0,\"value_0\",\"value_0\"", "JOIN query contains expected data"); + test_assert_query_result_contains(join_result, "4,\"value_4\",\"value_4\"", "JOIN query contains fifth row"); + chdb_destroy_query_result(join_result); + + /* Test 6c: Test UNION query across multiple tables */ + reset_arrow_stream(&stream2); + reset_arrow_stream(&stream3); + union_result = chdb_query(conn, + "SELECT COUNT(*) FROM (" + "SELECT id FROM arrowstream(test_arrow_table_2) WHERE id < 10 " + "UNION ALL " + "SELECT id FROM arrowstream(test_arrow_table_3) WHERE id < 10" + ")", "CSV"); + test_assert_row_count(union_result, 20, "UNION query row count"); + chdb_destroy_query_result(union_result); + + /* Cleanup additional tables */ + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArrayStream table"); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArrayStream table"); + + /* Test 7: Unregister original table should succeed */ + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArrayStream table: test_arrow_table"); + + /* Test 8: Sample last 5 rows after unregister should fail */ + reset_arrow_stream(&stream); + unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_table) ORDER BY id DESC LIMIT 5", "CSV"); + error = chdb_result_error(unregister_result); + + if (error) + { + snprintf(error_message, sizeof(error_message), "Got expected error: %s", error); + test_assert(error != NULL, "Query after unregister should fail", error_message); + } + else + { + test_assert(error != NULL, "Query after unregister should fail", "No error returned when error was expected"); + } + chdb_destroy_query_result(unregister_result); +} + +// Release function for array children in create_arrow_array +static void release_array_child_id(struct ArrowArray* a) +{ + if (a->buffers) + { + free((void*)(uintptr_t)a->buffers[1]); // id data + free((void**)(uintptr_t)a->buffers); + } + free(a); +} + +// Release function for array children (string) in create_arrow_array +static void release_array_child_string(struct ArrowArray* a) +{ + if (a->buffers) + { + free((void*)(uintptr_t)a->buffers[1]); // offsets + free((void*)(uintptr_t)a->buffers[2]); // string data + free((void**)(uintptr_t)a->buffers); + } + free(a); +} + +// Release function for main array in create_arrow_array +static void release_arrow_array_main(struct ArrowArray * a) +{ + if (a->children) + { + for (int64_t i = 0; i < a->n_children; i++) + { + if (a->children[i] && a->children[i]->release) + { + a->children[i]->release(a->children[i]); + } + } + free(a->children); + } + + if (a->buffers) + { + free((void**)(uintptr_t)a->buffers); + } +} + +// Helper function to create ArrowArray with specified row count +static void create_arrow_array(struct ArrowArray * array, uint64_t row_count) +{ + struct ArrowArray * id_array; + struct ArrowArray * value_array; + int64_t * id_data; + int32_t * offsets; + size_t total_string_size; + char * string_data; + size_t current_pos; + uint64_t i; + + array->length = row_count; + array->null_count = 0; + array->offset = 0; + array->n_buffers = 1; + array->n_children = 2; + array->buffers = (const void**)malloc(1 * sizeof(void*)); + array->buffers[0] = NULL; // validity buffer + + array->children = (struct ArrowArray**)malloc(2 * sizeof(struct ArrowArray*)); + array->dictionary = NULL; + + // Create id column (int64) + array->children[0] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + id_array = array->children[0]; + id_array->length = row_count; + id_array->null_count = 0; + id_array->offset = 0; + id_array->n_buffers = 2; + id_array->n_children = 0; + id_array->children = NULL; + id_array->dictionary = NULL; + + id_array->buffers = (const void**)malloc(2 * sizeof(void*)); + id_array->buffers[0] = NULL; // validity buffer + + // Allocate and populate id data + id_data = (int64_t*)malloc(row_count * sizeof(int64_t)); + for (i = 0; i < row_count; i++) + { + id_data[i] = (int64_t)i; + } + id_array->buffers[1] = id_data; + id_array->release = release_array_child_id; + + // Create value column (string) + array->children[1] = (struct ArrowArray*)malloc(sizeof(struct ArrowArray)); + value_array = array->children[1]; + value_array->length = row_count; + value_array->null_count = 0; + value_array->offset = 0; + value_array->n_buffers = 3; + value_array->n_children = 0; + value_array->children = NULL; + value_array->dictionary = NULL; + + value_array->buffers = (const void**)malloc(3 * sizeof(void*)); + value_array->buffers[0] = NULL; // validity buffer + + // Calculate total string data size and create offset array + offsets = (int32_t*)malloc((row_count + 1) * sizeof(int32_t)); + total_string_size = 0; + offsets[0] = 0; + + for (i = 0; i < row_count; i++) + { + char value_str[64]; + size_t len; + snprintf(value_str, sizeof(value_str), "value_%llu", (unsigned long long)i); + len = strlen(value_str); + total_string_size += len; + offsets[i + 1] = (int32_t)total_string_size; + } + + value_array->buffers[1] = offsets; + + // Allocate and populate string data + string_data = (char *)malloc(total_string_size); + current_pos = 0; + for (i = 0; i < row_count; i++) { + char value_str[64]; + size_t len; + snprintf(value_str, sizeof(value_str), "value_%llu", (unsigned long long)i); + len = strlen(value_str); + memcpy(string_data + current_pos, value_str, len); + current_pos += len; + } + value_array->buffers[2] = string_data; + value_array->release = release_array_child_string; + + // Set release callback for main array + array->release = release_arrow_array_main; +} + +void test_arrow_array_scan(chdb_connection conn) +{ + struct ArrowSchema schema; + struct ArrowArray array; + struct ArrowSchema schema2; + struct ArrowArray array2; + struct ArrowSchema schema3; + struct ArrowArray array3; + const char * table_name = "test_arrow_array_table"; + const char * non_exist_table_name = "non_exist_array_table"; + const char * table_name2 = "test_arrow_array_table_2"; + const char * table_name3 = "test_arrow_array_table_3"; + chdb_arrow_schema arrow_schema; + chdb_arrow_array arrow_array; + chdb_arrow_schema arrow_schema2; + chdb_arrow_array arrow_array2; + chdb_arrow_schema arrow_schema3; + chdb_arrow_array arrow_array3; + chdb_state result; + chdb_result * count_result; + chdb_result * sample_result; + chdb_result * last_result; + chdb_result * count2_result; + chdb_result * count3_result; + chdb_result * join_result; + chdb_result * union_result; + chdb_result * unregister_result; + const char * error; + char error_message[512]; + + printf("\n=== Testing ArrowArray Scan Functions ===\n"); + printf("Data specification: 1,000,000 rows × 2 columns (id: int64, value: string)\n"); + + // Create ArrowSchema (reuse existing function) + create_schema(&schema); + + // Create ArrowArray with 1,000,000 rows + memset(&array, 0, sizeof(array)); + create_arrow_array(&array, 1000000); + + printf("✓ ArrowArray initialization completed\n"); + printf("Starting registration with chDB...\n"); + + arrow_schema = (chdb_arrow_schema)&schema; + arrow_array = (chdb_arrow_array)&array; + + // Test 1: Register -> Query -> Unregister for row count + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray to table: test_arrow_array_table"); + + count_result = chdb_query(conn, "SELECT COUNT(*) as total_rows FROM arrowstream(test_arrow_array_table)", "CSV"); + test_assert_row_count(count_result, 1000000, "Count total rows"); + chdb_destroy_query_result(count_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after count query"); + + // Test 2: Unregister non-existent table should handle gracefully + result = chdb_arrow_unregister_table(conn, non_exist_table_name); + test_assert_chdb_state(result, "Unregister non-existent array table: non_exist_array_table"); + + // Test 3: Register -> Query -> Unregister for first 5 rows + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray for sample query"); + + sample_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) LIMIT 5", "CSV"); + test_assert_query_result_contains(sample_result, "0,\"value_0\"", "First 5 rows contain first row"); + test_assert_query_result_contains(sample_result, "4,\"value_4\"", "First 5 rows contain fifth row"); + chdb_destroy_query_result(sample_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after sample query"); + + // Test 4: Register -> Query -> Unregister for last 5 rows + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register ArrowArray for last rows query"); + + last_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); + test_assert_query_result_contains(last_result, "999999,\"value_999999\"", "Last 5 rows contain last row"); + test_assert_query_result_contains(last_result, "999995,\"value_999995\"", "Last 5 rows contain fifth row"); + chdb_destroy_query_result(last_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister ArrowArray table after last rows query"); + + // Test 5: Independent multiple table tests + // Create second ArrowArray with different data (500,000 rows) + create_schema(&schema2); + memset(&array2, 0, sizeof(array2)); + create_arrow_array(&array2, 500000); + + // Create third ArrowArray with different data (100,000 rows) + create_schema(&schema3); + memset(&array3, 0, sizeof(array3)); + create_arrow_array(&array3, 100000); + + arrow_schema2 = (chdb_arrow_schema)&schema2; + arrow_array2 = (chdb_arrow_array)&array2; + arrow_schema3 = (chdb_arrow_schema)&schema3; + arrow_array3 = (chdb_arrow_array)&array3; + + // Test 5a: Register -> Query -> Unregister for second table (500K rows) + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray to table: test_arrow_array_table_2"); + + count2_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_2)", "CSV"); + test_assert_row_count(count2_result, 500000, "Second array table row count"); + chdb_destroy_query_result(count2_result); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray table"); + + // Test 5b: Register -> Query -> Unregister for third table (100K rows) + result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); + test_assert_chdb_state(result, "Register third ArrowArray to table: test_arrow_array_table_3"); + + count3_result = chdb_query(conn, "SELECT COUNT(*) FROM arrowstream(test_arrow_array_table_3)", "CSV"); + test_assert_row_count(count3_result, 100000, "Third array table row count"); + chdb_destroy_query_result(count3_result); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArray table"); + + // Test 6: Cross-table JOIN query (Register both -> Query -> Unregister both) + result = chdb_arrow_array_scan(conn, table_name, arrow_schema, arrow_array); + test_assert_chdb_state(result, "Register first ArrowArray for JOIN"); + + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray for JOIN"); + + join_result = chdb_query(conn, + "SELECT t1.id, t1.value, t2.value as value2 " + "FROM arrowstream(test_arrow_array_table) t1 " + "INNER JOIN arrowstream(test_arrow_array_table_2) t2 ON t1.id = t2.id " + "WHERE t1.id < 5 ORDER BY t1.id", "CSV"); + test_assert_query_result_contains(join_result, "0,\"value_0\",\"value_0\"", "Array JOIN query contains expected data"); + test_assert_query_result_contains(join_result, "4,\"value_4\",\"value_4\"", "Array JOIN query contains fifth row"); + chdb_destroy_query_result(join_result); + + result = chdb_arrow_unregister_table(conn, table_name); + test_assert_chdb_state(result, "Unregister first ArrowArray after JOIN"); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray after JOIN"); + + // Test 7: Cross-table UNION query (Register both -> Query -> Unregister both) + result = chdb_arrow_array_scan(conn, table_name2, arrow_schema2, arrow_array2); + test_assert_chdb_state(result, "Register second ArrowArray for UNION"); + + result = chdb_arrow_array_scan(conn, table_name3, arrow_schema3, arrow_array3); + test_assert_chdb_state(result, "Register third ArrowArray for UNION"); + + union_result = chdb_query(conn, + "SELECT COUNT(*) FROM (" + "SELECT id FROM arrowstream(test_arrow_array_table_2) WHERE id < 10 " + "UNION ALL " + "SELECT id FROM arrowstream(test_arrow_array_table_3) WHERE id < 10" + ")", "CSV"); + test_assert_row_count(union_result, 20, "Array UNION query row count"); + chdb_destroy_query_result(union_result); + + result = chdb_arrow_unregister_table(conn, table_name2); + test_assert_chdb_state(result, "Unregister second ArrowArray after UNION"); + + result = chdb_arrow_unregister_table(conn, table_name3); + test_assert_chdb_state(result, "Unregister third ArrowArray after UNION"); + + // Test 8: Query after unregister should fail + unregister_result = chdb_query(conn, "SELECT * FROM arrowstream(test_arrow_array_table) ORDER BY id DESC LIMIT 5", "CSV"); + error = chdb_result_error(unregister_result); + + if (error) + { + snprintf(error_message, sizeof(error_message), "Got expected error: %s", error); + test_assert(error != NULL, "Array query after unregister should fail", error_message); + } + else + { + test_assert(error != NULL, "Array query after unregister should fail", "No error returned when error was expected"); + } + chdb_destroy_query_result(unregister_result); + + // Cleanup ArrowArrays and schemas + if (array.release) array.release(&array); + if (schema.release) schema.release(&schema); + if (array2.release) array2.release(&array2); + if (schema2.release) schema2.release(&schema2); + if (array3.release) array3.release(&array3); + if (schema3.release) schema3.release(&schema3); +} + +int main(void) +{ + char * argv[] = {"clickhouse", "--multiquery"}; + int argc = sizeof(argv) / sizeof(argv[0]); + chdb_connection * conn_ptr; + chdb_connection conn; + + printf("=== chDB Arrow Functions Test ===\n"); + + /* Create connection */ + conn_ptr = chdb_connect(argc, argv); + if (!conn_ptr || !*conn_ptr) { + printf("Failed to create chDB connection\n"); + exit(1); + } + + conn = *conn_ptr; + printf("✓ chDB connection established\n"); + + /* Run test suites */ + test_arrow_scan(conn); + test_arrow_array_scan(conn); + + /* Clean up */ + chdb_close_conn(conn_ptr); + + printf("\n=== chDB Arrow Functions Test Completed ===\n"); + + return 0; +} diff --git a/examples/runArrowTest.sh b/examples/runArrowTest.sh new file mode 100755 index 00000000000..7696f361162 --- /dev/null +++ b/examples/runArrowTest.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -e + +CFLAGS="-std=c99" + +# check current os type, and make ldd command +if [ "$(uname)" == "Darwin" ]; then + LDD="otool -L" + LIB_PATH="DYLD_LIBRARY_PATH" +elif [ "$(uname)" == "Linux" ]; then + LDD="ldd" + LIB_PATH="LD_LIBRARY_PATH" +else + echo "OS not supported" + exit 1 +fi + +# cd to the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +cd $DIR + +echo "Compile and link chdbArrowTest (C version)" +clang $CFLAGS chdbArrowTest.c -o chdbArrowTest \ + -I../programs/local/ \ + -I../contrib/arrow/cpp/src \ + -I../contrib/arrow-cmake/cpp/src \ + -I../src \ + -L../ -lchdb + +export ${LIB_PATH}=.. +${LDD} chdbArrowTest + +echo "Run Arrow API tests (C version):" +./chdbArrowTest diff --git a/programs/local/ArrowScanState.h b/programs/local/ArrowScanState.h new file mode 100644 index 00000000000..98501498f57 --- /dev/null +++ b/programs/local/ArrowScanState.h @@ -0,0 +1,34 @@ +#pragma once + +#include "ArrowStreamWrapper.h" + +#include +#include + +namespace CHDB +{ + +/// Scan state for each stream - shared between ArrowTableReader and ArrowStreamReader +struct ArrowScanState +{ + /// Current Arrow array being processed (for ArrowTableReader) + std::unique_ptr current_array; + /// Current offset within the array + size_t current_offset = 0; + /// Whether this stream is exhausted + bool exhausted = false; + /// Cached imported RecordBatch to avoid repeated imports + std::shared_ptr cached_record_batch; + + virtual ~ArrowScanState() = default; + + virtual void reset() + { + current_array.reset(); + current_offset = 0; + exhausted = false; + cached_record_batch.reset(); + } +}; + +} // namespace CHDB diff --git a/programs/local/ArrowSchema.cpp b/programs/local/ArrowSchema.cpp new file mode 100644 index 00000000000..8e0b3f21d0f --- /dev/null +++ b/programs/local/ArrowSchema.cpp @@ -0,0 +1,89 @@ +#include "ArrowSchema.h" + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +} + +using namespace DB; + +namespace CHDB +{ + +ArrowSchemaWrapper::~ArrowSchemaWrapper() +{ + if (arrow_schema.release != nullptr) + { + arrow_schema.release(&arrow_schema); + chassert(!arrow_schema.release); + } +} + +ArrowSchemaWrapper::ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept + : arrow_schema(other.arrow_schema) +{ + other.arrow_schema.release = nullptr; +} + +ArrowSchemaWrapper & ArrowSchemaWrapper::operator=(ArrowSchemaWrapper && other) noexcept +{ + if (this != &other) + { + if (arrow_schema.release) + { + arrow_schema.release(&arrow_schema); + } + arrow_schema = other.arrow_schema; + other.arrow_schema.release = nullptr; + } + return *this; +} + +void ArrowSchemaWrapper::convertArrowSchema( + ArrowSchemaWrapper & schema, + NamesAndTypesList & names_and_types, + ContextPtr & context) +{ + if (!schema.arrow_schema.release) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowSchema is already released"); + } + + /// Import ArrowSchema to arrow::Schema + auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); + if (!arrow_schema_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to import Arrow schema: {}", arrow_schema_result.status().message()); + } + + const auto & arrow_schema = arrow_schema_result.ValueOrDie(); + + const auto format_settings = getFormatSettings(context); + + /// Convert Arrow schema to ClickHouse header + auto block = ArrowColumnToCHColumn::arrowSchemaToCHHeader( + *arrow_schema, + nullptr, + "Arrow", + format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference, + format_settings.schema_inference_make_columns_nullable != 0, + false, + format_settings.parquet.allow_geoparquet_parser); + + for (const auto & column : block) + { + names_and_types.emplace_back(column.name, column.type); + } +} + +} // namespace CHDB diff --git a/programs/local/ArrowSchema.h b/programs/local/ArrowSchema.h new file mode 100644 index 00000000000..5e6720386ae --- /dev/null +++ b/programs/local/ArrowSchema.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +namespace CHDB +{ + +/// Wrapper for Arrow C Data Interface structures with RAII resource management +class ArrowSchemaWrapper +{ +public: + ArrowSchema arrow_schema; + + ArrowSchemaWrapper() { + arrow_schema.release = nullptr; + } + + ~ArrowSchemaWrapper(); + + /// Non-copyable but moveable + ArrowSchemaWrapper(const ArrowSchemaWrapper &) = delete; + ArrowSchemaWrapper & operator=(const ArrowSchemaWrapper &) = delete; + ArrowSchemaWrapper(ArrowSchemaWrapper && other) noexcept; + ArrowSchemaWrapper & operator=(ArrowSchemaWrapper && other) noexcept; + + static void convertArrowSchema( + ArrowSchemaWrapper & schema, + DB::NamesAndTypesList & names_and_types, + DB::ContextPtr & context); +}; + +} // namespace CHDB diff --git a/programs/local/ArrowStreamRegistry.h b/programs/local/ArrowStreamRegistry.h new file mode 100644 index 00000000000..c1fb465999e --- /dev/null +++ b/programs/local/ArrowStreamRegistry.h @@ -0,0 +1,100 @@ +#pragma once + +#include "chdb-internal.h" + +#include +#include +#include +#include + +#include + +struct ArrowArrayStream; + +namespace CHDB +{ + +class ArrowStreamRegistry +{ +public: + struct ArrowStreamInfo + { + ArrowArrayStream * stream = nullptr; + bool is_owner = false; + }; + +private: + std::unordered_map registered_streams; + mutable std::shared_mutex registry_mutex; + +public: + static ArrowStreamRegistry & instance() + { + static ArrowStreamRegistry instance; + return instance; + } + + bool registerArrowStream(const String & name, ArrowArrayStream * arrow_stream, bool is_owner) + { + std::unique_lock lock(registry_mutex); + + ArrowStreamInfo info; + info.stream = arrow_stream; + info.is_owner = is_owner; + + auto [iter, inserted] = registered_streams.emplace(name, std::move(info)); + return inserted; + } + + std::optional getArrowStream(const String & name) const + { + std::shared_lock lock(registry_mutex); + auto it = registered_streams.find(name); + if (it != registered_streams.end()) + return it->second; + return {}; + } + + bool unregisterArrowStream(const String & name) + { + std::unique_lock lock(registry_mutex); + auto it = registered_streams.find(name); + if (it != registered_streams.end()) + { + if (it->second.is_owner && it->second.stream) + { + /// Clean up owned Arrow stream + chdb_destroy_arrow_stream(it->second.stream); + } + registered_streams.erase(it); + return true; + } + return false; + } + + std::vector listRegisteredNames() const + { + std::shared_lock lock(registry_mutex); + std::vector names; + names.reserve(registered_streams.size()); + + for (const auto& [name, info] : registered_streams) + names.push_back(name); + + return names; + } + + size_t size() const + { + std::shared_lock lock(registry_mutex); + return registered_streams.size(); + } + + void clear() + { + std::unique_lock lock(registry_mutex); + registered_streams.clear(); + } +}; + +} diff --git a/programs/local/ArrowStreamSource.cpp b/programs/local/ArrowStreamSource.cpp new file mode 100644 index 00000000000..c9545f8205c --- /dev/null +++ b/programs/local/ArrowStreamSource.cpp @@ -0,0 +1,51 @@ +#include "ArrowStreamSource.h" +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +ArrowStreamSource::ArrowStreamSource( + const Block & sample_block_, + CHDB::ArrowTableReaderPtr arrow_table_reader_, + size_t stream_index_) + : ISource(sample_block_.cloneEmpty()) + , arrow_table_reader(arrow_table_reader_) + , sample_block(sample_block_) + , stream_index(stream_index_) +{ +} + +Chunk ArrowStreamSource::generate() +{ + chassert(arrow_table_reader); + + if (sample_block.getNames().empty()) + return {}; + + try + { + auto chunk = arrow_table_reader->readNextChunk(stream_index); + return chunk; + } + catch (const Exception &) + { + throw; + } + catch (const std::exception & e) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowStreamSource error: {}", e.what()); + } + catch (...) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowStreamSource unknown exception"); + } +} + +} diff --git a/programs/local/ArrowStreamSource.h b/programs/local/ArrowStreamSource.h new file mode 100644 index 00000000000..2d5055f394f --- /dev/null +++ b/programs/local/ArrowStreamSource.h @@ -0,0 +1,31 @@ +#pragma once + +#include "ArrowTableReader.h" + +#include +#include +#include + +namespace DB +{ + +class ArrowStreamSource : public ISource +{ +public: + ArrowStreamSource( + const Block & sample_block_, + CHDB::ArrowTableReaderPtr arrow_table_reader_, + size_t stream_index_); + + String getName() const override { return "ArrowStream"; } + + Chunk generate() override; + +private: + CHDB::ArrowTableReaderPtr arrow_table_reader; + Block sample_block; + size_t stream_index; + Poco::Logger * logger = &Poco::Logger::get("ArrowStreamSource"); +}; + +} diff --git a/programs/local/ArrowStreamWrapper.cpp b/programs/local/ArrowStreamWrapper.cpp new file mode 100644 index 00000000000..a1c834a49b1 --- /dev/null +++ b/programs/local/ArrowStreamWrapper.cpp @@ -0,0 +1,139 @@ +#include "ArrowStreamWrapper.h" + +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +} + +using namespace DB; + +namespace CHDB +{ + +/// ArrowArrayWrapper implementation +ArrowArrayWrapper::~ArrowArrayWrapper() +{ + if (arrow_array.release) + { + arrow_array.release(&arrow_array); + } +} + +ArrowArrayWrapper::ArrowArrayWrapper(ArrowArrayWrapper && other) noexcept + : arrow_array(other.arrow_array) +{ + other.arrow_array.release = nullptr; +} + +ArrowArrayWrapper & ArrowArrayWrapper::operator=(ArrowArrayWrapper && other) noexcept +{ + if (this != &other) + { + if (arrow_array.release) + { + arrow_array.release(&arrow_array); + } + arrow_array = other.arrow_array; + other.arrow_array.release = nullptr; + } + return *this; +} + +/// ArrowArrayStreamWrapper implementation +ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() +{ + if (should_release_on_destroy && arrow_array_stream.release) + { + arrow_array_stream.release(&arrow_array_stream); + } +} + +ArrowArrayStreamWrapper::ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept + : arrow_array_stream(other.arrow_array_stream) + , should_release_on_destroy(other.should_release_on_destroy) +{ + other.arrow_array_stream.release = nullptr; + other.should_release_on_destroy = true; +} + +ArrowArrayStreamWrapper & ArrowArrayStreamWrapper::operator=(ArrowArrayStreamWrapper && other) noexcept +{ + if (this != &other) + { + if (should_release_on_destroy && arrow_array_stream.release) + { + arrow_array_stream.release(&arrow_array_stream); + } + arrow_array_stream = other.arrow_array_stream; + should_release_on_destroy = other.should_release_on_destroy; + other.arrow_array_stream.release = nullptr; + other.should_release_on_destroy = true; + } + return *this; +} + +void ArrowArrayStreamWrapper::getSchema(ArrowSchemaWrapper& schema) +{ + if (!isValid()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowArrayStream is not valid"); + } + + if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema) != 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to get schema from ArrowArrayStream: {}", getError()); + } + + if (!schema.arrow_schema.release) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Released schema returned from ArrowArrayStream"); + } +} + +std::unique_ptr ArrowArrayStreamWrapper::getNextChunk() +{ + chassert(isValid()); + + auto chunk = std::make_unique(); + + /// Get next non-empty chunk, skipping empty ones + do + { + chunk->reset(); + if (arrow_array_stream.get_next(&arrow_array_stream, &chunk->arrow_array) != 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to get next chunk from ArrowArrayStream: {}", getError()); + } + + /// Check if we've reached the end of the stream + if (!chunk->arrow_array.release) + { + return nullptr; + } + } + while (chunk->arrow_array.length == 0); + + return chunk; +} + +const char* ArrowArrayStreamWrapper::getError() +{ + if (!isValid()) + { + return "ArrowArrayStream is not valid"; + } + + return arrow_array_stream.get_last_error(&arrow_array_stream); +} + +} // namespace CHDB diff --git a/programs/local/ArrowStreamWrapper.h b/programs/local/ArrowStreamWrapper.h new file mode 100644 index 00000000000..0eb5c229d51 --- /dev/null +++ b/programs/local/ArrowStreamWrapper.h @@ -0,0 +1,76 @@ +#pragma once + +#include "ArrowSchema.h" + +#include +#include + +namespace CHDB +{ + +class ArrowArrayWrapper +{ +public: + ArrowArray arrow_array; + + ArrowArrayWrapper() + { + reset(); + } + + ~ArrowArrayWrapper(); + + void reset() + { + arrow_array.length = 0; + arrow_array.release = nullptr; + } + + /// Non-copyable but moveable + ArrowArrayWrapper(const ArrowArrayWrapper &) = delete; + ArrowArrayWrapper & operator=(const ArrowArrayWrapper &) = delete; + ArrowArrayWrapper(ArrowArrayWrapper && other) noexcept; + ArrowArrayWrapper & operator=(ArrowArrayWrapper && other) noexcept; +}; + +class ArrowArrayStreamWrapper +{ +public: + ArrowArrayStream arrow_array_stream; + + explicit ArrowArrayStreamWrapper(bool should_release = true) + : should_release_on_destroy(should_release) { + arrow_array_stream.release = nullptr; + } + + ~ArrowArrayStreamWrapper(); + + /// Non-copyable but moveable + ArrowArrayStreamWrapper(const ArrowArrayStreamWrapper&) = delete; + ArrowArrayStreamWrapper& operator=(const ArrowArrayStreamWrapper&) = delete; + ArrowArrayStreamWrapper(ArrowArrayStreamWrapper&& other) noexcept; + ArrowArrayStreamWrapper& operator=(ArrowArrayStreamWrapper&& other) noexcept; + + /// Get schema from the stream + void getSchema(ArrowSchemaWrapper & schema); + + /// Get next chunk from the stream + std::unique_ptr getNextChunk(); + + /// Get last error message + const char* getError(); + + /// Check if stream is valid + bool isValid() const { return arrow_array_stream.release != nullptr; } + + /// Set whether to release on destruction + void setShouldRelease(bool should_release) { should_release_on_destroy = should_release; } + + /// Get whether will release on destruction + bool getShouldRelease() const { return should_release_on_destroy; } + +private: + bool should_release_on_destroy = true; +}; + +} // namespace CHDB diff --git a/programs/local/ArrowTableReader.cpp b/programs/local/ArrowTableReader.cpp new file mode 100644 index 00000000000..8e5d31a9cd1 --- /dev/null +++ b/programs/local/ArrowTableReader.cpp @@ -0,0 +1,197 @@ +#include "ArrowTableReader.h" + +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} + +} + +using namespace DB; + +namespace CHDB +{ + +ArrowTableReader::ArrowTableReader( + std::unique_ptr arrow_stream_, + const DB::Block & sample_block_, + const DB::FormatSettings & format_settings_, + size_t num_streams_, + size_t max_block_size_) + : sample_block(sample_block_), + format_settings(format_settings_), + arrow_stream(std::move(arrow_stream_)), + num_streams(num_streams_), + max_block_size(max_block_size_), + scan_states(num_streams_) +{ + initializeStream(); +} + +void ArrowTableReader::initializeStream() +{ + if (!arrow_stream || !arrow_stream->isValid()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "ArrowArrayStream is not valid"); + } + + /// Get schema from stream + arrow_stream->getSchema(schema); + auto arrow_schema_result = arrow::ImportSchema(&schema.arrow_schema); + if (!arrow_schema_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to import Arrow schema during initialization: {}", arrow_schema_result.status().message()); + } + cached_arrow_schema = arrow_schema_result.ValueOrDie(); +} + +Chunk ArrowTableReader::readNextChunk(size_t stream_index) +{ + if (stream_index >= num_streams) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Stream index {} is out of range [0, {})", stream_index, num_streams); + } + + auto & state = scan_states[stream_index]; + + if (state.exhausted) + { + return {}; + } + + try + { + /// If we don't have a current array or it's exhausted, get the next one + if (!state.current_array || state.current_offset >= static_cast(state.current_array->arrow_array.length)) + { + auto next_array = getNextArrowArray(); + if (!next_array) + { + state.exhausted = true; + return {}; + } + state.current_array = std::move(next_array); + state.current_offset = 0; + state.cached_record_batch.reset(); + } + + /// Calculate how many rows to read from current array + size_t available_rows = static_cast(state.current_array->arrow_array.length) - state.current_offset; + size_t rows_to_read = std::min(max_block_size, available_rows); + + /// Convert the slice to chunk + auto chunk = convertArrowArrayToChunk(*state.current_array, state.current_offset, rows_to_read, stream_index); + + /// Update offset + state.current_offset += rows_to_read; + + return chunk; + } + catch (const Exception &) + { + state.exhausted = true; + throw; + } +} + +std::unique_ptr ArrowTableReader::getNextArrowArray() +{ + std::lock_guard lock(stream_mutex); + + if (global_stream_exhausted || !arrow_stream || !arrow_stream->isValid()) + { + return nullptr; + } + + try + { + auto arrow_array = arrow_stream->getNextChunk(); + + if (!arrow_array || arrow_array->arrow_array.length == 0) + { + global_stream_exhausted = true; + return nullptr; + } + + return arrow_array; + } + catch (const Exception &) + { + global_stream_exhausted = true; + throw; + } +} + +Chunk ArrowTableReader::convertArrowArrayToChunk(const ArrowArrayWrapper & arrow_array_wrapper, size_t offset, size_t count, size_t stream_index) +{ + chassert(arrow_array_wrapper.arrow_array.length && count && offset < static_cast(arrow_array_wrapper.arrow_array.length)); + chassert(count <= static_cast(arrow_array_wrapper.arrow_array.length) - offset); + chassert(stream_index < num_streams); + + auto & state = scan_states[stream_index]; + std::shared_ptr record_batch; + + /// Check if we have a cached RecordBatch for this ArrowArray + if (!state.cached_record_batch) + { + /// Import the full ArrowArray to RecordBatch and cache it + ArrowArray array_copy = arrow_array_wrapper.arrow_array; + + /// Set a dummy release function to prevent Arrow from freeing the underlying data + static auto dummy_release = [](ArrowArray* array) + { + // No-op: ArrowArrayWrapper will handle the actual cleanup + // But we must set release to nullptr to follow Arrow C ABI convention + array->release = nullptr; + }; + array_copy.release = dummy_release; + + /// Import the full Arrow array to Arrow RecordBatch + auto arrow_batch_result = arrow::ImportRecordBatch(&array_copy, cached_arrow_schema); + if (!arrow_batch_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to import Arrow RecordBatch: {}", arrow_batch_result.status().message()); + } + + state.cached_record_batch = arrow_batch_result.ValueOrDie(); + } + + /// Use the cached RecordBatch and slice it + record_batch = state.cached_record_batch; + auto sliced_batch = record_batch->Slice(offset, count); + auto table_result = arrow::Table::FromRecordBatches({sliced_batch}); + if (!table_result.ok()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to create Arrow table from RecordBatch: {}", table_result.status().ToString()); + } + const auto & arrow_table = table_result.ValueOrDie(); + + /// Use ArrowColumnToCHColumn to convert the batch + ArrowColumnToCHColumn converter( + sample_block, + "Arrow", + format_settings.arrow.allow_missing_columns, + format_settings.null_as_default, + format_settings.date_time_overflow_behavior, + format_settings.parquet.allow_geoparquet_parser, + format_settings.arrow.case_insensitive_column_matching, + false + ); + + return converter.arrowTableToCHChunk(arrow_table, sliced_batch->num_rows(), nullptr); +} + +} // namespace CHDB diff --git a/programs/local/ArrowTableReader.h b/programs/local/ArrowTableReader.h new file mode 100644 index 00000000000..c2a29fc82c4 --- /dev/null +++ b/programs/local/ArrowTableReader.h @@ -0,0 +1,64 @@ +#pragma once + +#include "ArrowScanState.h" +#include "ArrowStreamWrapper.h" + +#include +#include +#include +#include + +namespace CHDB +{ + +class ArrowTableReader; +using ArrowTableReaderPtr = std::shared_ptr; + +class ArrowTableReader +{ +public: + ArrowTableReader( + std::unique_ptr arrow_stream_, + const DB::Block & sample_block_, + const DB::FormatSettings & format_settings_, + size_t num_streams_, + size_t max_block_size_); + + ~ArrowTableReader() = default; + + /// Read next chunk from the specified stream + DB::Chunk readNextChunk(size_t stream_index); + +private: + /// Initialize the Arrow stream from ArrowArrayStreamWrapper + void initializeStream(); + + /// Convert Arrow array slice to ClickHouse chunk + DB::Chunk convertArrowArrayToChunk(const ArrowArrayWrapper & arrow_array, size_t offset, size_t count, size_t stream_index); + + /// Get next Arrow array from stream + std::unique_ptr getNextArrowArray(); + + DB::Block sample_block; + DB::FormatSettings format_settings; + std::unique_ptr arrow_stream; + ArrowSchemaWrapper schema; + + /// Cached Arrow schema to avoid repeated imports + std::shared_ptr cached_arrow_schema; + + /// Multi-stream scanning parameters + size_t num_streams; + size_t max_block_size; + + /// Scan states for each stream + std::vector scan_states; + + /// Global stream state + bool global_stream_exhausted = false; + + /// Mutex for thread-safe access to arrow_stream + mutable std::mutex stream_mutex; +}; + +} // namespace CHDB diff --git a/programs/local/CMakeLists.txt b/programs/local/CMakeLists.txt index 83095fe2dd0..f84770e6392 100644 --- a/programs/local/CMakeLists.txt +++ b/programs/local/CMakeLists.txt @@ -1,8 +1,21 @@ set (CLICKHOUSE_LOCAL_SOURCES chdb.cpp + ArrowSchema.cpp + ArrowStreamWrapper.cpp + ArrowTableReader.cpp LocalServer.cpp ) +if (NOT USE_PYTHON) + set (CHDB_ARROW_SOURCES + chdb-arrow.cpp + ArrowStreamSource.cpp + StorageArrowStream.cpp + TableFunctionArrowStream.cpp + ) + set (CLICKHOUSE_LOCAL_SOURCES ${CLICKHOUSE_LOCAL_SOURCES} ${CHDB_ARROW_SOURCES}) +endif() + # Add force function references only for static library builds if (CHDB_STATIC_LIBRARY_BUILD) list(APPEND CLICKHOUSE_LOCAL_SOURCES ForceFunctionReferences.cpp) @@ -20,6 +33,8 @@ if (USE_PYTHON) PandasAnalyzer.cpp PandasDataFrame.cpp PandasScan.cpp + PyArrowStreamFactory.cpp + PyArrowTable.cpp PybindWrapper.cpp PythonConversion.cpp PythonDict.cpp @@ -117,6 +132,9 @@ endif() if (TARGET ch_contrib::utf8proc) target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::utf8proc) endif() +if (TARGET ch_contrib::arrow) + target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::arrow) +endif() if (TARGET ch_contrib::pybind11_stubs) target_link_libraries(clickhouse-local-lib PRIVATE ch_contrib::pybind11_stubs) target_compile_definitions(clickhouse-local-lib PRIVATE Py_LIMITED_API=0x03080000) diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index edf3f67ad20..c7b6026fea8 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -2,10 +2,13 @@ #include "chdb-internal.h" #if USE_PYTHON +#include "StoragePython.h" #include "TableFunctionPython.h" -#include -#include +#else +#include "StorageArrowStream.h" +#include "TableFunctionArrowStream.h" #endif +#include #include #include @@ -645,17 +648,22 @@ try registerAggregateFunctions(); registerTableFunctions(); -#if USE_PYTHON + auto & table_function_factory = TableFunctionFactory::instance(); +#if USE_PYTHON registerTableFunctionPython(table_function_factory); +#else + registerTableFunctionArrowStream(table_function_factory); #endif registerDatabases(); registerStorages(); -#if USE_PYTHON auto & storage_factory = StorageFactory::instance(); +#if USE_PYTHON registerStoragePython(storage_factory); +#else + registerStorageArrowStream(storage_factory); #endif registerDictionaries(); diff --git a/programs/local/PandasDataFrame.cpp b/programs/local/PandasDataFrame.cpp index e6841ce3937..c304dabbf0a 100644 --- a/programs/local/PandasDataFrame.cpp +++ b/programs/local/PandasDataFrame.cpp @@ -22,13 +22,6 @@ using namespace DB; namespace CHDB { -template -static bool ModuleIsLoaded() -{ - auto dict = pybind11::module_::import("sys").attr("modules"); - return dict.contains(py::str(T::Name)); -} - struct PandasBindColumn { public: PandasBindColumn(py::handle name, py::handle type, py::object column) @@ -92,6 +85,8 @@ static DataTypePtr inferDataTypeFromPandasColumn(PandasBindColumn & column, Cont ColumnsDescription PandasDataFrame::getActualTableStructure(const py::object & object, ContextPtr & context) { + chassert(py::gil_check()); + NamesAndTypesList names_and_types; PandasDataFrameBind df(object); @@ -116,6 +111,8 @@ ColumnsDescription PandasDataFrame::getActualTableStructure(const py::object & o bool PandasDataFrame::isPandasDataframe(const py::object & object) { + chassert(py::gil_check()); + if (!ModuleIsLoaded()) return false; diff --git a/programs/local/PyArrowCacheItem.h b/programs/local/PyArrowCacheItem.h new file mode 100644 index 00000000000..494cf149870 --- /dev/null +++ b/programs/local/PyArrowCacheItem.h @@ -0,0 +1,47 @@ +#pragma once + +#include "PythonImportCacheItem.h" + +namespace CHDB +{ + +struct PyarrowIpcCacheItem : public PythonImportCacheItem +{ + explicit PyarrowIpcCacheItem(PythonImportCacheItem * parent) + : PythonImportCacheItem("ipc", parent), message_reader("MessageReader", this) + {} + ~PyarrowIpcCacheItem() override = default; + + PythonImportCacheItem message_reader; +}; + +struct PyarrowDatasetCacheItem : public PythonImportCacheItem +{ + static constexpr const char * Name = "pyarrow.dataset"; + + PyarrowDatasetCacheItem() + : PythonImportCacheItem("pyarrow.dataset"), scanner("Scanner", this), dataset("Dataset", this) + {} + ~PyarrowDatasetCacheItem() override = default; + + PythonImportCacheItem scanner; + PythonImportCacheItem dataset; +}; + +struct PyarrowCacheItem : public PythonImportCacheItem +{ + static constexpr const char * Name = "pyarrow"; + + PyarrowCacheItem() + : PythonImportCacheItem("pyarrow"), dataset(), table("Table", this), + record_batch_reader("RecordBatchReader", this), ipc(this) + {} + ~PyarrowCacheItem() override = default; + + PyarrowDatasetCacheItem dataset; + PythonImportCacheItem table; + PythonImportCacheItem record_batch_reader; + PyarrowIpcCacheItem ipc; +}; + +} // namespace CHDB diff --git a/programs/local/PyArrowStreamFactory.cpp b/programs/local/PyArrowStreamFactory.cpp new file mode 100644 index 00000000000..0272985194e --- /dev/null +++ b/programs/local/PyArrowStreamFactory.cpp @@ -0,0 +1,113 @@ +#include "PyArrowStreamFactory.h" +#include "PyArrowTable.h" +#include "PythonImporter.h" + +#include +#include + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int PY_EXCEPTION_OCCURED; +} + +} + +using namespace DB; +namespace py = pybind11; + +namespace CHDB +{ + +std::unique_ptr PyArrowStreamFactory::createFromPyObject( + py::object & py_obj, + const Names & column_names) +{ + chassert(py::gil_check()); + + try + { + auto arrow_object_type = PyArrowTable::getArrowType(py_obj); + + switch (arrow_object_type) + { + case PyArrowObjectType::Table: + return createFromTable(py_obj, column_names); + default: + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Unsupported PyArrow object type: {}", arrow_object_type); + } + } + catch (const py::error_already_set & e) + { + throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, + "Failed to convert PyArrow object to arrow array stream: {}", e.what()); + } +} + +std::unique_ptr PyArrowStreamFactory::createFromTable( + py::object & table, + const Names & column_names) +{ + chassert(py::gil_check()); + + py::handle table_handle(table); + auto & import_cache = PythonImporter::ImportCache(); + auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset"); + + auto dataset = arrow_dataset(table_handle); + py::object arrow_scanner = dataset.attr("__class__").attr("scanner"); + + py::dict kwargs; + if (!column_names.empty()) { + ArrowSchemaWrapper schema; + auto obj_schema = table_handle.attr("schema"); + auto export_to_c = obj_schema.attr("_export_to_c"); + export_to_c(reinterpret_cast(&schema.arrow_schema)); + + /// Get available column names from schema + std::unordered_set available_columns; + if (schema.arrow_schema.n_children > 0 && schema.arrow_schema.children) + { + for (int64_t i = 0; i < schema.arrow_schema.n_children; ++i) + { + if (schema.arrow_schema.children[i] && schema.arrow_schema.children[i]->name) + { + available_columns.insert(schema.arrow_schema.children[i]->name); + } + } + } + + /// Only add column names that exist in the schema + py::list projection_list; + for (const auto & name : column_names) + { + if (available_columns.contains(name)) + { + projection_list.append(name); + } + } + + /// Only set columns if we have valid projections + if (projection_list.size() > 0) + { + kwargs["columns"] = projection_list; + } + } + + auto scanner = arrow_scanner(dataset, **kwargs); + + auto record_batches = scanner.attr("to_reader")(); + auto res = std::make_unique(); + auto export_to_c = record_batches.attr("_export_to_c"); + export_to_c(reinterpret_cast(&res->arrow_array_stream)); + return res; +} + +} // namespace CHDB diff --git a/programs/local/PyArrowStreamFactory.h b/programs/local/PyArrowStreamFactory.h new file mode 100644 index 00000000000..4c480d1d113 --- /dev/null +++ b/programs/local/PyArrowStreamFactory.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ArrowStreamWrapper.h" + +#include +#include + +namespace CHDB +{ + +/// Factory class for creating ArrowArrayStream from Python objects +class PyArrowStreamFactory +{ +public: + static std::unique_ptr createFromPyObject( + pybind11::object & py_obj, + const DB::Names & column_names); + +private: + static std::unique_ptr createFromTable( + pybind11::object & table, + const DB::Names & column_names); +}; + +} // namespace CHDB diff --git a/programs/local/PyArrowTable.cpp b/programs/local/PyArrowTable.cpp new file mode 100644 index 00000000000..5eef4b43ee4 --- /dev/null +++ b/programs/local/PyArrowTable.cpp @@ -0,0 +1,58 @@ +#include "PyArrowTable.h" +#include "ArrowSchema.h" +#include "PyArrowCacheItem.h" +#include "PythonImporter.h" + +#include + +using namespace DB; + +namespace CHDB +{ + +PyArrowObjectType PyArrowTable::getArrowType(const py::object & obj) +{ + chassert(py::gil_check()); + + if (ModuleIsLoaded()) + { + auto & import_cache = PythonImporter::ImportCache(); + auto table_class = import_cache.pyarrow.table(); + + if (py::isinstance(obj, table_class)) + return PyArrowObjectType::Table; + } + + return PyArrowObjectType::Invalid; +} + +bool PyArrowTable::isPyArrowTable(const py::object & object) +{ + try + { + return getArrowType(object) == PyArrowObjectType::Table; + } + catch (const py::error_already_set &) + { + return false; + } +} + +ColumnsDescription PyArrowTable::getActualTableStructure(const py::object & object, ContextPtr & context) +{ + chassert(py::gil_check()); + chassert(isPyArrowTable(object)); + + NamesAndTypesList names_and_types; + + auto obj_schema = object.attr("schema"); + auto export_to_c = obj_schema.attr("_export_to_c"); + ArrowSchemaWrapper schema; + export_to_c(reinterpret_cast(&schema.arrow_schema)); + + ArrowSchemaWrapper::convertArrowSchema(schema, names_and_types, context); + + return ColumnsDescription(names_and_types); +} + +} // namespace CHDB diff --git a/programs/local/PyArrowTable.h b/programs/local/PyArrowTable.h new file mode 100644 index 00000000000..d4fccb634ca --- /dev/null +++ b/programs/local/PyArrowTable.h @@ -0,0 +1,27 @@ +#pragma once + +#include "PybindWrapper.h" + +#include +#include + +namespace CHDB +{ + +enum class PyArrowObjectType +{ + Invalid, + Table +}; + +class PyArrowTable +{ +public: + static DB::ColumnsDescription getActualTableStructure(const py::object & object, DB::ContextPtr & context); + + static bool isPyArrowTable(const py::object & object); + + static PyArrowObjectType getArrowType(const py::object & object); +}; + +} // namespace CHDB diff --git a/programs/local/PybindWrapper.h b/programs/local/PybindWrapper.h index d653ab1ea73..17e630f10eb 100644 --- a/programs/local/PybindWrapper.h +++ b/programs/local/PybindWrapper.h @@ -4,15 +4,19 @@ #include #include -namespace pybind11 { +namespace pybind11 +{ +bool gil_check(); void gil_assert(); } -namespace CHDB { +namespace CHDB +{ -namespace py { +namespace py +{ using namespace pybind11; diff --git a/programs/local/PythonImportCache.h b/programs/local/PythonImportCache.h index 516ed057875..6bdf5cf7c8f 100644 --- a/programs/local/PythonImportCache.h +++ b/programs/local/PythonImportCache.h @@ -3,6 +3,7 @@ #include "DatetimeCacheItem.h" #include "DecimalCacheItem.h" #include "PandasCacheItem.h" +#include "PyArrowCacheItem.h" #include "PythonImportCacheItem.h" #include @@ -18,12 +19,11 @@ struct PythonImportCache { ~PythonImportCache(); -public: PandasCacheItem pandas; + PyarrowCacheItem pyarrow; DatetimeCacheItem datetime; DecimalCacheItem decimal; -public: py::handle AddCache(py::object item); private: diff --git a/programs/local/PythonImportCacheItem.h b/programs/local/PythonImportCacheItem.h index bea49e5e02c..5908bbf57ac 100644 --- a/programs/local/PythonImportCacheItem.h +++ b/programs/local/PythonImportCacheItem.h @@ -6,6 +6,13 @@ namespace CHDB { +template +static bool ModuleIsLoaded() +{ + auto dict = pybind11::module_::import("sys").attr("modules"); + return dict.contains(py::str(T::Name)); +} + struct PythonImportCache; struct PythonImportCacheItem { diff --git a/programs/local/PythonSource.cpp b/programs/local/PythonSource.cpp index cfecdbaab1d..d2a1435eb01 100644 --- a/programs/local/PythonSource.cpp +++ b/programs/local/PythonSource.cpp @@ -5,9 +5,7 @@ #include "StoragePython.h" #include -#include #include -#include #include #include #include @@ -20,7 +18,6 @@ #include #include #include -#include "PythonUtils.h" #include #include #include @@ -37,6 +34,8 @@ #include #include +using namespace CHDB; + namespace DB { @@ -58,7 +57,8 @@ PythonSource::PythonSource( size_t max_block_size_, size_t stream_index, size_t num_streams, - const FormatSettings & format_settings_) + const FormatSettings & format_settings_, + ArrowTableReaderPtr arrow_table_reader_) : ISource(sample_block_.cloneEmpty()) , data_source(data_source_) , isInheritsFromPyReader(isInheritsFromPyReader_) @@ -70,6 +70,7 @@ PythonSource::PythonSource( , num_streams(num_streams) , cursor(0) , format_settings(format_settings_) + , arrow_table_reader(arrow_table_reader_) { } @@ -438,14 +439,7 @@ Chunk PythonSource::scanDataToChunk() if (names.size() != columns.size()) throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Column cache size mismatch"); - auto rows_per_stream = data_source_row_count / num_streams; - auto start = stream_index * rows_per_stream; - auto end = (stream_index + 1) * rows_per_stream; - if (stream_index == num_streams - 1) - end = data_source_row_count; - if (cursor == 0) - cursor = start; - auto count = std::min(max_block_size, end - cursor); + auto [offset, count] = calculateOffsetAndCount(); if (count == 0) return {}; LOG_DEBUG(logger, "Stream index {} Reading {} rows from {}", stream_index, count, cursor); @@ -554,7 +548,6 @@ Chunk PythonSource::scanDataToChunk() return Chunk(std::move(columns), count); } - Chunk PythonSource::generate() { size_t num_rows = 0; @@ -564,6 +557,12 @@ Chunk PythonSource::generate() try { + if (arrow_table_reader) + { + auto chunk = arrow_table_reader->readNextChunk(stream_index); + return chunk; + } + if (isInheritsFromPyReader) { PyObjectVecPtr data; @@ -574,10 +573,8 @@ Chunk PythonSource::generate() return std::move(genChunk(num_rows, data)); } - else - { - return std::move(scanDataToChunk()); - } + + return std::move(scanDataToChunk()); } catch (const Exception & e) { @@ -596,4 +593,19 @@ Chunk PythonSource::generate() throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Python data handling unknown exception"); } } + +std::pair PythonSource::calculateOffsetAndCount() +{ + auto rows_per_stream = data_source_row_count / num_streams; + auto start = stream_index * rows_per_stream; + auto end = (stream_index + 1) * rows_per_stream; + if (stream_index == num_streams - 1) + end = data_source_row_count; + if (cursor == 0) + cursor = start; + auto count = std::min(max_block_size, end - cursor); + + return std::make_pair(cursor, count); } + +} \ No newline at end of file diff --git a/programs/local/PythonSource.h b/programs/local/PythonSource.h index 610fe823679..1ef138c68a6 100644 --- a/programs/local/PythonSource.h +++ b/programs/local/PythonSource.h @@ -1,16 +1,15 @@ #pragma once +#include "ArrowTableReader.h" +#include "PythonUtils.h" #include "config.h" -#include #include - #include #include #include #include #include -#include "PythonUtils.h" namespace DB { @@ -32,7 +31,8 @@ class PythonSource : public ISource size_t max_block_size_, size_t stream_index, size_t num_streams, - const FormatSettings & format_settings_); + const FormatSettings & format_settings_, + CHDB::ArrowTableReaderPtr arrow_table_reader_ = nullptr); ~PythonSource() override = default; @@ -59,6 +59,8 @@ class PythonSource : public ISource const FormatSettings format_settings; + CHDB::ArrowTableReaderPtr arrow_table_reader; + Chunk genChunk(size_t & num_rows, PyObjectVecPtr data); PyObjectVecPtr scanData(const py::object & data, const std::vector & col_names, size_t & cursor, size_t count); @@ -78,8 +80,8 @@ class PythonSource : public ISource void insert_string_from_array(py::handle obj, const MutableColumnPtr & column); - void prepareColumnCache(Names & names, Columns & columns); Chunk scanDataToChunk(); void destory(PyObjectVecPtr & data); + std::pair calculateOffsetAndCount(); }; } diff --git a/programs/local/StorageArrowStream.cpp b/programs/local/StorageArrowStream.cpp new file mode 100644 index 00000000000..403875886f1 --- /dev/null +++ b/programs/local/StorageArrowStream.cpp @@ -0,0 +1,97 @@ +#include "StorageArrowStream.h" +#include "ArrowStreamSource.h" +#include "ArrowStreamWrapper.h" +#include "ArrowTableReader.h" + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +void registerStorageArrowStream(StorageFactory & factory) +{ + factory.registerStorage( + "ArrowStream", + [](const StorageFactory::Arguments & args) -> StoragePtr + { + if (args.engine_args.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "ArrowStream engine requires 1 argument: ArrowStreamInfo object"); + + CHDB::ArrowStreamRegistry::ArrowStreamInfo stream_info = std::any_cast(args.engine_args[0]); + return std::make_shared(args.table_id, stream_info, args.columns, args.getLocalContext()); + }, + { + .supports_settings = false, + .supports_parallel_insert = false, + }); +} + +StorageArrowStream::StorageArrowStream( + const StorageID & storage_id_, + const CHDB::ArrowStreamRegistry::ArrowStreamInfo & stream_info_, + const ColumnsDescription & columns_, + ContextPtr context_) + : IStorage(storage_id_) + , WithContext(context_) + , stream_info(stream_info_) +{ + StorageInMemoryMetadata storage_metadata; + storage_metadata.setColumns(columns_); + setInMemoryMetadata(storage_metadata); +} + +Pipe StorageArrowStream::read( + const Names & column_names, + const StorageSnapshotPtr & storage_snapshot, + SelectQueryInfo & /*query_info*/, + ContextPtr /*context*/, + QueryProcessingStage::Enum /*processed_stage*/, + size_t max_block_size, + size_t num_streams) +{ + chassert(stream_info.stream); + storage_snapshot->check(column_names); + + Block sample_block = prepareSampleBlock(column_names, storage_snapshot); + auto format_settings = getFormatSettings(getContext()); + + /// Create ArrowArrayStreamWrapper from the registered stream + auto arrow_stream_wrapper = std::make_unique(false); + arrow_stream_wrapper->arrow_array_stream = *stream_info.stream; + + auto arrow_table_reader = std::make_shared( + std::move(arrow_stream_wrapper), + sample_block, + format_settings, + num_streams, + max_block_size + ); + + Pipes pipes; + for (size_t stream = 0; stream < num_streams; ++stream) + { + pipes.emplace_back(std::make_shared( + sample_block, arrow_table_reader, stream)); + } + return Pipe::unitePipes(std::move(pipes)); +} + +Block StorageArrowStream::prepareSampleBlock(const Names & column_names, const StorageSnapshotPtr & storage_snapshot) +{ + Block sample_block; + for (const String & column_name : column_names) + { + auto column_data = storage_snapshot->metadata->getColumns().getPhysical(column_name); + sample_block.insert({column_data.type, column_data.name}); + } + return sample_block; +} + +} diff --git a/programs/local/StorageArrowStream.h b/programs/local/StorageArrowStream.h new file mode 100644 index 00000000000..46d0262680c --- /dev/null +++ b/programs/local/StorageArrowStream.h @@ -0,0 +1,44 @@ +#pragma once + +#include "ArrowStreamRegistry.h" + +#include +#include +#include +#include + +namespace DB +{ + +void registerStorageArrowStream(StorageFactory & factory); + +class StorageArrowStream : public IStorage, public WithContext +{ +public: + StorageArrowStream( + const StorageID & storage_id_, + const CHDB::ArrowStreamRegistry::ArrowStreamInfo & stream_info_, + const ColumnsDescription & columns_, + ContextPtr context_); + + ~StorageArrowStream() override = default; + + std::string getName() const override { return "ArrowStream"; } + + Pipe read( + const Names & column_names, + const StorageSnapshotPtr & storage_snapshot, + SelectQueryInfo & query_info, + ContextPtr context, + QueryProcessingStage::Enum processed_stage, + size_t max_block_size, + size_t num_streams) override; + + Block prepareSampleBlock(const Names & column_names, const StorageSnapshotPtr & storage_snapshot); + +private: + CHDB::ArrowStreamRegistry::ArrowStreamInfo stream_info; + Poco::Logger * logger = &Poco::Logger::get("StorageArrowStream"); +}; + +} diff --git a/programs/local/StoragePython.cpp b/programs/local/StoragePython.cpp index c97ad161117..8f3b4f8002f 100644 --- a/programs/local/StoragePython.cpp +++ b/programs/local/StoragePython.cpp @@ -2,6 +2,8 @@ #include "FormatHelper.h" #include "PybindWrapper.h" #include "PythonSource.h" +#include "PyArrowTable.h" +#include "PyArrowStreamFactory.h" #include #include @@ -31,6 +33,8 @@ #include +using namespace CHDB; + namespace DB { @@ -78,13 +82,25 @@ Pipe StoragePython::read( std::make_shared(data_source, true, sample_block, column_cache, data_source_row_count, max_block_size, 0, 1, format_settings)); } - prepareColumnCache(column_names, sample_block.getColumns(), sample_block); + ArrowTableReaderPtr arrow_table_reader; + { + py::gil_scoped_acquire acquire; + if (PyArrowTable::isPyArrowTable(data_source)) + { + auto arrow_stream = PyArrowStreamFactory::createFromPyObject(data_source, sample_block.getNames()); + arrow_table_reader = std::make_shared( + std::move(arrow_stream), sample_block, + format_settings, num_streams, max_block_size); + } + } + + if (!arrow_table_reader) + prepareColumnCache(column_names, sample_block.getColumns(), sample_block); Pipes pipes; - // num_streams = 32; // for chdb testing for (size_t stream = 0; stream < num_streams; ++stream) pipes.emplace_back(std::make_shared( - data_source, false, sample_block, column_cache, data_source_row_count, max_block_size, stream, num_streams, format_settings)); + data_source, false, sample_block, column_cache, data_source_row_count, max_block_size, stream, num_streams, format_settings, arrow_table_reader)); return Pipe::unitePipes(std::move(pipes)); } diff --git a/programs/local/TableFunctionArrowStream.cpp b/programs/local/TableFunctionArrowStream.cpp new file mode 100644 index 00000000000..a59176c71c0 --- /dev/null +++ b/programs/local/TableFunctionArrowStream.cpp @@ -0,0 +1,127 @@ +#include "TableFunctionArrowStream.h" +#include "ArrowSchema.h" +#include "ArrowStreamWrapper.h" +#include "StorageArrowStream.h" + +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int UNKNOWN_IDENTIFIER; + extern const int BAD_ARGUMENTS; +} + +void TableFunctionArrowStream::parseArguments(const ASTPtr & ast_function, ContextPtr context) +{ + const auto & func_args = ast_function->as(); + + if (!func_args.arguments) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Table function 'arrowstream' must have arguments."); + + ASTs & args = func_args.arguments->children; + + if (args.size() != 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "ArrowStream table requires 1 argument: stream name"); + + auto stream_name_arg = evaluateConstantExpressionOrIdentifierAsLiteral(args[0], context); + + try + { + stream_name = stream_name_arg->as().value.safeGet(); + + stream_name.erase( + std::remove_if(stream_name.begin(), stream_name.end(), + [](char c) { return c == '\'' || c == '\"' || c == '`'; }), + stream_name.end()); + + auto stream_opt = CHDB::ArrowStreamRegistry::instance().getArrowStream(stream_name); + if (!stream_opt) + { + throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, + "ArrowStream '{}' not found in registry. " + "Please register it first using chdb_arrow_scan.", + stream_name); + } + + stream_info = *stream_opt; + } + catch (const Exception &) + { + throw; + } + catch (const std::exception & e) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Error parsing arrowstream argument: {}", e.what()); + } + catch (...) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Error parsing arrowstream argument"); + } +} + +StoragePtr TableFunctionArrowStream::executeImpl( + const ASTPtr & /*ast_function*/, + ContextPtr context, + const String & table_name, + ColumnsDescription /*cached_columns*/, + bool is_insert_query) const +{ + if (stream_name.empty() || !stream_info.stream) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "ArrowStream name not initialized"); + + auto columns = getActualTableStructure(context, is_insert_query); + + auto storage = std::make_shared( + StorageID(getDatabaseName(), table_name), + stream_info, + columns, + context); + + storage->startup(); + return storage; +} + +ColumnsDescription TableFunctionArrowStream::getActualTableStructure( + ContextPtr context, bool /*is_insert_query*/) const +{ + auto * arrow_stream = reinterpret_cast(stream_info.stream); + CHDB::ArrowSchemaWrapper schema; + + if (arrow_stream->get_schema(arrow_stream, &schema.arrow_schema) != 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Failed to get schema from ArrowStream '{}'", stream_name); + } + + NamesAndTypesList names_and_types; + CHDB::ArrowSchemaWrapper::convertArrowSchema(schema, names_and_types, context); + + return ColumnsDescription(names_and_types); +} + +void registerTableFunctionArrowStream(TableFunctionFactory & factory) +{ + factory.registerFunction( + {.documentation = { + .description = R"( +Creates a table from a registered ArrowStream. +This table function requires a single argument which is the name of a registered ArrowStream. +Use chdb_arrow_register_table() to register ArrowStreams first. +)", + .examples = {{"arrowstream", "SELECT * FROM arrowstream('my_data')", ""}}, + .category = FunctionDocumentation::Category::TableFunction + }}, + TableFunctionFactory::Case::Insensitive); +} + +} diff --git a/programs/local/TableFunctionArrowStream.h b/programs/local/TableFunctionArrowStream.h new file mode 100644 index 00000000000..761830e3633 --- /dev/null +++ b/programs/local/TableFunctionArrowStream.h @@ -0,0 +1,41 @@ +#pragma once + +#include "ArrowStreamRegistry.h" + +#include +#include +#include + +namespace DB +{ + +class TableFunctionFactory; +void registerTableFunctionArrowStream(TableFunctionFactory & factory); + +class TableFunctionArrowStream : public ITableFunction +{ +public: + static constexpr auto name = "arrowstream"; + std::string getName() const override { return name; } + +private: + Poco::Logger * logger = &Poco::Logger::get("TableFunctionArrowStream"); + + StoragePtr executeImpl( + const ASTPtr & ast_function, + ContextPtr context, + const std::string & table_name, + ColumnsDescription cached_columns, + bool is_insert_query) const override; + + const char * getStorageTypeName() const override { return "ArrowStream"; } + + void parseArguments(const ASTPtr & ast_function, ContextPtr context) override; + + ColumnsDescription getActualTableStructure(ContextPtr context, bool is_insert_query) const override; + + String stream_name; + CHDB::ArrowStreamRegistry::ArrowStreamInfo stream_info; +}; + +} diff --git a/programs/local/TableFunctionPython.cpp b/programs/local/TableFunctionPython.cpp index 042a41fcaa8..ecb4ca09ba8 100644 --- a/programs/local/TableFunctionPython.cpp +++ b/programs/local/TableFunctionPython.cpp @@ -1,13 +1,13 @@ +#include "TableFunctionPython.h" #include "StoragePython.h" #include "PandasDataFrame.h" +#include "PyArrowTable.h" #include "PythonDict.h" #include "PythonReader.h" #include "PythonTableCache.h" #include "PythonUtils.h" -#include "TableFunctionPython.h" #include - #include #include #include @@ -39,7 +39,6 @@ extern const int UNKNOWN_FORMAT; void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr context) { - // py::gil_scoped_acquire acquire; const auto & func_args = ast_function->as(); if (!func_args.arguments) @@ -115,6 +114,9 @@ ColumnsDescription TableFunctionPython::getActualTableStructure(ContextPtr conte if (PandasDataFrame::isPandasDataframe(reader)) return PandasDataFrame::getActualTableStructure(reader, context); + if (PyArrowTable::isPyArrowTable(reader)) + return PyArrowTable::getActualTableStructure(reader, context); + if (PythonDict::isPythonDict(reader)) return PythonDict::getActualTableStructure(reader, context); diff --git a/programs/local/TableFunctionPython.h b/programs/local/TableFunctionPython.h index 067a6fa4601..ffea035e24b 100644 --- a/programs/local/TableFunctionPython.h +++ b/programs/local/TableFunctionPython.h @@ -1,9 +1,7 @@ #pragma once -#include "StoragePython.h" #include "PybindWrapper.h" -#include "config.h" #include #include #include @@ -22,7 +20,7 @@ class TableFunctionPython : public ITableFunction ~TableFunctionPython() override { // Acquire the GIL before destroying the reader object - py::gil_scoped_acquire acquire; + pybind11::gil_scoped_acquire acquire; reader.dec_ref(); reader.release(); } @@ -40,7 +38,7 @@ class TableFunctionPython : public ITableFunction void parseArguments(const ASTPtr & ast_function, ContextPtr context) override; ColumnsDescription getActualTableStructure(ContextPtr context, bool is_insert_query) const override; - py::object reader; + pybind11::object reader; }; } diff --git a/programs/local/chdb-arrow.cpp b/programs/local/chdb-arrow.cpp new file mode 100644 index 00000000000..e899e47e0c5 --- /dev/null +++ b/programs/local/chdb-arrow.cpp @@ -0,0 +1,187 @@ +#include "chdb.h" +#include "chdb-internal.h" +#include "ArrowStreamRegistry.h" + +#include +#include +#include + +namespace CHDB +{ + +struct PrivateData +{ + ArrowSchema * schema; + ArrowArray * array; + bool done = false; +}; + +void EmptySchemaRelease(ArrowSchema * schema) +{ + schema->release = nullptr; +} + +void EmptyArrayRelease(ArrowArray * array) +{ + array->release = nullptr; +} + +void EmptyStreamRelease(ArrowArrayStream * stream) +{ + stream->release = nullptr; +} + +int GetSchema(struct ArrowArrayStream * stream, struct ArrowSchema * out) +{ + auto * private_data = static_cast((stream->private_data)); + if (private_data->schema == nullptr) + return CHDBError; + + *out = *private_data->schema; + out->release = EmptySchemaRelease; + return CHDBSuccess; +} + +int GetNext(struct ArrowArrayStream * stream, struct ArrowArray * out) +{ + auto * private_data = static_cast((stream->private_data)); + *out = *private_data->array; + if (private_data->done) + { + out->release = nullptr; + } + else + { + out->release = EmptyArrayRelease; + } + + private_data->done = true; + return CHDBSuccess; +} + +const char * GetLastError(struct ArrowArrayStream * /*stream*/) +{ + return nullptr; +} + +void Release(struct ArrowArrayStream * stream) +{ + if (stream->private_data != nullptr) + delete reinterpret_cast(stream->private_data); + + stream->private_data = nullptr; + stream->release = nullptr; +} + +void chdb_destroy_arrow_stream(ArrowArrayStream * arrow_stream) +{ + if (!arrow_stream) + return; + + if (arrow_stream->release) + arrow_stream->release(arrow_stream); + chassert(!arrow_stream->release); + + delete arrow_stream; +} + +} // namespace CHDB + +static chdb_state chdb_inner_arrow_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_stream arrow_stream, bool is_owner) +{ + std::shared_lock global_lock(global_connection_mutex); + + if (!table_name || !arrow_stream) + return CHDBError; + + auto * connection = reinterpret_cast(conn); + if (!checkConnectionValidity(connection)) + return CHDBError; + + auto * stream = reinterpret_cast(arrow_stream); + + ArrowSchema schema; + if (stream->get_schema(stream, &schema) == CHDBError) + return CHDBError; + + using ReleaseFunction = void (*)(ArrowSchema *); + std::vector releases(static_cast(schema.n_children)); + for (size_t i = 0; i < static_cast(schema.n_children); i++) + { + auto * child = schema.children[i]; + releases[i] = child->release; + child->release = CHDB::EmptySchemaRelease; + } + + bool success = false; + try + { + success = CHDB::ArrowStreamRegistry::instance().registerArrowStream(String(table_name), stream, is_owner); + } + catch (...) + { + return CHDBError; + } + + for (size_t i = 0; i < static_cast(schema.n_children); ++i) + { + schema.children[i]->release = releases[i]; + } + + return success ? CHDBSuccess : CHDBError; +} + +chdb_state chdb_arrow_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_stream arrow_stream) +{ + ChdbDestructorGuard guard; + return chdb_inner_arrow_scan(conn, table_name, arrow_stream, false); +} + +chdb_state chdb_arrow_array_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array) +{ + ChdbDestructorGuard guard; + + auto * private_data = new CHDB::PrivateData(); + private_data->schema = reinterpret_cast(arrow_schema); + private_data->array = reinterpret_cast(arrow_array); + private_data->done = false; + + auto * stream = new ArrowArrayStream(); + stream->get_schema = CHDB::GetSchema; + stream->get_next = CHDB::GetNext; + stream->get_last_error = CHDB::GetLastError; + stream->release = CHDB::Release; + stream->private_data = private_data; + + return chdb_inner_arrow_scan(conn, table_name, reinterpret_cast(stream), true); +} + +chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name) +{ + ChdbDestructorGuard guard; + + std::shared_lock global_lock(global_connection_mutex); + + if (!table_name) + return CHDBError; + + auto * connection = reinterpret_cast(conn); + if (!checkConnectionValidity(connection)) + return CHDBError; + + try + { + CHDB::ArrowStreamRegistry::instance().unregisterArrowStream(String(table_name)); + return CHDBSuccess; + } + catch (...) + { + return CHDBError; + } +} diff --git a/programs/local/chdb-internal.h b/programs/local/chdb-internal.h index 0b93907d981..945cf4ba3ae 100644 --- a/programs/local/chdb-internal.h +++ b/programs/local/chdb-internal.h @@ -6,14 +6,40 @@ #include #include #include +#include #include -#include +#include namespace DB { class LocalServer; } +extern std::shared_mutex global_connection_mutex; +extern thread_local bool chdb_destructor_cleanup_in_progress; + +/** + * RAII guard for accurate memory tracking in chDB external interfaces + * used at the beginning of execution to provide thread marking, enabling MemoryTracker + * to accurately track memory changes. + */ +class ChdbDestructorGuard +{ +public: + ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = true; } + ~ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = false; } + ChdbDestructorGuard(const ChdbDestructorGuard &) = delete; + ChdbDestructorGuard & operator=(const ChdbDestructorGuard &) = delete; + ChdbDestructorGuard(ChdbDestructorGuard &&) = delete; + ChdbDestructorGuard & operator=(ChdbDestructorGuard &&) = delete; +}; + +/// Connection validity check function +inline bool checkConnectionValidity(chdb_conn * connection) +{ + return connection && connection->connected && connection->queue; +} + namespace CHDB { @@ -93,4 +119,7 @@ void cancelStreamQuery(DB::LocalServer * server, void * stream_result); const std::string & chdb_result_error_string(chdb_result * result); const std::string & chdb_streaming_result_error_string(chdb_streaming_result * result); + +void chdb_destroy_arrow_stream(ArrowArrayStream * arrow_stream); + } diff --git a/programs/local/chdb.cpp b/programs/local/chdb.cpp index bd933e6ab93..26885cb6cdb 100644 --- a/programs/local/chdb.cpp +++ b/programs/local/chdb.cpp @@ -1,14 +1,11 @@ #include "chdb.h" -#include -#include -#include "Common/MemoryTracker.h" +#include "chdb-internal.h" #include "LocalServer.h" #include "QueryResult.h" -#include "chdb-internal.h" #if USE_PYTHON -# include "FormatHelper.h" -# include "PythonTableCache.h" +#include "FormatHelper.h" +#include "PythonTableCache.h" #endif #ifdef CHDB_STATIC_LIBRARY_BUILD @@ -22,35 +19,22 @@ namespace DB #endif extern thread_local bool chdb_destructor_cleanup_in_progress; +std::shared_mutex global_connection_mutex; namespace CHDB { -/** - * RAII guard for accurate memory tracking in chDB external interfaces - * - * When Python (or other programming language) threads call chDB-provided interfaces - * such as chdb_destroy_query_result, the memory released cannot be accurately tracked - * by ClickHouse's MemoryTracker, which may lead to false reports of insufficient memory. - * - * Therefore, for all externally exposed chDB interfaces, ChdbDestructorGuard must be - * used at the beginning of execution to provide thread marking, enabling MemoryTracker - * to accurately track memory changes. - */ -class ChdbDestructorGuard +#if !USE_PYTHON +extern "C" { -public: - ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = true; } - - ~ChdbDestructorGuard() { chdb_destructor_cleanup_in_progress = false; } + extern chdb_state chdb_arrow_scan(chdb_connection, const char *, chdb_arrow_stream); +} - ChdbDestructorGuard(const ChdbDestructorGuard &) = delete; - ChdbDestructorGuard & operator=(const ChdbDestructorGuard &) = delete; - ChdbDestructorGuard(ChdbDestructorGuard &&) = delete; - ChdbDestructorGuard & operator=(ChdbDestructorGuard &&) = delete; +[[maybe_unused]] void * force_link_arrow_functions[] = { + reinterpret_cast(chdb_arrow_scan) }; +#endif -static std::shared_mutex global_connection_mutex; static std::mutex CHDB_MUTEX; chdb_conn * global_conn_ptr = nullptr; std::string global_db_path; @@ -298,11 +282,6 @@ static std::pair createQueryResult(DB::LocalServer * serve return std::make_pair(std::move(query_result), is_end); } -static bool checkConnectionValidity(chdb_conn * conn) -{ - return conn && conn->connected && conn->queue; -} - static QueryResultPtr executeQueryRequest( CHDB::QueryQueue * queue, const char * query, diff --git a/programs/local/chdb.h b/programs/local/chdb.h index e4ecaddb4d9..d16f5172f43 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -66,6 +66,13 @@ typedef struct #endif +// Return state enumeration for chDB API functions +typedef enum chdb_state +{ + CHDBSuccess = 0, + CHDBError = 1 +} chdb_state; + // Opaque handle for query results. // Internal data structure managed by chDB implementation. // Users should only interact through API functions. @@ -82,6 +89,24 @@ typedef struct chdb_connection_ void * internal_data; } * chdb_connection; +// Holds an arrow array stream. +typedef struct chdb_arrow_stream_ +{ + void * internal_data; +} * chdb_arrow_stream; + +// Holds an arrow schema. +typedef struct chdb_arrow_schema_ +{ + void * internal_data; +} * chdb_arrow_schema; + +// Holds an arrow array. +typedef struct chdb_arrow_array_ +{ + void * internal_data; +} * chdb_arrow_array; + #ifndef CHDB_NO_DEPRECATED // WARNING: The following interfaces are deprecated and will be removed in a future version. CHDB_EXPORT struct local_result * query_stable(int argc, char ** argv); @@ -266,20 +291,6 @@ CHDB_EXPORT chdb_result * chdb_query_cmdline(int argc, char ** argv); */ CHDB_EXPORT chdb_result * chdb_stream_query(chdb_connection conn, const char * query, const char * format); -/** - * Executes a query with explicit string lengths (binary-safe). - * @brief Thread-safe function that handles query execution with specified buffer lengths - * @param conn Connection to execute query on - * @param query SQL query buffer (may contain null bytes) - * @param query_len Length of query buffer in bytes - * @param format Output format buffer (may contain null bytes) - * @param format_len Length of format buffer in bytes - * @return Query result structure containing output or error message - * @note Strings do not need to be null-terminated - * @note Use this function when dealing with queries/formats containing null bytes - */ -CHDB_EXPORT chdb_result * chdb_query_n(chdb_connection conn, const char * query, size_t query_len, const char * format, size_t format_len); - /** * Executes a streaming query with explicit string lengths (binary-safe). * @brief Initializes streaming query execution with specified buffer lengths @@ -375,6 +386,41 @@ CHDB_EXPORT uint64_t chdb_result_storage_bytes_read(chdb_result * result); */ CHDB_EXPORT const char * chdb_result_error(chdb_result * result); +//===--------------------------------------------------------------------===// +// Arrow Integration +//===--------------------------------------------------------------------===// + +/** + * Registers an Arrow stream as an arrow stream table function with the given name + * @param conn The connection on which to execute the registration + * @param table_name Name to register for the arrow stream table function + * @param arrow_stream chdb Arrow stream handle + * @return CHDBSuccess on success, CHDBError on failure + */ +CHDB_EXPORT chdb_state chdb_arrow_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_stream arrow_stream); + +/** + * Registers an Arrow array as an arrow stream table function with the given name + * @param conn The connection on which to execute the registration + * @param table_name Name to register for the arrow stream table function + * @param arrow_schema chdb Arrow schema handle + * @param arrow_array chdb Arrow array handle + * @return CHDBSuccess on success, CHDBError on failure + */ +CHDB_EXPORT chdb_state chdb_arrow_array_scan( + chdb_connection conn, const char * table_name, + chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array); + +/** + * Unregisters an arrow stream table function that was previously registered via chdb_arrow_scan + * @param conn The connection on which to execute the unregister operation + * @param table_name Name of the arrow stream table function to unregister + * @return CHDBSuccess on success, CHDBError on failure + */ +CHDB_EXPORT chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name); + #ifdef __cplusplus } #endif diff --git a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp index 273122ec09f..3a501a4d29a 100644 --- a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp +++ b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp @@ -1220,13 +1220,9 @@ static ColumnWithTypeAndName readNonNullableColumnFromArrowColumn( // TODO: read UUID as a string? case arrow::Type::NA: { - if (settings.allow_arrow_null_type) - { - auto type = std::make_shared(); - auto column = ColumnNothing::create(arrow_column->length()); - return {std::move(column), type, column_name}; - } - [[fallthrough]]; + auto type = std::make_shared(); + auto column = ColumnNothing::create(arrow_column->length()); + return {std::move(column), type, column_name}; } default: { diff --git a/tests/test_arrow_table_queries.py b/tests/test_arrow_table_queries.py new file mode 100644 index 00000000000..a05bf3b0e95 --- /dev/null +++ b/tests/test_arrow_table_queries.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 + +import unittest +import tempfile +import os +import shutil +import pyarrow as pa +import pyarrow.parquet as pq +import chdb +from chdb import session +from urllib.request import urlretrieve + +# Clean up and create session in the test methods instead of globally + +class TestChDBArrowTable(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Download parquet file if it doesn't exist + cls.parquet_file = "hits_0.parquet" + if not os.path.exists(cls.parquet_file): + print(f"Downloading {cls.parquet_file}...") + url = "https://datasets.clickhouse.com/hits_compatible/athena_partitioned/hits_0.parquet" + urlretrieve(url, cls.parquet_file) + print("Download complete!") + + # Load parquet as PyArrow table + cls.arrow_table = pq.read_table(cls.parquet_file) + cls.table_size = cls.arrow_table.nbytes + cls.num_rows = cls.arrow_table.num_rows + cls.num_columns = cls.arrow_table.num_columns + + print(f"Loaded Arrow table: {cls.num_rows} rows, {cls.num_columns} columns, {cls.table_size} bytes") + + if os.path.exists(".test_chdb_arrow_table"): + shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) + cls.sess = session.Session(".test_chdb_arrow_table") + + @classmethod + def tearDownClass(cls): + # Clean up session directory + if os.path.exists(".test_chdb_arrow_table"): + shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True) + cls.sess.close() + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_arrow_table_basic_info(self): + """Test basic Arrow table information""" + self.assertEqual(self.table_size, 729898624) + self.assertEqual(self.num_rows, 1000000) + self.assertEqual(self.num_columns, 105) + + def test_arrow_table_count(self): + """Test counting rows in Arrow table""" + my_arrow_table = self.arrow_table + result = self.sess.query("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)", "CSV") + lines = str(result).strip().split('\n') + count = int(lines[0]) + self.assertEqual(count, self.num_rows, f"Count should match table rows: {self.num_rows}") + + def test_arrow_table_schema(self): + """Test querying Arrow table schema information""" + my_arrow_table = self.arrow_table + result = self.sess.query("DESCRIBE Python(my_arrow_table)", "CSV") + # print(result) + self.assertIn('WatchID', str(result)) + self.assertIn('URLHash', str(result)) + + def test_arrow_table_limit(self): + """Test LIMIT queries on Arrow table""" + my_arrow_table = self.arrow_table + result = self.sess.query("SELECT * FROM Python(my_arrow_table) LIMIT 5", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 5, "Should have 5 data rows") + + def test_arrow_table_select_columns(self): + """Test selecting specific columns from Arrow table""" + my_arrow_table = self.arrow_table + # Get first few column names from schema + schema = self.arrow_table.schema + first_col = schema.field(0).name + second_col = schema.field(1).name if len(schema) > 1 else first_col + + result = self.sess.query(f"SELECT {first_col}, {second_col} FROM Python(my_arrow_table) LIMIT 3", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 3, "Should have 3 data rows") + + def test_arrow_table_where_clause(self): + """Test WHERE clause filtering on Arrow table""" + my_arrow_table = self.arrow_table + # Find a numeric column for filtering + numeric_col = None + for field in self.arrow_table.schema: + if pa.types.is_integer(field.type) or pa.types.is_floating(field.type): + numeric_col = field.name + break + + result = self.sess.query(f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE {numeric_col} > 1", "CSV") + lines = str(result).strip().split('\n') + count = int(lines[0]) + self.assertEqual(count, 1000000) + + def test_arrow_table_group_by(self): + """Test GROUP BY queries on Arrow table""" + my_arrow_table = self.arrow_table + # Find a string column for grouping + string_col = None + for field in self.arrow_table.schema: + if pa.types.is_binary(field.type) or pa.types.is_large_binary(field.type): + string_col = field.name + break + + result = self.sess.query(f"SELECT {string_col}, COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY {string_col} ORDER BY cnt DESC LIMIT 5", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 5) + + def test_arrow_table_aggregations(self): + """Test aggregation functions on Arrow table""" + my_arrow_table = self.arrow_table + # Find a numeric column for aggregation + numeric_col = None + for field in self.arrow_table.schema: + if pa.types.is_integer(field.type) or pa.types.is_floating(field.type): + numeric_col = field.name + break + + result = self.sess.query(f"SELECT AVG({numeric_col}) as avg_val, MIN({numeric_col}) as min_val, MAX({numeric_col}) as max_val FROM Python(my_arrow_table)", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 1) + + def test_arrow_table_order_by(self): + """Test ORDER BY queries on Arrow table""" + my_arrow_table = self.arrow_table + # Use first column for ordering + first_col = self.arrow_table.schema.field(0).name + + result = self.sess.query(f"SELECT {first_col} FROM Python(my_arrow_table) ORDER BY {first_col} LIMIT 10", "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 10) + + def test_arrow_table_subquery(self): + """Test subqueries with Arrow table""" + my_arrow_table = self.arrow_table + result = self.sess.query(""" + SELECT COUNT(*) as total_count + FROM ( + SELECT * FROM Python(my_arrow_table) + WHERE WatchID IS NOT NULL + LIMIT 1000 + ) subq + """, "CSV") + lines = str(result).strip().split('\n') + self.assertEqual(len(lines), 1) + count = int(lines[0]) + self.assertEqual(count, 1000) + + def test_arrow_table_multiple_tables(self): + """Test using multiple Arrow tables in one query""" + my_arrow_table = self.arrow_table + # Create a smaller subset table + subset_table = my_arrow_table.slice(0, min(100, my_arrow_table.num_rows)) + + result = self.sess.query(""" + SELECT + (SELECT COUNT(*) FROM Python(my_arrow_table)) as full_count, + (SELECT COUNT(*) FROM Python(subset_table)) as subset_count + """, "CSV") + self.assertEqual(str(result).strip(), '1000000,100') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_query_json.py b/tests/test_query_json.py index c52a24f589c..b4b1f3b35d1 100644 --- a/tests/test_query_json.py +++ b/tests/test_query_json.py @@ -21,6 +21,9 @@ \\N,\\N,"[1,666]" """ EXPECTED2 = '"apple1",3,\\N\n\\N,4,2\n' +EXPECTED3 = """"['urgent','important']",100.3,"[]" +"[]",0,"[1,666]" +""" dict1 = { "c1": [1, 2, 3, 4, 5, 6, 7, 8], @@ -389,9 +392,9 @@ def test_special_numpy_types(self): self.assertEqual(str(ret), '"2025-05-30 20:08:08.123000000"\n') def test_query_pyarrow_table1(self): - ret = self.sess.query("SELECT c4.tags, c3.deep.level2.level3, c3.mixed_list[].a FROM Python(arrow_table1) WHERE c1 <= 2 ORDER BY c1") + ret = self.sess.query("SELECT c4.tags, c3.deep.level2.level3, c3.mixed_list.a FROM Python(arrow_table1) WHERE c1 <= 2 ORDER BY c1") - self.assertEqual(str(ret), EXPECTED1) + self.assertEqual(str(ret), EXPECTED3) def test_pyarrow_complex_types(self): struct_type = pa.struct([ diff --git a/tests/test_query_py.py b/tests/test_query_py.py index ea6a2074665..a35dccf8d47 100644 --- a/tests/test_query_py.py +++ b/tests/test_query_py.py @@ -185,7 +185,7 @@ def test_query_arrow3(self): ) self.assertEqual( str(ret), - "5872873,587287.3,553446.5,470878.25,3,0,7,10\n", + "5872873,587287.3,553446.5,582813.5,3,0,7,10\n", ) def test_query_arrow4(self): @@ -209,17 +209,17 @@ def test_query_arrow5(self): self.assertDictEqual( schema_dict, { - "quadkey": "String", - "tile": "String", - "tile_x": "Float64", - "tile_y": "Float64", - "avg_d_kbps": "Int64", - "avg_u_kbps": "Int64", - "avg_lat_ms": "Int64", - "avg_lat_down_ms": "Float64", - "avg_lat_up_ms": "Float64", - "tests": "Int64", - "devices": "Int64", + "quadkey": "Nullable(String)", + "tile": "Nullable(String)", + "tile_x": "Nullable(Float64)", + "tile_y": "Nullable(Float64)", + "avg_d_kbps": "Nullable(Int64)", + "avg_u_kbps": "Nullable(Int64)", + "avg_lat_ms": "Nullable(Int64)", + "avg_lat_down_ms": "Nullable(Float64)", + "avg_lat_up_ms": "Nullable(Float64)", + "tests": "Nullable(Int64)", + "devices": "Nullable(Int64)", }, ) ret = chdb.query( @@ -237,23 +237,34 @@ def test_query_arrow5(self): self.assertDictEqual( {x["name"]: x["type"] for x in json.loads(str(ret)).get("meta")}, { - "max(avg_d_kbps)": "Int64", - "max(avg_lat_down_ms)": "Float64", - "max(avg_lat_ms)": "Int64", - "max(avg_lat_up_ms)": "Float64", - "max(avg_u_kbps)": "Int64", - "max(devices)": "Int64", - "max(tests)": "Int64", - "round(median(avg_d_kbps), 2)": "Float64", - "round(median(avg_lat_down_ms), 2)": "Float64", - "round(median(avg_lat_ms), 2)": "Float64", - "round(median(avg_lat_up_ms), 2)": "Float64", - "round(median(avg_u_kbps), 2)": "Float64", - "round(median(devices), 2)": "Float64", - "round(median(tests), 2)": "Float64", + "max(avg_d_kbps)": "Nullable(Int64)", + "max(avg_lat_down_ms)": "Nullable(Float64)", + "max(avg_lat_ms)": "Nullable(Int64)", + "max(avg_lat_up_ms)": "Nullable(Float64)", + "max(avg_u_kbps)": "Nullable(Int64)", + "max(devices)": "Nullable(Int64)", + "max(tests)": "Nullable(Int64)", + "round(median(avg_d_kbps), 2)": "Nullable(Float64)", + "round(median(avg_lat_down_ms), 2)": "Nullable(Float64)", + "round(median(avg_lat_ms), 2)": "Nullable(Float64)", + "round(median(avg_lat_up_ms), 2)": "Nullable(Float64)", + "round(median(avg_u_kbps), 2)": "Nullable(Float64)", + "round(median(devices), 2)": "Nullable(Float64)", + "round(median(tests), 2)": "Nullable(Float64)", }, ) + def test_query_arrow_null_type(self): + null_array = pa.array([None, None, None]) + table = pa.table([null_array], names=["null_col"]) + ret = chdb.query("SELECT * FROM Python(table)") + self.assertEqual(str(ret), "\\N\n\\N\n\\N\n") + + null_array = pa.array([None, 1, None]) + table = pa.table([null_array], names=["null_col"]) + ret = chdb.query("SELECT * FROM Python(table)") + self.assertEqual(str(ret), "\\N\n1\n\\N\n") + def test_random_float(self): x = {"col1": [random.uniform(0, 1) for _ in range(0, 100000)]} ret = chdb.sql( diff --git a/tests/test_unsupported_arrow_types.py b/tests/test_unsupported_arrow_types.py new file mode 100644 index 00000000000..88640876e1f --- /dev/null +++ b/tests/test_unsupported_arrow_types.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 + +import unittest +import pyarrow as pa +import pyarrow.compute as pc +import chdb +from chdb import ChdbError + + +class TestUnsupportedArrowTypes(unittest.TestCase): + """Test that chDB properly handles unsupported Arrow types""" + + def setUp(self): + """Set up test data""" + self.sample_data = [1, 2, 3, 4, 5] + self.sample_strings = ["a", "b", "c", "d", "e"] + + def test_sparse_union_type(self): + """Test SPARSE_UNION type - should fail""" + # Create a sparse union type + children = [ + pa.array([1, None, 3, None, 5]), + pa.array([None, "b", None, "d", None]) + ] + type_ids = pa.array([0, 1, 0, 1, 0], type=pa.int8()) + + union_array = pa.UnionArray.from_sparse(type_ids, children) + table = pa.table([union_array], names=["sparse_union_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_dense_union_type(self): + """Test DENSE_UNION type - should fail""" + # Create a dense union type + children = [ + pa.array([1, 3, 5]), + pa.array(["b", "d"]) + ] + type_ids = pa.array([0, 1, 0, 1, 0], type=pa.int8()) + offsets = pa.array([0, 0, 1, 1, 2], type=pa.int32()) + + union_array = pa.UnionArray.from_dense(type_ids, offsets, children) + table = pa.table([union_array], names=["dense_union_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_interval_month_day_type(self): + """Test INTERVAL_MONTH_DAY type - should fail""" + pass + + def test_interval_day_time_type(self): + """Test INTERVAL_DAY_TIME type - should fail""" + pass + + def test_interval_month_day_nano_type(self): + """Test INTERVAL_MONTH_DAY_NANO type - should fail""" + start_timestamps = pc.strptime( + pa.array(["2021-01-01 00:00:00", "2022-01-01 00:00:00", "2023-01-01 00:00:00"]), + format="%Y-%m-%d %H:%M:%S", + unit="ns" + ) + + end_timestamps = pc.strptime( + pa.array(["2021-04-01 00:00:00", "2022-05-01 00:00:00", "2023-07-01 00:00:00"]), + format="%Y-%m-%d %H:%M:%S", + unit="ns" + ) + + interval_array = pc.month_day_nano_interval_between(start_timestamps, end_timestamps) + table = pa.table([interval_array], names=["interval_month_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_list_view_type(self): + """Test LIST_VIEW type - should fail""" + # Create list view array + list_data = [[1, 2], [3, 4, 5], [6], [], [7, 8, 9]] + list_view_array = pa.array(list_data, type=pa.list_view(pa.int64())) + table = pa.table([list_view_array], names=["list_view_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_large_list_view_type(self): + """Test LARGE_LIST_VIEW type - should fail""" + # Create large list view array (if available) + list_data = [[1, 2], [3, 4, 5], [6], [], [7, 8, 9]] + large_list_view_array = pa.array(list_data, type=pa.large_list_view(pa.int64())) + table = pa.table([large_list_view_array], names=["large_list_view_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_run_end_encoded_type(self): + """Test RUN_END_ENCODED type - should fail""" + # Create run-end encoded array + values = pa.array([1, 2, 3]) + run_ends = pa.array([3, 7, 10], type=pa.int32()) + ree_array = pa.RunEndEncodedArray.from_arrays(run_ends, values) + table = pa.table([ree_array], names=["run_end_encoded_col"]) + + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + + self.assertIn("Unsupported", str(context.exception)) + + def test_skip_unsupported_columns_setting(self): + """Test input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference=1 skips unsupported columns""" + # Create a table with both supported and unsupported columns + supported_col = pa.array([1, 2, 3, 4, 5]) # int64 - supported + # Create union array (unsupported) + union_children = [ + pa.array([10, None, 30, None, 50]), + pa.array([None, "b", None, "d", None]) + ] + union_type_ids = pa.array([0, 1, 0, 1, 0], type=pa.int8()) + unsupported_col = pa.UnionArray.from_sparse(union_type_ids, union_children) + + table = pa.table([ + supported_col, + unsupported_col + ], names=["supported_col", "unsupported_col"]) + + # Without the setting, query should fail + with self.assertRaises(Exception) as context: + chdb.query("SELECT * FROM Python(table)") + self.assertIn("Unsupported", str(context.exception)) + + # With the setting, query should succeed but skip unsupported column + result = chdb.query( + "SELECT * FROM Python(table) settings input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference=1" + ) + self.assertEqual(str(result), "1\n2\n3\n4\n5\n") + + +if __name__ == "__main__": + unittest.main(verbosity=2)