Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented support for the SQL Server sql_variant data type #1354

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dbspecific.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// SQL Server


#define SQL_SS_VARIANT -150 // SQL Server 2008 SQL_VARIANT type
#define SQL_SS_XML -152 // SQL Server 2005 XML type
#define SQL_DB2_DECFLOAT -360 // IBM DB/2 DECFLOAT type
#define SQL_DB2_XML -370 // IBM DB/2 XML type
Expand Down
87 changes: 66 additions & 21 deletions src/getdata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void GetData_init()
}

static byte* ReallocOrFreeBuffer(byte* pb, Py_ssize_t cbNeed);
PyObject *GetData_SqlVariant(Cursor *cur, Py_ssize_t iCol);

inline bool IsBinaryType(SQLSMALLINT sqltype)
{
Expand Down Expand Up @@ -534,28 +535,30 @@ static PyObject* GetDataTimestamp(Cursor* cur, Py_ssize_t iCol)

switch (cur->colinfos[iCol].sql_type)
{
case SQL_TYPE_TIME:
{
int micros = (int)(value.fraction / 1000); // nanos --> micros
return PyTime_FromTime(value.hour, value.minute, value.second, micros);
}

case SQL_TYPE_DATE:
return PyDate_FromDate(value.year, value.month, value.day);

case SQL_TYPE_TIMESTAMP:
{
if (value.year < 1)
{
value.year = 1;
}
else if (value.year > 9999)
{
value.year = 9999;
}
}
case SQL_TYPE_TIME:
{
int micros = (int)(value.fraction / 1000); // nanos --> micros
return PyTime_FromTime(value.hour, value.minute, value.second, micros);
}

case SQL_TYPE_DATE:
case SQL_DATE:
return PyDate_FromDate(value.year, value.month, value.day);

case SQL_TYPE_TIMESTAMP:
case SQL_TIMESTAMP:
{
if (value.year < 1)
{
value.year = 1;
}
else if (value.year > 9999)
{
value.year = 9999;
}
}
}


int micros = (int)(value.fraction / 1000); // nanos --> micros

Expand Down Expand Up @@ -645,6 +648,7 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type)
break;

case SQL_TYPE_DATE:
case SQL_DATE:
pytype = (PyObject*)PyDateTimeAPI->DateType;
break;

Expand All @@ -654,6 +658,7 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type)
break;

case SQL_TYPE_TIMESTAMP:
case SQL_TIMESTAMP:
pytype = (PyObject*)PyDateTimeAPI->DateTimeType;
break;

Expand Down Expand Up @@ -744,15 +749,55 @@ PyObject* GetData(Cursor* cur, Py_ssize_t iCol)
return GetDataDouble(cur, iCol);


case SQL_DATE:
case SQL_TYPE_DATE:
case SQL_TYPE_TIME:
case SQL_TIMESTAMP:
case SQL_TYPE_TIMESTAMP:
return GetDataTimestamp(cur, iCol);

case SQL_SS_TIME2:
return GetSqlServerTime(cur, iCol);

case SQL_SS_VARIANT:
return GetData_SqlVariant(cur, iCol);
}

return RaiseErrorV("HY106", ProgrammingError, "ODBC SQL type %d is not yet supported. column-index=%zd type=%d",
(int)pinfo->sql_type, iCol, (int)pinfo->sql_type);
}

PyObject *GetData_SqlVariant(Cursor *cur, Py_ssize_t iCol) {
char pBuff;

SQLLEN indicator, variantType;
SQLRETURN retcode;

PyObject *decodeResult;

// Call SQLGetData on the current column with a data length of 0. According to MS, this makes
// the ODBC driver read the sql_variant header which contains the underlying data type
pBuff = 0;
indicator = 0;
retcode = SQLGetData(cur->hstmt, static_cast<SQLSMALLINT>(iCol + 1), SQL_C_BINARY,
&pBuff, 0, &indicator);
if (!SQL_SUCCEEDED(retcode))
return RaiseErrorFromHandle(cur->cnxn, "SQLGetData", cur->cnxn->hdbc, cur->hstmt);

// Get the SQL_CA_SS_VARIANT_TYPE field for the column which will contain the underlying data type
variantType = 0;
retcode = SQLColAttribute(cur->hstmt, iCol + 1, SQL_CA_SS_VARIANT_TYPE, NULL, 0, NULL, &variantType);
if (!SQL_SUCCEEDED(retcode))
return RaiseErrorFromHandle(cur->cnxn, "SQLColAttribute", cur->cnxn->hdbc, cur->hstmt);

// Replace the original SQL_VARIANT data type with the underlying data type then call GetData() again
cur->colinfos[iCol].sql_type = static_cast<SQLSMALLINT>(variantType);
decodeResult = GetData(cur, iCol);

// Restore the original SQL_VARIANT data type so that the next decode will call this method again
cur->colinfos[iCol].sql_type = static_cast<SQLSMALLINT>(SQL_SS_VARIANT);

return decodeResult;

// NOTE: we don't free the hstmt here as it's managed by the cursor
}
6 changes: 5 additions & 1 deletion src/pyodbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ typedef unsigned long long UINT64;
#define SQL_CA_SS_CATALOG_NAME 1225
#endif

#ifndef SQL_CA_SS_VARIANT_TYPE
#define SQL_CA_SS_VARIANT_TYPE 1215
#endif

inline bool IsSet(DWORD grf, DWORD flags)
{
return (grf & flags) == flags;
Expand Down Expand Up @@ -117,7 +121,7 @@ inline void DebugTrace(const char* szFmt, ...) { UNUSED(szFmt); }

// issue #880: entry missing from iODBC sqltypes.h
#ifndef BYTE
typedef unsigned char BYTE;
typedef unsigned char BYTE;
#endif
bool PyMem_Realloc(BYTE** pp, size_t newlen);
// A wrapper around realloc with a safer interface. If it is successful, *pp is updated to the
Expand Down
38 changes: 38 additions & 0 deletions tests/sqlserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,44 @@ def test_tvp_diffschema(cursor: pyodbc.Cursor):
_test_tvp(cursor, True)


@pytest.mark.skipif(SQLSERVER_YEAR < 2000, reason='sql_variant not supported until 2000')
def test_sql_variant(cursor: pyodbc.Cursor):
"""
Tests decoding of the sql_variant data type as performed by the GetData_SqlVariant() method.
"""

cursor.execute("create table t1 (a sql_variant)")

# insert a number of values of disparate types. this is not exhaustive as not all
# types that can be contained within a sql_variant field are supported by pyodbc
cursor.execute("insert into t1 values (456.7)")
cursor.execute("insert into t1 values ('a string')")
cursor.execute("insert into t1 values (CAST('2024-06-03' AS DATE))")
cursor.execute("insert into t1 values (CAST('2024-06-03 23:46:03.000' AS DATETIME))")
cursor.execute("insert into t1 values (CAST('binary data' AS VARBINARY(200)))")
cursor.execute(
"insert into t1 values (CAST('0592b437-745f-4b2c-a997-97022c624cf6' AS UNIQUEIDENTIFIER))"
)

# select all of the values we inserted and ensure they have the correct types
results = [record[0] for record in cursor.execute("select a from t1").fetchall()]
for index, assertion_tuple in enumerate(
[
(Decimal, Decimal("456.7")),
(str, "a string"),
(date, date(2024, 6, 3)),
(datetime, datetime(2024, 6, 3, 23, 46, 3)),
(bytes, b'binary data'),
(uuid.UUID, uuid.UUID("0592b437-745f-4b2c-a997-97022c624cf6"))
]
):
# pylint: disable=unidiomatic-typecheck
expected_type, expected_value = assertion_tuple

assert type(results[index]) == expected_type
assert results[index] == expected_value


def get_sqlserver_version(cursor: pyodbc.Cursor):

"""
Expand Down