Skip to content

Commit

Permalink
feat(bindings/python): convert decimal values to python Decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc committed Mar 21, 2024
1 parent 3cc3100 commit 639ad58
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 6 deletions.
7 changes: 7 additions & 0 deletions bindings/nodejs/tests/binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ Then("Select string {string} should be equal to {string}", async function (input
assert.equal(output, value);
});

Then("Select types should be expected native types", async function () {
// NumberValue::Decimal
const row = await this.conn.queryRow(`SELECT 15.7563::Decimal(8,4), 2.0+3.0`);
const excepted = ["15.7563", "5.0"];
assert.equal(row.values(), excepted);
});

Then("Select numbers should iterate all rows", async function () {
let rows = await this.conn.queryIter("SELECT number FROM numbers(5)");
let ret = [];
Expand Down
2 changes: 2 additions & 0 deletions bindings/python/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ docs/_build/

# Generated docs
docs

.ruff_cache/
31 changes: 25 additions & 6 deletions bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ use std::sync::Arc;

use once_cell::sync::Lazy;
use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::sync::GILOnceCell;
use pyo3::types::{PyDict, PyList, PyTuple, PyType};
use pyo3::{intern, prelude::*};
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
Expand All @@ -29,6 +30,18 @@ pub static VERSION: Lazy<String> = Lazy::new(|| {
version.to_string()
});

pub static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();

fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> {
DECIMAL_CLS
.get_or_try_init(py, || {
py.import(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()
})
.map(|ty| ty.as_ref(py))
}

pub struct Value(databend_driver::Value);

impl IntoPy<PyObject> for Value {
Expand Down Expand Up @@ -97,12 +110,18 @@ impl IntoPy<PyObject> for NumberValue {
databend_driver::NumberValue::Float32(i) => i.into_py(py),
databend_driver::NumberValue::Float64(i) => i.into_py(py),
databend_driver::NumberValue::Decimal128(_, _) => {
let s = self.0.to_string();
s.into_py(py)
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
let ret = dec_cls
.call1((self.0.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
databend_driver::NumberValue::Decimal256(_, _) => {
let s = self.0.to_string();
s.into_py(py)
let dec_cls = get_decimal_cls(py).expect("failed to load decimal.Decimal");
let ret = dec_cls
.call1((self.0.to_string(),))
.expect("failed to call decimal.Decimal(value)");
ret.to_object(py)
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions bindings/python/tests/asyncio/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from decimal import Decimal

from behave import given, when, then
from behave.api.async_step import async_run_until_complete
Expand Down Expand Up @@ -56,6 +57,15 @@ async def _(context, input, output):
assert output == value


@then("Select types should be expected native types")
@async_run_until_complete
async def _(context):
# NumberValue::Decimal
row = await context.conn.query_row("SELECT 15.7563::Decimal(8,4), 2.0+3.0")
expected = (Decimal("15.7563"), Decimal("5.0"))
assert row.values() == expected


@then("Select numbers should iterate all rows")
@async_run_until_complete
async def _(context):
Expand Down
9 changes: 9 additions & 0 deletions bindings/python/tests/blocking/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from decimal import Decimal

from behave import given, when, then
import databend_driver
Expand Down Expand Up @@ -52,6 +53,14 @@ def _(context, input, output):
assert output == value


@then("Select types should be expected native types")
async def _(context):
# NumberValue::Decimal
row = context.conn.query_row("SELECT 15.7563::Decimal(8,4), 2.0+3.0")
expected = (Decimal("15.7563"), Decimal("5.0"))
assert row.values() == expected


@then("Select numbers should iterate all rows")
def _(context):
rows = context.conn.query_iter("SELECT number FROM numbers(5)")
Expand Down
4 changes: 4 additions & 0 deletions bindings/tests/features/binding.feature
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Feature: Databend Driver
Given A new Databend Driver Client
Then Select string "Hello, Databend!" should be equal to "Hello, Databend!"

Scenario: Select Types
Given A new Databend Driver Client
Then Select types should be expected native types

Scenario: Select Iter
Given A new Databend Driver Client
Then Select numbers should iterate all rows
Expand Down

0 comments on commit 639ad58

Please sign in to comment.