Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bindings/python): convert decimal values to python Decimal #370

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bindings/nodejs/tests/binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ Then("Select string {string} should be equal to {string}", async function (input
assert.equal(output, value);
});

Then("Select types should be expected native types", async function () {
// NumberValue::Decimal
const row = await this.conn.queryRow(`SELECT 15.7563::Decimal(8,4), 2.0+3.0`);
const excepted = ["15.7563", "5.0"];
assert.deepEqual(row.values(), excepted);
});

Then("Select numbers should iterate all rows", async function () {
let rows = await this.conn.queryIter("SELECT number FROM numbers(5)");
let ret = [];
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ docs/_build/

# Generated docs
docs

.ruff_cache/
31 changes: 25 additions & 6 deletions bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use std::sync::Arc;

use once_cell::sync::Lazy;
use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::sync::GILOnceCell;
use pyo3::types::{PyDict, PyList, PyTuple, PyType};
use pyo3::{intern, prelude::*};
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
Expand All @@ -29,6 +30,18 @@ pub static VERSION: Lazy<String> = Lazy::new(|| {
version.to_string()
});

pub static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();

fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> {
DECIMAL_CLS
.get_or_try_init(py, || {
py.import(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()
})
.map(|ty| ty.as_ref(py))
}

pub struct Value(databend_driver::Value);

impl IntoPy<PyObject> for Value {
Expand Down Expand Up @@ -97,12 +110,18 @@ impl IntoPy<PyObject> for NumberValue {
databend_driver::NumberValue::Float32(i) => i.into_py(py),
databend_driver::NumberValue::Float64(i) => i.into_py(py),
databend_driver::NumberValue::Decimal128(_, _) => {
let s = self.0.to_string();
s.into_py(py)
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
let ret = dec_cls
.call1((self.0.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
databend_driver::NumberValue::Decimal256(_, _) => {
let s = self.0.to_string();
s.into_py(py)
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
let ret = dec_cls
.call1((self.0.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions bindings/python/tests/asyncio/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from decimal import Decimal

from behave import given, when, then
from behave.api.async_step import async_run_until_complete
Expand Down Expand Up @@ -56,6 +57,15 @@ async def _(context, input, output):
assert output == value


@then("Select types should be expected native types")
@async_run_until_complete
async def _(context):
# NumberValue::Decimal
row = await context.conn.query_row("SELECT 15.7563::Decimal(8,4), 2.0+3.0")
expected = (Decimal("15.7563"), Decimal("5.0"))
assert row.values() == expected


@then("Select numbers should iterate all rows")
@async_run_until_complete
async def _(context):
Expand Down
9 changes: 9 additions & 0 deletions bindings/python/tests/blocking/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from decimal import Decimal

from behave import given, when, then
import databend_driver
Expand Down Expand Up @@ -52,6 +53,14 @@ def _(context, input, output):
assert output == value


@then("Select types should be expected native types")
async def _(context):
# NumberValue::Decimal
row = context.conn.query_row("SELECT 15.7563::Decimal(8,4), 2.0+3.0")
expected = (Decimal("15.7563"), Decimal("5.0"))
assert row.values() == expected


@then("Select numbers should iterate all rows")
def _(context):
rows = context.conn.query_iter("SELECT number FROM numbers(5)")
Expand Down
4 changes: 4 additions & 0 deletions bindings/tests/features/binding.feature
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Feature: Databend Driver
Given A new Databend Driver Client
Then Select string "Hello, Databend!" should be equal to "Hello, Databend!"

Scenario: Select Types
Given A new Databend Driver Client
Then Select types should be expected native types

Scenario: Select Iter
Given A new Databend Driver Client
Then Select numbers should iterate all rows
Expand Down
Loading