diff --git a/bindings/nodejs/tests/binding.js b/bindings/nodejs/tests/binding.js index 80fb2ada..d6b14e78 100644 --- a/bindings/nodejs/tests/binding.js +++ b/bindings/nodejs/tests/binding.js @@ -33,6 +33,13 @@ Then("Select string {string} should be equal to {string}", async function (input assert.equal(output, value); }); +Then("Select types should be expected native types", async function () { + // NumberValue::Decimal + const row = await this.conn.queryRow(`SELECT 15.7563::Decimal(8,4), 2.0+3.0`); + const excepted = ["15.7563", "5.0"]; + assert.deepEqual(row.values(), excepted); +}); + Then("Select numbers should iterate all rows", async function () { let rows = await this.conn.queryIter("SELECT number FROM numbers(5)"); let ret = []; diff --git a/bindings/python/.gitignore b/bindings/python/.gitignore index 65036982..d106553c 100644 --- a/bindings/python/.gitignore +++ b/bindings/python/.gitignore @@ -73,3 +73,5 @@ docs/_build/ # Generated docs docs + +.ruff_cache/ diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 9dd0d479..2884f4a4 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -16,8 +16,9 @@ use std::sync::Arc; use once_cell::sync::Lazy; use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration}; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyTuple}; +use pyo3::sync::GILOnceCell; +use pyo3::types::{PyDict, PyList, PyTuple, PyType}; +use pyo3::{intern, prelude::*}; use pyo3_asyncio::tokio::future_into_py; use tokio::sync::Mutex; use tokio_stream::StreamExt; @@ -29,6 +30,18 @@ pub static VERSION: Lazy = Lazy::new(|| { version.to_string() }); +pub static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new(); + +fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> { + DECIMAL_CLS + .get_or_try_init(py, || { + py.import(intern!(py, "decimal"))? + .getattr(intern!(py, "Decimal"))? + .extract() + }) + .map(|ty| ty.as_ref(py)) +} + pub struct Value(databend_driver::Value); impl IntoPy for Value { @@ -97,12 +110,18 @@ impl IntoPy for NumberValue { databend_driver::NumberValue::Float32(i) => i.into_py(py), databend_driver::NumberValue::Float64(i) => i.into_py(py), databend_driver::NumberValue::Decimal128(_, _) => { - let s = self.0.to_string(); - s.into_py(py) + let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal"); + let ret = dec_cls + .call1((self.0.to_string(),)) + .expect("failed to call decimal.Decimal(value)"); + ret.to_object(py) } databend_driver::NumberValue::Decimal256(_, _) => { - let s = self.0.to_string(); - s.into_py(py) + let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal"); + let ret = dec_cls + .call1((self.0.to_string(),)) + .expect("failed to call decimal.Decimal(value)"); + ret.to_object(py) } } } diff --git a/bindings/python/tests/asyncio/steps/binding.py b/bindings/python/tests/asyncio/steps/binding.py index 5d13b68f..eec63532 100644 --- a/bindings/python/tests/asyncio/steps/binding.py +++ b/bindings/python/tests/asyncio/steps/binding.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from decimal import Decimal from behave import given, when, then from behave.api.async_step import async_run_until_complete @@ -56,6 +57,15 @@ async def _(context, input, output): assert output == value +@then("Select types should be expected native types") +@async_run_until_complete +async def _(context): + # NumberValue::Decimal + row = await context.conn.query_row("SELECT 15.7563::Decimal(8,4), 2.0+3.0") + expected = (Decimal("15.7563"), Decimal("5.0")) + assert row.values() == expected + + @then("Select numbers should iterate all rows") @async_run_until_complete async def _(context): diff --git a/bindings/python/tests/blocking/steps/binding.py b/bindings/python/tests/blocking/steps/binding.py index 071d733a..67ccf27f 100644 --- a/bindings/python/tests/blocking/steps/binding.py +++ b/bindings/python/tests/blocking/steps/binding.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from decimal import Decimal from behave import given, when, then import databend_driver @@ -52,6 +53,14 @@ def _(context, input, output): assert output == value +@then("Select types should be expected native types") +async def _(context): + # NumberValue::Decimal + row = context.conn.query_row("SELECT 15.7563::Decimal(8,4), 2.0+3.0") + expected = (Decimal("15.7563"), Decimal("5.0")) + assert row.values() == expected + + @then("Select numbers should iterate all rows") def _(context): rows = context.conn.query_iter("SELECT number FROM numbers(5)") diff --git a/bindings/tests/features/binding.feature b/bindings/tests/features/binding.feature index 35953ec6..1db9dfb2 100644 --- a/bindings/tests/features/binding.feature +++ b/bindings/tests/features/binding.feature @@ -18,6 +18,10 @@ Feature: Databend Driver Given A new Databend Driver Client Then Select string "Hello, Databend!" should be equal to "Hello, Databend!" + Scenario: Select Types + Given A new Databend Driver Client + Then Select types should be expected native types + Scenario: Select Iter Given A new Databend Driver Client Then Select numbers should iterate all rows