Skip to content

Commit

Permalink
Update df(),pl(),arrow() and fetchnumpy()
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibo-Chen13 committed Nov 4, 2024
1 parent f5e394b commit 244eb04
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 2 deletions.
18 changes: 17 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ Insert Data
Pandas DataFrame
----------------
Big fan of Pandas? We too! You can mix SQL and Pandas API together:
Big fan of Pandas? We too! You can mix SQL and Pandas API together. Also you can converting query results to a variety of formats(e.g. Numpy Array, Pandas DataFrame, Polars DataFrame, Arrow Table) by DBAPI.


.. code-block:: python
Expand Down Expand Up @@ -128,3 +129,18 @@ Big fan of Pandas? We too! You can mix SQL and Pandas API together:
df = c.query_dataframe('SELECT * FROM table(test)')
print(df)
print(df.describe())
# Converting query results to a variety of formats with dbapi
with connect('proton://localhost') as conn:
with conn.cursor() as cur:
cur.execute('SELECT * FROM table(test)')
print(cur.df()) # Pandas DataFrame
cur.execute('SELECT * FROM table(test)')
print(cur.fetchnumpy()) # Numpy Arrays
cur.execute('SELECT * FROM table(test)')
print(cur.pl()) # Polars DataFrame
cur.execute('SELECT * FROM table(test)')
print(cur.arrow()) # Arrow Table
21 changes: 20 additions & 1 deletion example/pandas/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import time

from proton_driver import client
from proton_driver import client, connect

if __name__ == "__main__":
c = client.Client(host='127.0.0.1', port=8463)
Expand Down Expand Up @@ -37,3 +37,22 @@
df = c.query_dataframe('SELECT * FROM table(test)')
print(df)
print(df.describe())

# Converting query results to a variety of formats with dbapi
with connect('proton://localhost') as conn:
with conn.cursor() as cur:
cur.execute('SELECT * FROM table(test)')
print('--------------Pandas DataFrame--------------')
print(cur.df())

cur.execute('SELECT * FROM table(test)')
print('----------------Numpy Arrays----------------')
print(cur.fetchnumpy())

cur.execute('SELECT * FROM table(test)')
print('--------------Polars DataFrame--------------')
print(cur.pl())

cur.execute('SELECT * FROM table(test)')
print('-----------------Arrow Table----------------')
print(cur.arrow())
72 changes: 72 additions & 0 deletions proton_driver/dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,78 @@ def fetchall(self):
self._rows = []
return rv

def df(self):
"""
Fetch all (remaining) rows of a query result, returning them as
a pandas DataFrame.
:return: Pandas DataFrame of fetched rows.
"""
self._check_query_started()

import pandas as pd

rv = pd.DataFrame({
name: [row[i] for row in self._rows] if name else None
for i, name in enumerate(self._columns)
})
self._rows = []
return rv

def fetchnumpy(self):
"""
Fetch all (remaining) rows of a query result, returning
them as a dictionary of NumPy arrays.
:return: Dictionary of NumPy arrays of fetched rows.
"""
self._check_query_started()

import numpy as np

rv = {
name: np.array([row[i] for row in self._rows]) if name else None
for i, name in enumerate(self._columns)
}
self._rows = []
return rv

def pl(self):
"""
Fetch all (remaining) rows of a query result, returning them as
a Polars DataFrame.
:return: Polars DataFrame of fetched rows.
"""
self._check_query_started()

import polars as pl

rv = pl.DataFrame({
name: [row[i] for row in self._rows] if name else None
for i, name in enumerate(self._columns)
})
self._rows = []
return rv

def arrow(self):
"""
Fetch all (remaining) rows of a query result, returning them as
a Arrow Table.
:return: Arrow Table of fetched rows.
"""
self._check_query_started()

import pyarrow as pa

rv = pa.table({
name: [row[i] for row in self._rows] if name else None
for i, name in enumerate(self._columns)
})
self._rows = []
return rv

def setinputsizes(self, sizes):
# Do nothing.
pass
Expand Down
110 changes: 110 additions & 0 deletions tests/numpy/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from tests.testcase import BaseTestCase
from tests.numpy.testcase import NumpyBaseTestCase
from proton_driver import connect
from datetime import datetime
from decimal import Decimal


class GenericTestCase(NumpyBaseTestCase):
Expand Down Expand Up @@ -171,3 +174,110 @@ def test_query_dataframe(self):
self.assertEqual(
'Extras for NumPy must be installed', str(e.exception)
)


class DataFrameDBAPITestCase(NumpyBaseTestCase):
types = \
'a int64, b string, c datetime,' \
'd fixed_string(10), e decimal(9, 5), f float64,' \
'g low_cardinality(string), h nullable(int32)'

columns = 'a,b,c,d,e,f,g,h'
data = [
[
123, 'abc', datetime(2024, 5, 20, 12, 11, 10),
'abcefgcxxx', Decimal('300.42'), 3.402823e12,
'127001', 332
],
[
456, 'cde', datetime(2024, 6, 21, 12, 13, 50),
'1234567890', Decimal('171.31'), -3.4028235e13,
'127001', None
],
[
789, 'efg', datetime(1998, 7, 22, 12, 30, 10),
'stream sql', Decimal('894.22'), float('inf'),
'127001', None
],
]

def setUp(self):
super(DataFrameDBAPITestCase, self).setUp()
self.conn = connect('proton://localhost')
self.cur = self.conn.cursor()
self.cur.execute('DROP STREAM IF EXISTS test')
self.cur.execute(f'CREATE STREAM test ({self.types}) ENGINE = Memory')
self.cur.execute(
f'INSERT INTO test ({self.columns}) VALUES',
self.data
)
self.cur.execute(f'SELECT {self.columns} FROM test')

def tearDown(self):
super(DataFrameDBAPITestCase, self).tearDown()
self.cur.execute('DROP STREAM test')

def test_dbapi_fetchnumpy(self):
expect = {
col: np.array([row[i] for row in self.data])
for i, col in enumerate(self.columns.split(','))
}
rv = self.cur.fetchnumpy()
for key, value in expect.items():
self.assertIsNotNone(rv.get(key))
self.assertarraysEqual(value, rv[key])

def test_dbapi_df(self):
expect = pd.DataFrame(self.data, columns=self.columns.split(','))
df = self.cur.df()

self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(df.shape, (3, 8))
self.assertEqual(
[type.name for type in df.dtypes],
['int64', 'object', 'datetime64[ns]',
'object', 'object', 'float64',
'object', 'float64']
)
self.assertTrue(expect.equals(df))

def test_dbapi_pl(self):
try:
import polars as pl
except ImportError:
self.skipTest('Polars extras are not installed')

expect = pl.DataFrame({
col: [row[i] for row in self.data]
for i, col in enumerate(self.columns.split(','))
})

df = self.cur.pl()
self.assertIsInstance(df, pl.DataFrame)
self.assertEqual(df.shape, (3, 8))
self.assertSequenceEqual(
df.schema.dtypes(),
[pl.Int64, pl.String, pl.Datetime, pl.String,
pl.Decimal, pl.Float64, pl.String, pl.Int64]
)
self.assertTrue(expect.equals(df))

def test_dbapi_arrow(self):
try:
import pyarrow as pa
except ImportError:
self.skipTest('Pyarrow extras are not installed')

expect = pa.table({
col: [row[i] for row in self.data]
for i, col in enumerate(self.columns.split(','))
})
at = self.cur.arrow()
self.assertEqual(at.shape, (3, 8))
self.assertSequenceEqual(
at.schema.types,
[pa.int64(), pa.string(), pa.timestamp('us'),
pa.string(), pa.decimal128(5, 2), pa.float64(),
pa.string(), pa.int64()]
)
self.assertTrue(expect.equals(at))

0 comments on commit 244eb04

Please sign in to comment.