diff --git a/src/getdata.cpp b/src/getdata.cpp index 3831c5d9..14f60005 100644 --- a/src/getdata.cpp +++ b/src/getdata.cpp @@ -542,9 +542,11 @@ static PyObject* GetDataTimestamp(Cursor* cur, Py_ssize_t iCol) } case SQL_TYPE_DATE: + case SQL_DATE: return PyDate_FromDate(value.year, value.month, value.day); case SQL_TYPE_TIMESTAMP: + case SQL_TIMESTAMP: { if (value.year < 1) { @@ -646,6 +648,7 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type) break; case SQL_TYPE_DATE: + case SQL_DATE: pytype = (PyObject*)PyDateTimeAPI->DateType; break; @@ -655,6 +658,7 @@ PyObject* PythonTypeFromSqlType(Cursor* cur, SQLSMALLINT type) break; case SQL_TYPE_TIMESTAMP: + case SQL_TIMESTAMP: pytype = (PyObject*)PyDateTimeAPI->DateTimeType; break; @@ -745,8 +749,10 @@ PyObject* GetData(Cursor* cur, Py_ssize_t iCol) return GetDataDouble(cur, iCol); + case SQL_DATE: case SQL_TYPE_DATE: case SQL_TYPE_TIME: + case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: return GetDataTimestamp(cur, iCol); @@ -767,6 +773,8 @@ PyObject *GetData_SqlVariant(Cursor *cur, Py_ssize_t iCol) { SQLLEN indicator, variantType; SQLRETURN retcode; + PyObject *decodeResult; + // Call SQLGetData on the current column with a data length of 0. According to MS, this makes // the ODBC driver read the sql_variant header which contains the underlying data type pBuff = 0; @@ -784,7 +792,12 @@ PyObject *GetData_SqlVariant(Cursor *cur, Py_ssize_t iCol) { // Replace the original SQL_VARIANT data type with the underlying data type then call GetData() again cur->colinfos[iCol].sql_type = static_cast(variantType); - return GetData(cur, iCol); + decodeResult = GetData(cur, iCol); + + // Restore the original SQL_VARIANT data type so that the next decode will call this method again + cur->colinfos[iCol].sql_type = static_cast(SQL_SS_VARIANT); + + return decodeResult; // NOTE: we don't free the hstmt here as it's managed by the cursor } diff --git a/tests/sqlserver_test.py b/tests/sqlserver_test.py index 9602de4a..ab51ce0d 100755 --- a/tests/sqlserver_test.py +++ b/tests/sqlserver_test.py @@ -1614,6 +1614,44 @@ def test_tvp_diffschema(cursor: pyodbc.Cursor): _test_tvp(cursor, True) +@pytest.mark.skipif(SQLSERVER_YEAR < 2000, reason='sql_variant not supported until 2000') +def test_sql_variant(cursor: pyodbc.Cursor): + """ + Tests decoding of the sql_variant data type as performed by the GetData_SqlVariant() method. + """ + + cursor.execute("create table t1 (a sql_variant)") + + # insert a number of values of disparate types. this is not exhaustive as not all + # types that can be contained within a sql_variant field are supported by pyodbc + cursor.execute("insert into t1 values (456.7)") + cursor.execute("insert into t1 values ('a string')") + cursor.execute("insert into t1 values (CAST('2024-06-03' AS DATE))") + cursor.execute("insert into t1 values (CAST('2024-06-03 23:46:03.000' AS DATETIME))") + cursor.execute("insert into t1 values (CAST('binary data' AS VARBINARY(200)))") + cursor.execute( + "insert into t1 values (CAST('0592b437-745f-4b2c-a997-97022c624cf6' AS UNIQUEIDENTIFIER))" + ) + + # select all of the values we inserted and ensure they have the correct types + results = [record[0] for record in cursor.execute("select a from t1").fetchall()] + for index, assertion_tuple in enumerate( + [ + (Decimal, Decimal("456.7")), + (str, "a string"), + (date, date(2024, 6, 3)), + (datetime, datetime(2024, 6, 3, 23, 46, 3)), + (bytes, b'binary data'), + (uuid.UUID, uuid.UUID("0592b437-745f-4b2c-a997-97022c624cf6")) + ] + ): + # pylint: disable=unidiomatic-typecheck + expected_type, expected_value = assertion_tuple + + assert type(results[index]) == expected_type + assert results[index] == expected_value + + def get_sqlserver_version(cursor: pyodbc.Cursor): """