Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support create_on_conflict in CTAS #132

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main_distribution.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
name: Build extension binaries
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main
with:
duckdb_version: c29c67bb971362cd1e9143305acffebb1bc9bd63
duckdb_version: 78ebe44ef9f31b43cfc41b3bf739ab9069e16ae8
ci_tools_version: 5bdbe4d606d78dbd749f9578ba8ca639feece023
exclude_archs: "wasm_mvp;wasm_eh;wasm_threads;windows_amd64;windows_amd64_mingw;windows_amd64_rtools"
extension_name: substrait
Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 1122 files
21 changes: 19 additions & 2 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
interval_t interval {};
interval.months = 0;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]
return Value::INTERVAL(interval);
}
default:
Expand Down Expand Up @@ -515,7 +515,7 @@

if (sop.aggregate().groupings_size() > 0) {
for (auto &sgrp : sop.aggregate().groupings()) {
for (auto &sgrpexpr : sgrp.grouping_expressions()) {

Check warning on line 518 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 518 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]
groups.push_back(TransformExpr(sgrpexpr));
expressions.push_back(TransformExpr(sgrpexpr));
}
Expand Down Expand Up @@ -615,8 +615,8 @@
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
if (!sget.virtual_table().values().empty()) {

Check warning on line 618 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 618 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
auto literal_values = sget.virtual_table().values();

Check warning on line 619 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 619 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
Expand Down Expand Up @@ -725,6 +725,19 @@
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

OnCreateConflict SubstraitToDuckDB::TransformCreateMode(substrait::WriteRel_CreateMode mode) {
switch (mode) {
case substrait::WriteRel::CREATE_MODE_ERROR_IF_EXISTS:
return OnCreateConflict::ERROR_ON_CONFLICT;
case substrait::WriteRel::CREATE_MODE_IGNORE_IF_EXISTS:
return OnCreateConflict::IGNORE_ON_CONFLICT;
case substrait::WriteRel::CREATE_MODE_REPLACE_IF_EXISTS:
return OnCreateConflict::REPLACE_ON_CONFLICT;
default:
throw NotImplementedException("Unsupported on conflict type " + to_string(mode));
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();
auto &nobj = swrite.named_table();
Expand All @@ -738,9 +751,13 @@
schema_name = nobj.names(0);
}
auto input = TransformOp(swrite.input());
auto on_conflict = OnCreateConflict::ERROR_ON_CONFLICT;
if (swrite.create_mode()) {
on_conflict = TransformCreateMode(swrite.create_mode());
}
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
return input->CreateRel(schema_name, table_name, false, on_conflict);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
return input->InsertRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
Expand Down Expand Up @@ -841,7 +858,7 @@
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
const auto create_table = static_cast<CreateTableRelation *>(child.get());
auto proj = make_shared_ptr<ProjectionRelation>(create_table->child, std::move(expressions), aliases);
return proj->CreateRel(create_table->schema_name, create_table->table_name);
return proj->CreateRel(create_table->schema_name, create_table->table_name, create_table->temporary, create_table->on_conflict);
}
default:
return child;
Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SubstraitToDuckDB {
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformWriteOp(const substrait::Rel &sop);
static OnCreateConflict TransformCreateMode(substrait::WriteRel_CreateMode mode);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr,
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::WriteRel_CreateMode TransformOnCreateConflict(OnCreateConflict on_conflict);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
15 changes: 14 additions & 1 deletion src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
} else {
auto interval_day = make_uniq<substrait::Expression_Literal_IntervalDayToSecond>();
interval_day->set_days(dval.GetValue<interval_t>().days);
interval_day->set_microseconds(static_cast<int32_t>(dval.GetValue<interval_t>().micros));

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]
sval.set_allocated_interval_day_to_second(interval_day.release());
}
}
Expand Down Expand Up @@ -1012,7 +1012,7 @@
// TODO push projection or push substrait to allow expressions here
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1015 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
Expand Down Expand Up @@ -1280,7 +1280,7 @@
auto virtual_table = sget->mutable_virtual_table();

// Add a dummy value to emit one row
auto dummy_value = virtual_table->add_values();

Check warning on line 1283 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]

Check warning on line 1283 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]
dummy_value->add_fields()->set_i32(42);
return get_rel;
}
Expand Down Expand Up @@ -1437,6 +1437,19 @@
return rel;
}

substrait::WriteRel_CreateMode DuckDBToSubstrait::TransformOnCreateConflict(OnCreateConflict on_conflict) {
switch(on_conflict) {
case OnCreateConflict::ERROR_ON_CONFLICT:
return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_ERROR_IF_EXISTS;
case OnCreateConflict::IGNORE_ON_CONFLICT:
return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_IGNORE_IF_EXISTS;
case OnCreateConflict::REPLACE_ON_CONFLICT:
return substrait::WriteRel_CreateMode::WriteRel_CreateMode_CREATE_MODE_REPLACE_IF_EXISTS;
default:
throw NotImplementedException("Unknown OnCreateConflict type " + to_string((int)on_conflict));
}
}

substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &create_table = dop.Cast<LogicalCreateTable>();
Expand Down Expand Up @@ -1468,7 +1481,7 @@
auto named_table = write->mutable_named_table();
named_table->add_names(create_info.schema);
named_table->add_names(create_info.table);

write->set_create_mode(TransformOnCreateConflict(create_info.on_conflict));
return rel;
}

Expand Down
38 changes: 38 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,41 @@ TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 0, {2, 6}));
REQUIRE(CHECK_COLUMN(result, 1, {4, 8}));
}

TEST_CASE("Test C CTAS with create_on_conflict via Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

auto res1 = ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS "
"SELECT employee_id, salary FROM employees"
);

auto result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));


REQUIRE_NO_FAIL(ExecuteViaSubstraitJSON(con, "CREATE TABLE IF NOT EXISTS employee_salaries AS "
"SELECT employee_id, department_id, salary FROM employees"));

result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));

auto res3 = ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS "
"SELECT employee_id, department_id, salary FROM employees");
REQUIRE_FAIL(res3);
result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));

REQUIRE_NO_FAIL(ExecuteViaSubstraitJSON(con, "CREATE OR REPLACE TABLE employee_salaries AS "
"SELECT name, salary FROM employees"));

result = con.Query("SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));
}

Loading