diff --git a/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java b/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java index f02fe00..98fb525 100755 --- a/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java +++ b/calcite-rs-jni/jni-arrow/src/main/java/com/hasura/ArrowJdbcWrapper.java @@ -77,8 +77,7 @@ public VectorSchemaRoot executeQuery(String query) throws Exception { try { ArrowResultSet resultSet = executeQueryBatched(query); VectorSchemaRoot result = resultSet.nextBatch(); - resultSet.close(); - logger.info("Successfully executed query and got results"); + logger.info("Successfully executed query and got results. Remember to close the vector root to release memory."); return result; } catch (Exception e) { logger.severe("Error executing query: " + e.getMessage()); diff --git a/calcite-rs-jni/odbc/.run/DDN-ODBC-Tester.run.xml b/calcite-rs-jni/odbc/.run/DDN-ODBC-Tester.run.xml new file mode 100755 index 0000000..bb53872 --- /dev/null +++ b/calcite-rs-jni/odbc/.run/DDN-ODBC-Tester.run.xml @@ -0,0 +1,21 @@ + + + + \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt b/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt index 47dacfa..c225c03 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/CMakeLists.txt @@ -101,6 +101,7 @@ set(SOURCE_FILES src/SQLGetStmtAttr.cpp src/SQLSetDescField.cpp src/SQLGetTypeInfo.cpp + src/SQLBindParameter.cpp ) # Header files @@ -110,6 +111,8 @@ set(HEADER_FILES include/globals.hpp include/logging.hpp include/statement.hpp + include/JniParam.hpp + src/JniParam.cpp ) # Function to filter files starting with `._` @@ -125,6 +128,8 @@ endfunction() filter_files(SOURCE_FILES) filter_files(HEADER_FILES) +add_definitions(-D_SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING) + # Create the library add_library(${PROJECT_NAME} SHARED ${SOURCE_FILES} diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/JniParam.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/JniParam.hpp new file mode 100755 index 0000000..c9496a7 --- /dev/null +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/JniParam.hpp @@ -0,0 +1,63 @@ +// +// Created by kennethstott on 11/26/2024. +// + +#ifndef JNIPARAM_H +#define JNIPARAM_H + +#endif //JNIPARAM_H + +#pragma once + +#include +#include +#include + +class JniParam { +public: + enum class Type { + String, + StringArray, + Integer, + Float, + Double, + Boolean + }; + + // Constructors for different types + explicit JniParam(const std::string& value); + explicit JniParam(const std::vector& value); + explicit JniParam(int value); + explicit JniParam(float value); + explicit JniParam(double value); + explicit JniParam(bool value); + + JniParam(); + + // Get the JNI signature for this type + std::string getSignature() const; + + // Convert the parameter to JNI value + jvalue toJValue(JNIEnv* env) const; + + // Clean up any JNI resources + void cleanup(JNIEnv* env, const jvalue& value) const; + + // Get the parameter type + [[nodiscard]] Type getType() const { return type_; } + [[nodiscard]] std::string getString() const { return stringValue_; } + [[nodiscard]] std::vector getStringArray() const { return stringArrayValue_; } + [[nodiscard]] int getInt() const { return intValue_; } + [[nodiscard]] float getFloat() const { return floatValue_; } + [[nodiscard]] double getDouble() const { return doubleValue_; } + [[nodiscard]] bool getBool() const { return boolValue_; } + +private: + Type type_; + std::string stringValue_; + std::vector stringArrayValue_; + int intValue_ = 0; + float floatValue_ = 0.0f; + double doubleValue_ = 0.0; + bool boolValue_ = false; +}; \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp index 1dfd875..50b0dc7 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/connection.hpp @@ -1,93 +1,120 @@ #pragma once - -// Windows includes must come first -#define WIN32_LEAN_AND_MEAN +#include #include +#define WIN32_LEAN_AND_MEAN -// SQL includes +// ODBC includes #include +// JNI includes +#include // Standard includes +#include #include +#include - -// JNI includes -#include - +#include "JniParam.hpp" #include "statement.hpp" // Forward declarations -class Environment; +class Statement; + +struct ConnectionParams { + std::string server; + std::string port; + std::string database; + std::string role; + std::string auth; + std::string uid; + std::string pwd; + std::string encrypt = "no"; + std::string timeout; + + bool isValid() const { + return !server.empty() && !port.empty() && !database.empty(); + } +}; + +// struct ColumnDesc { +// std::string name; +// SQLSMALLINT sqlType; +// SQLULEN columnSize; +// SQLSMALLINT decimalDigits; +// SQLSMALLINT nullable; +// }; class Connection { -private: - JavaVM* jvm{}; - jobject wrapperInstance{}; - jclass wrapperClass{}; - std::string connectionString; - bool isConnected{}; - - struct ConnectionParams { - std::string server; - std::string port; - std::string database; - std::string role; - std::string auth; - std::string uid; - std::string pwd; - std::string encrypt; - std::string timeout; - - [[nodiscard]] bool isValid() const { - return !server.empty() && !port.empty() && !database.empty(); - } - }; +public: + Connection() = default; + ~Connection(); - bool initJVM(); - bool initWrapper(const ConnectionParams& params); + // Delete copy constructor and assignment operator + Connection(const Connection&) = delete; + Connection& operator=(const Connection&) = delete; - static bool parseConnectionString(const std::string& connStr, ConnectionParams& params); + // Connection string methods + void setConnectionString(const std::string& dsn, const std::string& uid, const std::string& authStr); + void setConnectionString(const std::string& dsn); - static std::string buildJdbcUrl(const ConnectionParams ¶ms); - std::vector activeStmts; + // Connection management + SQLRETURN connect(); + SQLRETURN disconnect(); -public: + // Query methods + SQLRETURN Query(const std::string& query, Statement* stmt); + SQLRETURN GetTables( + const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType, + Statement* stmt) const; + SQLRETURN GetColumns( + const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName, + Statement* stmt) const; + + // Statement management bool hasActiveStmts() const; - void cleanupActiveStmts(); - SQLRETURN GetTables(const std::string &catalogName, const std::string &schemaName, const std::string &tableName, - const std::string &tableType, Statement *stmt) const; - - SQLRETURN GetColumns(const std::string &catalogName, const std::string &schemaName, const std::string &tableName, - const std::string &columnName, Statement *stmt) const; - - Connection() = default; - ~Connection(); - - SQLRETURN connect(); + // Get connection state + bool isConnected() const { return connected; } + JNIEnv* env = nullptr; + [[nodiscard]] const std::string& getConnectionString() const { return connectionString; } - void setConnectionString(const std::string &dsn, const std::string &uid, const std::string &authStr); - void setConnectionString(const std::string &dsn); +private: + // Connection state + bool connected = false; + std::string connectionString; + std::vector activeStmts; + // JVM/JNI state + JavaVM* jvm = nullptr; + jclass wrapperClass = nullptr; + jobject wrapperInstance = nullptr; + // Initialization helpers + bool initJVM(); + bool initWrapper(const ConnectionParams& params); + bool parseConnectionString(const std::string& connStr, ConnectionParams& params); + std::string buildJdbcUrl(const ConnectionParams& params); + SQLRETURN populateColumnDescriptors(jobject schemaRoot, Statement *stmt) const; - // Delete copy constructor and assignment operator - Connection(const Connection&) = delete; - Connection& operator=(const Connection&) = delete; - - SQLRETURN disconnect(); - JNIEnv* env{}; + // Result set handling + SQLRETURN executeAndGetArrowResult( + const char *methodName, + const std::vector ¶ms, + Statement *stmt) const; - // Getters - [[nodiscard]] JavaVM* getJVM() const { return jvm; } - [[nodiscard]] jobject getWrapperInstance() const { return wrapperInstance; } - [[nodiscard]] jclass getWrapperClass() const { return wrapperClass; } - [[nodiscard]] const std::string& getConnectionString() const { return connectionString; } + // Type mapping helpers + static SQLSMALLINT mapArrowTypeToSQL(JNIEnv* env, jobject arrowType); + static SQLULEN getSQLTypeSize(SQLSMALLINT sqlType); }; // Helper functions declarations -std::string GetModuleDirectory(HMODULE hModule); +std::string GetModuleDirectory(); std::string WideStringToString(const std::wstring& wstr); \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp index 0034a45..61d2ce9 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/environment.hpp @@ -1,5 +1,7 @@ #pragma once +#include #include +#define WIN32_LEAN_AND_MEAN #include #include #include diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp index 798ce37..51c5aed 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/include/statement.hpp @@ -7,6 +7,8 @@ #include #include // Add this if you're using JNI types +#include "JniParam.hpp" + // Forward declaration of Connection class class Connection; @@ -17,6 +19,27 @@ struct ColumnDesc { SQLULEN columnSize; SQLSMALLINT decimalDigits; SQLSMALLINT nullable; + const char* catalogName; + SQLSMALLINT catalogNameLength; + const char* schemaName; + SQLSMALLINT schemaNameLength; + const char* tableName; + SQLSMALLINT tableNameLength; + const char* baseColumnName; + SQLSMALLINT baseColumnNameLength; + const char* baseTableName; + SQLSMALLINT baseTableNameLength; + const char* literalPrefix; + SQLSMALLINT literalPrefixLength; + const char* literalSuffix; + SQLSMALLINT literalSuffixLength; + const char* localTypeName; + SQLSMALLINT localTypeNameLength; + SQLSMALLINT unnamed; + const char* label; + SQLSMALLINT labelLength; + SQLULEN displaySize; + SQLSMALLINT scale; }; struct ColumnData { @@ -25,9 +48,11 @@ struct ColumnData { }; class Statement { -public: - std::vector setupColumnResultColumns(); +private: + std::vector boundParams; + std::string originalQuery; +public: explicit Statement(Connection* connection); // Delete copy constructor and assignment operator @@ -36,11 +61,17 @@ class Statement { void clearResults(); - std::vector setupTableResultColumns(); + SQLRETURN bindParameter(SQLUSMALLINT parameterNumber, SQLSMALLINT inputOutputType, SQLSMALLINT valueType, + SQLSMALLINT parameterType, SQLULEN columnSize, SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, SQLLEN bufferLength, SQLLEN *strLen_or_IndPtr); + + std::string escapeString(const std::string &str) const; + std::string buildInterpolatedQuery() const; SQLRETURN setArrowResult(jobject schemaRoot, const std::vector &columnDescriptors); + SQLRETURN setOriginalQuery(const std::string &query) { originalQuery = query; return SQL_SUCCESS; } SQLRETURN getData(SQLUSMALLINT colNum, SQLSMALLINT targetType, - SQLPOINTER targetValue, SQLLEN bufferLength, - SQLLEN* strLengthOrIndicator); + SQLPOINTER targetValue, SQLLEN bufferLength, + SQLLEN* strLengthOrIndicator); SQLRETURN fetch(); [[nodiscard]] SQLRETURN getFetchStatus() const; [[nodiscard]] bool hasData() const; diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def b/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def index 4ee45f1..166dcd4 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/odbc_driver.def @@ -10,8 +10,8 @@ EXPORTS SQLDisconnect SQLGetInfo=SQLGetInfo_A SQLGetInfoW=SQLGetInfo_W - SQLExecDirect=SQLExecDirect_A - SQLExecDirectW=SQLExecDirect_W + SQLExecDirect + SQLExecDirectW SQLSetDescField SQLSetDescFieldW SQLTables=SQLTables_A @@ -27,6 +27,7 @@ EXPORTS SQLColAttributeW SQLFetch SQLGetTypeInfo + SQLBindParameter SQLGetTypeInfoW SQLGetData=SQLGetData_A SQLGetDataW=SQLGetData_W diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/JniParam.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/JniParam.cpp new file mode 100755 index 0000000..65f2d88 --- /dev/null +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/JniParam.cpp @@ -0,0 +1,136 @@ +#include "../include/JniParam.hpp" +#include "../include/logging.hpp" + +JniParam::JniParam(const std::string& value) + : type_(Type::String) + , stringValue_(value) +{} + +JniParam::JniParam(const std::vector& value) + : type_(Type::StringArray) + , stringArrayValue_(value) +{} + +JniParam::JniParam(int value) + : type_(Type::Integer) + , intValue_(value) +{} + +JniParam::JniParam(float value) + : type_(Type::Float) + , floatValue_(value) +{} + +JniParam::JniParam(double value) + : type_(Type::Double) + , doubleValue_(value) +{} + +JniParam::JniParam(bool value) + : type_(Type::Boolean) + , boolValue_(value) +{} + +JniParam::JniParam() {} + +std::string JniParam::getSignature() const { + switch (type_) { + case Type::String: + return "Ljava/lang/String;"; + case Type::StringArray: + return "[Ljava/lang/String;"; + case Type::Integer: + return "I"; + case Type::Float: + return "F"; + case Type::Double: + return "D"; + case Type::Boolean: + return "Z"; + default: + return ""; + } +} + +jvalue JniParam::toJValue(JNIEnv* env) const { + jvalue val{}; + try { + switch (type_) { + case Type::String: + val.l = stringValue_.empty() ? + nullptr : + env->NewStringUTF(stringValue_.c_str()); + LOGF("Created jstring from: %s", stringValue_.c_str()); + break; + + case Type::StringArray: { + jclass stringClass = env->FindClass("java/lang/String"); + if (!stringClass) { + LOG("Failed to find String class"); + break; + } + + jobjectArray arr = env->NewObjectArray( + stringArrayValue_.size(), stringClass, nullptr); + if (!arr) { + LOG("Failed to create String array"); + env->DeleteLocalRef(stringClass); + break; + } + + for (size_t i = 0; i < stringArrayValue_.size(); i++) { + jstring str = env->NewStringUTF(stringArrayValue_[i].c_str()); + if (str) { + env->SetObjectArrayElement(arr, i, str); + env->DeleteLocalRef(str); + } + } + + val.l = arr; + env->DeleteLocalRef(stringClass); + LOGF("Created String array with %zu elements", stringArrayValue_.size()); + break; + } + + case Type::Integer: + val.i = intValue_; + LOGF("Set integer value: %d", intValue_); + break; + + case Type::Float: + val.f = floatValue_; + LOGF("Set float value: %f", floatValue_); + break; + + case Type::Double: + val.d = doubleValue_; + LOGF("Set double value: %f", doubleValue_); + break; + + case Type::Boolean: + val.z = boolValue_; + LOGF("Set boolean value: %d", boolValue_); + break; + } + } + catch (const std::exception& e) { + LOGF("Exception in toJValue: %s", e.what()); + // Ensure val is zeroed in case of error + val = jvalue{}; + } + return val; +} + +void JniParam::cleanup(JNIEnv* env, const jvalue& value) const { + try { + if (type_ == Type::String || type_ == Type::StringArray) { + if (value.l != nullptr) { + env->DeleteLocalRef(value.l); + LOG("Cleaned up JNI reference"); + } + } + } + catch (const std::exception& e) { + LOGF("Exception in cleanup: %s", e.what()); + } +} \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLBindParameter.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLBindParameter.cpp new file mode 100755 index 0000000..3bc5a9d --- /dev/null +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLBindParameter.cpp @@ -0,0 +1,42 @@ +#include +#include +#define WIN32_LEAN_AND_MEAN +#include "../include/connection.hpp" +#include "../include/logging.hpp" +//#include "../include/httplib.h" +#include "../include/globals.hpp" +#include "../include/environment.hpp" +#include "../include/statement.hpp" + +extern "C" { + SQLRETURN SQL_API SQLBindParameter( + SQLHSTMT StatementHandle, + SQLUSMALLINT ParameterNumber, + SQLSMALLINT InputOutputType, + SQLSMALLINT ValueType, + SQLSMALLINT ParameterType, + SQLULEN ColumnSize, + SQLSMALLINT DecimalDigits, + SQLPOINTER ParameterValuePtr, + SQLLEN BufferLength, + SQLLEN *StrLen_or_IndPtr + ) { + if (!StatementHandle) { + return SQL_ERROR; + } + + Statement* stmt = reinterpret_cast(StatementHandle); + + return stmt->bindParameter( + ParameterNumber, + InputOutputType, + ValueType, + ParameterType, + ColumnSize, + DecimalDigits, + ParameterValuePtr, + BufferLength, + StrLen_or_IndPtr + ); + } +} \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp index d16650d..cecb160 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLColAttribute.cpp @@ -10,222 +10,244 @@ // Column attribute functions extern "C" { - SQLRETURN SQL_API SQLColAttribute( - SQLHSTMT StatementHandle, - SQLUSMALLINT ColumnNumber, - SQLUSMALLINT FieldIdentifier, - SQLPOINTER CharacterAttribute, - SQLSMALLINT BufferLength, - SQLSMALLINT* StringLength, - SQLLEN* NumericAttribute) { - - LOGF("SQLColAttribute called - Column: %u, Field: %u", ColumnNumber, FieldIdentifier); - auto* stmt = static_cast(StatementHandle); - if (!stmt) { - LOG("Invalid statement handle"); - return SQL_INVALID_HANDLE; - } - - if (!stmt->hasResult) { - LOG("No result set available"); - return SQL_ERROR; - } - - LOGF("Accessing column %u of %zu", ColumnNumber, stmt->resultColumns.size()); - if (ColumnNumber <= 0 || ColumnNumber > stmt->resultColumns.size()) { - LOG("Invalid column number"); - return SQL_ERROR; - } - - const auto& col = stmt->resultColumns[ColumnNumber - 1]; - LOGF("Column name: %s", col.name); - - switch (FieldIdentifier) { + SQLHSTMT StatementHandle, + SQLUSMALLINT ColumnNumber, + SQLUSMALLINT FieldIdentifier, + SQLPOINTER CharacterAttribute, + SQLSMALLINT BufferLength, + SQLSMALLINT *StringLength, + SQLLEN *NumericAttribute) { + auto *stmt = static_cast(StatementHandle); + LOGF("SQLColAttribute called - Column: %u, Field: %u, Total Size: %u", ColumnNumber, FieldIdentifier, stmt->resultColumns.size()); + if (!stmt) { + LOG("Invalid statement handle"); + return SQL_INVALID_HANDLE; + } + + if (!stmt->hasResult) { + LOG("No result set available"); + return SQL_ERROR; + } + + LOGF("Accessing column %u of %zu", ColumnNumber, stmt->resultColumns.size()); + if (ColumnNumber <= 0 || ColumnNumber > stmt->resultColumns.size()) { + LOG("Invalid column number"); + return SQL_ERROR; + } + + const auto &col = stmt->resultColumns[ColumnNumber - 1]; + + switch (FieldIdentifier) { // Basic column metadata - case SQL_COLUMN_COUNT: // 0 + case SQL_COLUMN_COUNT: // 0 if (NumericAttribute) { *NumericAttribute = stmt->resultColumns.size(); } break; - case SQL_COLUMN_NAME: // 1 - case SQL_DESC_NAME: // SQL_COLUMN_NAME - case SQL_DESC_LABEL: // SQL_COLUMN_LABEL + case SQL_COLUMN_NAME: // 1 + case SQL_DESC_NAME: // SQL_COLUMN_NAME + if (CharacterAttribute && BufferLength > 0) { + LOGF("Returning column name: %s, %u", col.name, col.nameLength); + strncpy((char *) CharacterAttribute, col.name, BufferLength); + if (StringLength) { + *StringLength = col.nameLength; + } + } else { + LOG("Did not return column name"); + } + break; + + case SQL_DESC_LABEL: // SQL_COLUMN_LABEL if (CharacterAttribute && BufferLength > 0) { - strncpy((char*)CharacterAttribute, col.name, BufferLength); + strncpy((char *) CharacterAttribute, col.label, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(col.name)); + *StringLength = col.labelLength; } } break; // Type information - case SQL_COLUMN_TYPE: // 2 - case SQL_DESC_TYPE: // 1002 - // case SQL_DESC_CONCISE_TYPE: // SQL_COLUMN_TYPE + case SQL_COLUMN_TYPE: // 2 + case SQL_DESC_TYPE: // 1002 + // case SQL_DESC_CONCISE_TYPE: // SQL_COLUMN_TYPE if (NumericAttribute) { *NumericAttribute = col.sqlType; } break; // Size and length information - case SQL_COLUMN_LENGTH: // 3 - case SQL_DESC_LENGTH: // SQL_COLUMN_LENGTH - case SQL_DESC_OCTET_LENGTH: // 1013 + case SQL_COLUMN_LENGTH: // 3 + case SQL_DESC_LENGTH: // SQL_COLUMN_LENGTH + case SQL_DESC_OCTET_LENGTH: // 1013 if (NumericAttribute) { *NumericAttribute = col.columnSize; } break; - case SQL_COLUMN_DISPLAY_SIZE: // 6 - case SQL_COLUMN_PRECISION: // 4 - case SQL_DESC_PRECISION: // SQL_COLUMN_PRECISION - case SQL_COLUMN_SCALE: // 5 - case SQL_DESC_SCALE: // SQL_COLUMN_SCALE + case SQL_COLUMN_DISPLAY_SIZE: // 6 + if (NumericAttribute) { + *NumericAttribute = col.displaySize; + } + break; + + case SQL_COLUMN_PRECISION: // 4 + case SQL_DESC_PRECISION: // SQL_COLUMN_PRECISION if (NumericAttribute) { - *NumericAttribute = 0; // VARCHAR doesn't have scale + *NumericAttribute = col.decimalDigits; + } + break; + + case SQL_COLUMN_SCALE: // 5 + case SQL_DESC_SCALE: // SQL_COLUMN_SCALE + if (NumericAttribute) { + *NumericAttribute = col.scale; // VARCHAR doesn't have scale } break; // Nullability - case SQL_COLUMN_NULLABLE: // 7 - case SQL_DESC_NULLABLE: // SQL_COLUMN_NULLABLE + case SQL_COLUMN_NULLABLE: // 7 + case SQL_DESC_NULLABLE: // SQL_COLUMN_NULLABLE if (NumericAttribute) { *NumericAttribute = col.nullable; } break; // Type characteristics - case SQL_COLUMN_UNSIGNED: // 8 - // case SQL_DESC_UNSIGNED: // SQL_COLUMN_UNSIGNED + case SQL_COLUMN_UNSIGNED: // 8 + // case SQL_DESC_UNSIGNED: // SQL_COLUMN_UNSIGNED if (NumericAttribute) { - *NumericAttribute = SQL_TRUE; // VARCHAR is unsigned + *NumericAttribute = SQL_TRUE; // VARCHAR is unsigned } break; - case SQL_COLUMN_MONEY: // 9 - case SQL_COLUMN_AUTO_INCREMENT: // 11 - // case SQL_DESC_AUTO_UNIQUE_VALUE: // SQL_COLUMN_AUTO_INCREMENT - if (NumericAttribute) { - *NumericAttribute = SQL_FALSE; // VARCHAR is not auto-increment - } - break; + case SQL_COLUMN_MONEY: // 9 + case SQL_COLUMN_AUTO_INCREMENT: // 11 + // case SQL_DESC_AUTO_UNIQUE_VALUE: // SQL_COLUMN_AUTO_INCREMENT + if (NumericAttribute) { + *NumericAttribute = SQL_FALSE; // VARCHAR is not auto-increment + } + break; - case SQL_COLUMN_UPDATABLE: // 10 - // case SQL_DESC_UPDATABLE: // SQL_COLUMN_UPDATABLE + case SQL_COLUMN_UPDATABLE: // 10 + // case SQL_DESC_UPDATABLE: // SQL_COLUMN_UPDATABLE if (NumericAttribute) { - *NumericAttribute = SQL_ATTR_READONLY; // Catalog results are read-only + *NumericAttribute = SQL_ATTR_READONLY; // Catalog results are read-only } break; - case SQL_COLUMN_CASE_SENSITIVE: // 12 - // case SQL_DESC_CASE_SENSITIVE: // SQL_COLUMN_CASE_SENSITIVE + case SQL_COLUMN_CASE_SENSITIVE: // 12 + // case SQL_DESC_CASE_SENSITIVE: // SQL_COLUMN_CASE_SENSITIVE if (NumericAttribute) { - *NumericAttribute = SQL_TRUE; // VARCHAR is case sensitive + *NumericAttribute = SQL_TRUE; // VARCHAR is case sensitive } break; - case SQL_COLUMN_SEARCHABLE: // 13 - // case SQL_DESC_SEARCHABLE: // SQL_COLUMN_SEARCHABLE + case SQL_COLUMN_SEARCHABLE: // 13 + // case SQL_DESC_SEARCHABLE: // SQL_COLUMN_SEARCHABLE if (NumericAttribute) { - *NumericAttribute = SQL_SEARCHABLE; // VARCHAR is fully searchable + *NumericAttribute = SQL_SEARCHABLE; // VARCHAR is fully searchable } break; // Type names and descriptions - case SQL_COLUMN_TYPE_NAME: // 14 - // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME + case SQL_COLUMN_TYPE_NAME: // 14 + // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME if (CharacterAttribute && BufferLength > 0) { - const char* typeName = "VARCHAR"; - strncpy((char*)CharacterAttribute, typeName, BufferLength); + strncpy((char *) CharacterAttribute, "VARCHAR", BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(typeName)); + *StringLength = 7; } } break; // Table information - case SQL_COLUMN_TABLE_NAME: // 15 - // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME + case SQL_COLUMN_TABLE_NAME: // 15 + // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME if (CharacterAttribute && BufferLength > 0) { - const char* tableName = ""; // Catalog functions typically don't set this - strncpy((char*)CharacterAttribute, tableName, BufferLength); + strncpy((char *) CharacterAttribute, col.tableName, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(tableName)); + *StringLength = col.tableNameLength; } } break; - case SQL_COLUMN_OWNER_NAME: // 16 - // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME + case SQL_COLUMN_OWNER_NAME: // 16 + // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME if (CharacterAttribute && BufferLength > 0) { - const char* schemaName = ""; // Catalog functions typically don't set this - strncpy((char*)CharacterAttribute, schemaName, BufferLength); + strncpy((char *) CharacterAttribute, col.schemaName, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(schemaName)); + *StringLength = col.schemaNameLength; } } break; - case SQL_COLUMN_QUALIFIER_NAME: // 17 - // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME + case SQL_COLUMN_QUALIFIER_NAME: // 17 + // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME if (CharacterAttribute && BufferLength > 0) { - const char* catalogName = ""; // Catalog functions typically don't set this - strncpy((char*)CharacterAttribute, catalogName, BufferLength); + strncpy((char *) CharacterAttribute, col.catalogName, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(catalogName)); + *StringLength = col.catalogNameLength; } } break; // SQL literal formatting - case SQL_DESC_LITERAL_PREFIX: // 27 + case SQL_DESC_LITERAL_PREFIX: // 27 if (CharacterAttribute && BufferLength > 0) { - const char* prefix = "'"; // VARCHAR uses single quotes - strncpy((char*)CharacterAttribute, prefix, BufferLength); + strncpy((char *) CharacterAttribute, col.literalPrefix, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(prefix)); + *StringLength = col.literalPrefixLength; } } break; - case SQL_DESC_LITERAL_SUFFIX: // 28 + case SQL_DESC_LITERAL_SUFFIX: // 28 if (CharacterAttribute && BufferLength > 0) { - const char* suffix = "'"; // VARCHAR uses single quotes - strncpy((char*)CharacterAttribute, suffix, BufferLength); + strncpy((char *) CharacterAttribute, col.literalSuffix, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(suffix)); + *StringLength = col.literalSuffixLength; } } break; - // // Additional descriptors - // case SQL_DESC_FIXED_PREC_SCALE: // 1108 - // if (NumericAttribute) { - // *NumericAttribute = SQL_FALSE; // VARCHAR is not fixed precision - // } - // break; - - case SQL_DESC_LOCAL_TYPE_NAME: // 29 + case SQL_DESC_LOCAL_TYPE_NAME: // 29 if (CharacterAttribute && BufferLength > 0) { - const char* localTypeName = "VARCHAR"; // Usually same as type name - strncpy((char*)CharacterAttribute, localTypeName, BufferLength); + strncpy((char *) CharacterAttribute, col.localTypeName, BufferLength); if (StringLength) { - *StringLength = static_cast(strlen(localTypeName)); + *StringLength = col.localTypeNameLength; } } break; - case SQL_DESC_NUM_PREC_RADIX: // 32 + case SQL_DESC_NUM_PREC_RADIX: // 32 if (NumericAttribute) { - *NumericAttribute = 0; // Not applicable for VARCHAR + *NumericAttribute = 0; // Not applicable for VARCHAR } break; - case SQL_DESC_UNNAMED: // 1089 + case SQL_DESC_UNNAMED: // 1089 if (NumericAttribute) { - *NumericAttribute = SQL_NAMED; // Column has a name + *NumericAttribute = col.unnamed; + } + break; + + case SQL_DESC_BASE_COLUMN_NAME: // 1025 + if (CharacterAttribute && BufferLength > 0) { + strncpy((char *) CharacterAttribute, col.baseColumnName, BufferLength); + if (StringLength) { + *StringLength = col.baseColumnNameLength; + } + } + break; + + case SQL_DESC_BASE_TABLE_NAME: // 1026 + if (CharacterAttribute && BufferLength > 0) { + strncpy((char *) CharacterAttribute, col.baseTableName, BufferLength); + if (StringLength) { + *StringLength = col.baseTableNameLength; + } } break; @@ -238,87 +260,88 @@ SQLRETURN SQL_API SQLColAttribute( *StringLength = 0; } break; - } + } - return SQL_SUCCESS; + return SQL_SUCCESS; } SQLRETURN SQL_API SQLColAttributeW( - SQLHSTMT StatementHandle, - SQLUSMALLINT ColumnNumber, - SQLUSMALLINT FieldIdentifier, - SQLPOINTER CharacterAttribute, - SQLSMALLINT BufferLength, - SQLSMALLINT* StringLength, - SQLLEN* NumericAttribute) { - - LOG("SQLColAttributeW called"); - - // For non-string attributes, just delegate to the ANSI version - switch (FieldIdentifier) { - // String attributes that need Unicode conversion - case SQL_COLUMN_NAME: // 1 - case SQL_COLUMN_TYPE_NAME: // 14 - case SQL_COLUMN_TABLE_NAME: // 15 - case SQL_COLUMN_OWNER_NAME: // 16 - case SQL_COLUMN_QUALIFIER_NAME: // 17 - case SQL_COLUMN_LABEL: // 18 - case SQL_DESC_NAME: // SQL_COLUMN_NAME - // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME - // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME - // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME - // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME - // case SQL_DESC_LABEL: // SQL_COLUMN_LABEL - case SQL_DESC_LITERAL_PREFIX: // 27 - case SQL_DESC_LITERAL_SUFFIX: // 28 - case SQL_DESC_LOCAL_TYPE_NAME: // 29 - break; // Handle these below with Unicode conversion - - // All other attributes can go directly to ANSI version - default: - return SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier, - CharacterAttribute, BufferLength, StringLength, - NumericAttribute); - } - - // Get the ANSI string - char ansiBuffer[SQL_MAX_MESSAGE_LENGTH]; - SQLSMALLINT ansiLength = 0; - - SQLRETURN ret = SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier, - ansiBuffer, sizeof(ansiBuffer), &ansiLength, NumericAttribute); - - if (!SQL_SUCCEEDED(ret)) { - return ret; - } - - // Convert to wide char if we have a buffer - if (SQL_SUCCEEDED(ret) && CharacterAttribute && BufferLength > 0) { - if (ansiLength > 0) { - size_t numChars = 0; - mbstowcs_s(&numChars, (wchar_t*)CharacterAttribute, - BufferLength/sizeof(wchar_t), ansiBuffer, _TRUNCATE); - if (StringLength) { - *StringLength = static_cast(numChars * sizeof(wchar_t)); - } - LOGF("Converted string to Unicode. Original length: %d, Wide length: %zu", - ansiLength, numChars); - } else { - // Empty string case - if (BufferLength >= sizeof(wchar_t)) { - *((wchar_t*)CharacterAttribute) = L'\0'; - } - if (StringLength) { - *StringLength = 0; - } - LOG("Set empty Unicode string"); - } - } else if (StringLength) { - // Just return required length if no buffer provided - *StringLength = ansiLength * sizeof(wchar_t); - LOGF("Returning required buffer length: %d bytes", *StringLength); - } - - return ret; + SQLHSTMT StatementHandle, + SQLUSMALLINT ColumnNumber, + SQLUSMALLINT FieldIdentifier, + SQLPOINTER CharacterAttribute, + SQLSMALLINT BufferLength, + SQLSMALLINT *StringLength, + SQLLEN *NumericAttribute) { + LOG("SQLColAttributeW called"); + + // For non-string attributes, just delegate to the ANSI version + switch (FieldIdentifier) { + // String attributes that need Unicode conversion + case SQL_COLUMN_NAME: // 1 + case SQL_COLUMN_TYPE_NAME: // 14 + case SQL_COLUMN_TABLE_NAME: // 15 + case SQL_COLUMN_OWNER_NAME: // 16 + case SQL_COLUMN_QUALIFIER_NAME: // 17 + case SQL_COLUMN_LABEL: // 18 + case SQL_DESC_NAME: // SQL_COLUMN_NAME + // case SQL_DESC_TYPE_NAME: // SQL_COLUMN_TYPE_NAME + // case SQL_DESC_TABLE_NAME: // SQL_COLUMN_TABLE_NAME + // case SQL_DESC_SCHEMA_NAME: // SQL_COLUMN_OWNER_NAME + // case SQL_DESC_CATALOG_NAME: // SQL_COLUMN_QUALIFIER_NAME + // case SQL_DESC_LABEL: // SQL_COLUMN_LABEL + case SQL_DESC_LITERAL_PREFIX: // 27 + case SQL_DESC_LITERAL_SUFFIX: // 28 + case SQL_DESC_LOCAL_TYPE_NAME: // 29 + case SQL_DESC_BASE_COLUMN_NAME: // 1025 + case SQL_DESC_BASE_TABLE_NAME: // 1026 + break; // Handle these below with Unicode conversion + + // All other attributes can go directly to ANSI version + default: + return SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier, + CharacterAttribute, BufferLength, StringLength, + NumericAttribute); + } + + // Get the ANSI string + char ansiBuffer[SQL_MAX_MESSAGE_LENGTH]; + SQLSMALLINT ansiLength = 0; + + SQLRETURN ret = SQLColAttribute(StatementHandle, ColumnNumber, FieldIdentifier, + ansiBuffer, sizeof(ansiBuffer), &ansiLength, NumericAttribute); + + if (!SQL_SUCCEEDED(ret)) { + return ret; + } + + // Convert to wide char if we have a buffer + if (SQL_SUCCEEDED(ret) && CharacterAttribute && BufferLength > 0) { + if (ansiLength > 0) { + size_t numChars = 0; + errno_t err = mbstowcs_s(&numChars, (wchar_t *)CharacterAttribute, + BufferLength / sizeof(wchar_t), ansiBuffer, _TRUNCATE); + if (err == 0 && StringLength) { + *StringLength = static_cast(numChars * sizeof(wchar_t)); + } + LOGF("Converted string to Unicode. Original length: %d, Wide length: %zu", ansiLength, numChars); + + } else { + // Empty string case + if (BufferLength >= sizeof(wchar_t)) { + *((wchar_t *) CharacterAttribute) = L'\0'; + } + if (StringLength) { + *StringLength = 0; + } + LOG("Set empty Unicode string"); + } + } else if (StringLength) { + // Just return required length if no buffer provided + *StringLength = ansiLength * sizeof(wchar_t); + LOGF("Returning required buffer length: %d bytes", *StringLength); + } + + return ret; } -} // extern "C" \ No newline at end of file +} // extern "C" diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp index 25341ff..960dff5 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLDriverConnect.cpp @@ -3,7 +3,6 @@ #define WIN32_LEAN_AND_MEAN #include "../include/connection.hpp" #include "../include/logging.hpp" -#include "../include/httplib.h" #include "../include/globals.hpp" #include "../include/environment.hpp" #include "../include/statement.hpp" diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp index 47f3fff..173d261 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLExecDirect.cpp @@ -10,63 +10,54 @@ extern "C" { - SQLRETURN SQL_API SQLExecDirect_A( - SQLHSTMT hstmt, - SQLCHAR* szSqlStr, - SQLINTEGER cbSqlStr) - { - std::string sqlStr(reinterpret_cast(szSqlStr), - cbSqlStr == SQL_NTS ? strlen(reinterpret_cast(szSqlStr)) : cbSqlStr); + SQLRETURN SQL_API SQLExecDirect( + SQLHSTMT StatementHandle, + SQLCHAR* StatementText, + SQLINTEGER TextLength) { - auto stmt = static_cast(hstmt); - if (!stmt) { + LOGF("SQLExecDirect called with query: %s", StatementText); + + auto* stmt = static_cast(StatementHandle); + if (!stmt || !stmt->conn) { + LOG("Invalid statement handle or connection"); return SQL_INVALID_HANDLE; } - return SQL_ERROR; - } - - SQLRETURN SQL_API SQLExecDirect_W( - SQLHSTMT hstmt, - SQLWCHAR* szSqlStr, - SQLINTEGER cbSqlStr) - { - // Get the actual string length - size_t actualLength = 0; - while (szSqlStr[actualLength] != L'\0' && - !(szSqlStr[actualLength] == L'\\' && szSqlStr[actualLength + 1] == L'0')) { - actualLength++; - } - - // Create wide string excluding the "\ 0" terminator - std::wstring wsqlStr(szSqlStr, actualLength); - std::string sqlStr = WideStringToString(wsqlStr); + // Convert to string based on TextLength + std::string query; + if (TextLength == SQL_NTS) { + query = reinterpret_cast(StatementText); + } else { + query = std::string(reinterpret_cast(StatementText), TextLength); + } - LOGF("SQLExecDirect_W: Original SQL: '%s'", sqlStr.c_str()); + LOGF("Executing query: %s", query.c_str()); + return stmt->conn->Query(query, stmt); + } - // Transform the query - std::string transformedSql = sqlStr; + // And the Unicode version + SQLRETURN SQL_API SQLExecDirectW( + SQLHSTMT StatementHandle, + SQLWCHAR* StatementText, + SQLINTEGER TextLength) { - // Replace INFORMATION_SCHEMA.TABLES - size_t pos = transformedSql.find("INFORMATION_SCHEMA.TABLES"); - if (pos != std::string::npos) { - transformedSql.replace(pos, std::string("INFORMATION_SCHEMA.TABLES").length(), "metadata.TABLES"); + auto* stmt = static_cast(StatementHandle); + if (!stmt || !stmt->conn) { + LOG("Invalid statement handle or connection"); + return SQL_INVALID_HANDLE; } - // Replace INFORMATION_SCHEMA.COLUMNS - pos = transformedSql.find("INFORMATION_SCHEMA.COLUMNS"); - if (pos != std::string::npos) { - transformedSql.replace(pos, std::string("INFORMATION_SCHEMA.COLUMNS").length(), "metadata.COLUMNS"); + // Convert wide string to UTF-8 + std::wstring wquery; + if (TextLength == SQL_NTS) { + wquery = reinterpret_cast(StatementText); + } else { + wquery = std::wstring(reinterpret_cast(StatementText), TextLength); } - LOGF("SQLExecDirect_W: Transformed SQL: '%s'", transformedSql.c_str()); - - auto stmt = static_cast(hstmt); - if (!stmt) { - LOG("SQLExecDirect_W: Invalid statement handle"); - return SQL_INVALID_HANDLE; - } + std::string query = WideStringToString(wquery); + LOGF("SQLExecDirectW executing query: %s", query.c_str()); - return SQL_ERROR; + return stmt->conn->Query(query, stmt); } } // extern "C" \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp index 473564d..61e8e5a 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/SQLSetDescField.cpp @@ -52,31 +52,125 @@ SQLRETURN SQL_API SQLSetDescField( } switch (FieldIdentifier) { + case SQL_COLUMN_TYPE: case SQL_DESC_TYPE: - case SQL_DESC_CONCISE_TYPE: + // case SQL_DESC_CONCISE_TYPE: stmt->resultColumns[colIdx].sqlType = reinterpret_cast(Value); break; + case SQL_COLUMN_NAME: case SQL_DESC_NAME: if (Value) { + LOGF("Setting column name to: %s", static_cast(Value)); stmt->resultColumns[colIdx].name = static_cast(Value); stmt->resultColumns[colIdx].nameLength = (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].name) : BufferLength; + LOGF("Set column name to: %s", stmt->resultColumns[colIdx].name); } break; + // case SQL_COLUMN_LABEL: + case SQL_DESC_LABEL: + if (Value) { + stmt->resultColumns[colIdx].label = static_cast(Value); + stmt->resultColumns[colIdx].labelLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].label) : BufferLength; + } + break; + + case SQL_COLUMN_NULLABLE: case SQL_DESC_NULLABLE: stmt->resultColumns[colIdx].nullable = reinterpret_cast(Value); break; - + case SQL_COLUMN_LENGTH: case SQL_DESC_LENGTH: stmt->resultColumns[colIdx].columnSize = reinterpret_cast(Value); break; + case SQL_COLUMN_PRECISION: case SQL_DESC_PRECISION: stmt->resultColumns[colIdx].decimalDigits = reinterpret_cast(Value); break; + case SQL_COLUMN_SCALE: + case SQL_DESC_SCALE: + stmt->resultColumns[colIdx].scale = reinterpret_cast(Value); + break; + + case SQL_CATALOG_NAME: + case SQL_DESC_CATALOG_NAME: + if (Value) { + stmt->resultColumns[colIdx].catalogName = static_cast(Value); + stmt->resultColumns[colIdx].catalogNameLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].catalogName) : BufferLength; + } + break; + + case SQL_DESC_SCHEMA_NAME: + if (Value) { + stmt->resultColumns[colIdx].schemaName = static_cast(Value); + stmt->resultColumns[colIdx].schemaNameLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].schemaName) : BufferLength; + } + break; + + // case SQL_COLUMN_TABLE_NAME: + case SQL_DESC_TABLE_NAME: + if (Value) { + stmt->resultColumns[colIdx].tableName = static_cast(Value); + stmt->resultColumns[colIdx].tableNameLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].tableName) : BufferLength; + } + break; + + case SQL_DESC_BASE_COLUMN_NAME: + if (Value) { + stmt->resultColumns[colIdx].baseColumnName = static_cast(Value); + stmt->resultColumns[colIdx].baseColumnNameLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].baseColumnName) : BufferLength; + } + break; + + case SQL_DESC_BASE_TABLE_NAME: + if (Value) { + stmt->resultColumns[colIdx].baseTableName = static_cast(Value); + stmt->resultColumns[colIdx].baseTableNameLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].baseTableName) : BufferLength; + } + break; + + case SQL_DESC_LITERAL_PREFIX: + if (Value) { + stmt->resultColumns[colIdx].literalPrefix = static_cast(Value); + stmt->resultColumns[colIdx].literalPrefixLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].literalPrefix) : BufferLength; + } + break; + + case SQL_DESC_LITERAL_SUFFIX: + if (Value) { + stmt->resultColumns[colIdx].literalSuffix = static_cast(Value); + stmt->resultColumns[colIdx].literalSuffixLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].literalSuffix) : BufferLength; + } + break; + + case SQL_DESC_LOCAL_TYPE_NAME: + if (Value) { + stmt->resultColumns[colIdx].localTypeName = static_cast(Value); + stmt->resultColumns[colIdx].localTypeNameLength = + (BufferLength == SQL_NTS) ? strlen(stmt->resultColumns[colIdx].localTypeName) : BufferLength; + } + break; + + case SQL_DESC_UNNAMED: + stmt->resultColumns[colIdx].unnamed = reinterpret_cast(Value); + break; + + case SQL_DESC_DISPLAY_SIZE: + stmt->resultColumns[colIdx].displaySize = reinterpret_cast(Value); + break; + default: LOGF("Unsupported field identifier: %d", FieldIdentifier); return SQL_ERROR; @@ -85,14 +179,26 @@ SQLRETURN SQL_API SQLSetDescField( return SQL_SUCCESS; } - SQLRETURN SQL_API SQLSetDescFieldW( - SQLHDESC DescriptorHandle, - SQLSMALLINT RecNumber, - SQLSMALLINT FieldIdentifier, - SQLPOINTER Value, - SQLINTEGER BufferLength) { +SQLRETURN SQL_API SQLSetDescFieldW( + SQLHDESC DescriptorHandle, + SQLSMALLINT RecNumber, + SQLSMALLINT FieldIdentifier, + SQLPOINTER Value, + SQLINTEGER BufferLength) { if (FieldIdentifier == SQL_DESC_NAME && Value != nullptr) { + // Convert wide string to narrow for internal storage + auto wstr = static_cast(Value); + std::string name = WideStringToString(wstr); + LOGF("Converted name: %s", name.c_str()); + auto result = SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(name.c_str()), SQL_NTS); + Statement *s = static_cast(DescriptorHandle); + LOGF("Post SQLSetDescField: %s", s->resultColumns[RecNumber -1].name); + return result; + } + + if (FieldIdentifier == SQL_COLUMN_NAME && Value != nullptr) { // Convert wide string to narrow for internal storage const wchar_t* wstr = static_cast(Value); std::string name = WideStringToString(std::wstring(wstr)); @@ -100,6 +206,56 @@ SQLRETURN SQL_API SQLSetDescField( const_cast(name.c_str()), SQL_NTS); } + if (FieldIdentifier == SQL_DESC_CATALOG_NAME && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string catalogName = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(catalogName.c_str()), SQL_NTS); + } + + if (FieldIdentifier == SQL_DESC_SCHEMA_NAME && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string schemaName = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(schemaName.c_str()), SQL_NTS); + } + + if (FieldIdentifier == SQL_DESC_TABLE_NAME && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string tableName = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(tableName.c_str()), SQL_NTS); + } + + if (FieldIdentifier == SQL_DESC_BASE_COLUMN_NAME && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string baseColumnName = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(baseColumnName.c_str()), SQL_NTS); + } + + if (FieldIdentifier == SQL_DESC_BASE_TABLE_NAME && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string baseTableName = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(baseTableName.c_str()), SQL_NTS); + } + + if (FieldIdentifier == SQL_DESC_LOCAL_TYPE_NAME && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string localTypeName = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(localTypeName.c_str()), SQL_NTS); + } + + if (FieldIdentifier == SQL_DESC_LABEL && Value != nullptr) { + const wchar_t* wstr = static_cast(Value); + std::string label = WideStringToString(std::wstring(wstr)); + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, + const_cast(label.c_str()), SQL_NTS); + } + + return SQLSetDescField(DescriptorHandle, RecNumber, FieldIdentifier, Value, BufferLength); } } \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp index d453a01..fc37a31 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/connection.cpp @@ -1,182 +1,174 @@ -// Windows includes must come first -#define WIN32_LEAN_AND_MEAN -#include -#include - -// ODBC includes -#include - -// JNI includes -#include +#include "../include/connection.hpp" -// Standard includes -#include -#include -#include -#include +#include -// Project includes -#include "../include/connection.hpp" #include "../include/statement.hpp" -#include "../include/globals.hpp" #include "../include/logging.hpp" +#include +#include -bool Connection::initJVM() { - LOG("Initializing JVM"); +#include "JniParam.hpp" - JavaVMInitArgs vm_args; - JavaVMOption options[5]; - // Add the path to your JAR files - std::string classPath = "-Djava.class.path="; - classPath += dllPath + "\\jni-arrow-1.0.0-jar-with-dependencies.jar"; - LOGF("Java classpath: %s", classPath.c_str()); +HMODULE GetCurrentModule() { + HMODULE hModule = nullptr; + GetModuleHandleEx( + GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(GetCurrentModule), + &hModule); + return hModule; +} - options[0].optionString = const_cast(classPath.c_str()); +std::string GetModuleDirectory() { + char path[MAX_PATH]; + HMODULE hModule = GetCurrentModule(); + GetModuleFileNameA(hModule, path, MAX_PATH); + std::string modulePath(path); + size_t pos = modulePath.find_last_of("\\/"); + return (std::string::npos == pos) ? "" : modulePath.substr(0, pos); +} - // Add any necessary JVM options - std::string extraOptions = "-Xmx512m"; - options[1].optionString = const_cast(extraOptions.c_str()); - options[2].optionString = const_cast("--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED"); - options[3].optionString = const_cast("-Dotel.java.global-autoconfigure.enabled=true"); - options[4].optionString = const_cast("-Dlog_level=debug"); +Connection::~Connection() { + if (connected) { + disconnect(); + } +} - // Set the JNI version to 1.8 for OpenJDK 11 - vm_args.version = JNI_VERSION_1_8; - vm_args.nOptions = 2; - vm_args.options = options; - vm_args.ignoreUnrecognized = JNI_TRUE; +// std::string WideStringToString(const std::wstring& wstr) { +// LOG("WideStringToString: Input length: " + std::to_string(wstr.length())); +// +// if (wstr.empty()) { +// LOG("WideStringToString: Empty input string"); +// return {}; +// } +// +// // Print first few wide chars as numbers for debugging +// char intbuf[100] = {0}; +// char* intptr = intbuf; +// size_t maxChars = wstr.length() > 5 ? 5 : wstr.length(); +// for (size_t i = 0; i < maxChars; ++i) { +// intptr += sprintf_s(intptr, 20, "%d ", static_cast(wstr[i])); +// } +// LOG("WideStringToString: First 5 chars (as ints): " + std::string(intbuf)); +// +// // Get required buffer size +// const int size_needed = WideCharToMultiByte(CP_UTF8, 0, +// wstr.data(), static_cast(wstr.length()), +// nullptr, 0, +// nullptr, nullptr); +// +// if (size_needed <= 0) { +// const DWORD error = GetLastError(); +// LOG("WideStringToString: Error calculating buffer size. Error code: " + std::to_string(error)); +// return std::string(); +// } +// +// LOG("WideStringToString: Allocating buffer of size: " + std::to_string(size_needed)); +// std::string strTo(size_needed, 0); +// +// int result = WideCharToMultiByte(CP_UTF8, 0, +// wstr.data(), static_cast(wstr.length()), +// &strTo[0], size_needed, +// nullptr, nullptr); +// +// if (result <= 0) { +// DWORD error = GetLastError(); +// LOG("WideStringToString: Conversion failed. Error code: " + std::to_string(error)); +// return std::string(); +// } +// +// LOG("WideStringToString: Converted string length: " + std::to_string(strTo.length())); +// LOG("WideStringToString: Converted string: '" + strTo + "'"); +// +// return strTo; +// } + +#include // For wcstombs +#include +#include + +// Function to convert std::wstring to std::string +std::string WideStringToString(const std::wstring& wstr) { + // Determine the required buffer size for the multibyte string + size_t len = wcstombs(nullptr, wstr.c_str(), 0); + if (len == static_cast(-1)) { + throw std::runtime_error("Conversion error"); + } - if (const jint rc = JNI_CreateJavaVM(&jvm, reinterpret_cast(&env), &vm_args); rc != JNI_OK) { - LOGF("Failed to create JVM: %d", rc); - LOGF("JNI_CreateJavaVM return code: %d", rc); - return false; + // Create a buffer of the appropriate size + std::vector buffer(len + 1); + wcstombs(buffer.data(), wstr.c_str(), buffer.size()); + + // Convert the buffer to std::string + return std::string(buffer.data()); +} + +// Function to convert std::string to std::wstring +std::wstring StringToWideString(const std::string& str) { + // Determine the required buffer size for the wide character string + size_t len = mbstowcs(nullptr, str.c_str(), 0); + if (len == static_cast(-1)) { + throw std::runtime_error("Conversion error"); } - LOG("JVM created successfully"); - return true; + // Create a buffer of the appropriate size + std::vector buffer(len + 1); + mbstowcs(buffer.data(), str.c_str(), buffer.size()); + + // Convert the buffer to std::wstring + return std::wstring(buffer.data()); } void Connection::setConnectionString(const std::string& dsn, const std::string& uid, const std::string& authStr) { - // Store connection string std::ostringstream ss; ss << "DSN=" << dsn << ";UID=" << uid << ";PWD=" << authStr; this->connectionString = ss.str(); } void Connection::setConnectionString(const std::string& dsn) { - this->connectionString =dsn; -} - -bool Connection::initWrapper(const ConnectionParams& params) { - try { - // Find the wrapper class - wrapperClass = env->FindClass("com/hasura/ArrowJdbcWrapper"); - if (!wrapperClass) { - LOG("Failed to find ArrowJdbcWrapper class"); - return false; - } - LOG("Found ArrowJdbcWrapper class"); - - // Get the method ID for the constructor - jmethodID constructor = env->GetMethodID(wrapperClass, "", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V"); - if (constructor == nullptr) { - LOG("Failed to find ArrowJdbcWrapper constructor"); - return false; - } - LOG("Found ArrowJdbcWrapper constructor"); - - // Get the method ID for healthCheck - jmethodID healthCheckMethod = env->GetMethodID(wrapperClass, "healthCheck", "()Z"); - if (healthCheckMethod == nullptr) { - LOG("Failed to find healthCheck method"); - return false; - } - LOG("Found healthCheck method"); - - // Create Java strings for the constructor arguments - const std::string jdbcUrl = buildJdbcUrl(params); - jstring jvm_jdbcUrl = env->NewStringUTF(jdbcUrl.c_str()); // Replace JDBC_URL with your actual URL - jstring username = env->NewStringUTF(params.uid.c_str()); // Replace USERNAME with your actual username - jstring password = env->NewStringUTF(params.pwd.c_str()); // Replace PASSWORD with your actual password - - // Create a new instance of ArrowJdbcWrapper - LOGF("jdbcUrl %s, username %s, password, %s", jdbcUrl.c_str(), params.uid.c_str(), params.pwd.c_str()); - wrapperInstance = env->NewObject(wrapperClass, constructor, jvm_jdbcUrl, username, password); - if (wrapperInstance == nullptr) { - LOG("Failed to construct ArrowJdbcWrapper"); - // Check if an exception was thrown during the object creation - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); // Print the exception details to stderr - env->ExceptionClear(); // Clear the exception so that we can continue - LOG("Failed to construct ArrowJdbcWrapper due to an exception"); - return false; - } - return false; - } - LOG("Initialized ArrowJdbcWrapper"); - - // Call the healthCheck method - const jboolean healthStatus = env->CallBooleanMethod(wrapperInstance, healthCheckMethod); - if (env->ExceptionCheck()) { - env->ExceptionDescribe(); - env->ExceptionClear(); - LOG("Exception occurred while calling healthCheck method"); - return false; - } - bool isHealthy = (healthStatus == JNI_TRUE); - LOGF("Health check status: %d", isHealthy); - - return isHealthy; - } - catch (const std::exception& e) { - LOGF("Exception in initWrapper: %s", e.what()); - return false; - } + this->connectionString = dsn; } SQLRETURN Connection::connect() { - if (isConnected) { + if (connected) { LOG("Already connected"); return SQL_ERROR; } - // Parse connection parameters ConnectionParams params; if (!parseConnectionString(connectionString, params)) { LOG("Failed to parse connection string"); return SQL_ERROR; } - // Initialize JVM if not already initialized if (!jvm && !initJVM()) { LOG("Failed to initialize JVM"); return SQL_ERROR; } - // Initialize wrapper with connection parameters if (!initWrapper(params)) { LOG("Failed to initialize wrapper"); return SQL_ERROR; } - isConnected = true; - LOG("Connected!"); + connected = true; + LOG("Connected successfully"); return SQL_SUCCESS; } SQLRETURN Connection::disconnect() { - if (!isConnected) { + if (!connected) { LOG("Not connected"); return SQL_ERROR; } + cleanupActiveStmts(); + if (env && wrapperInstance) { - // Call close on the wrapper instance jmethodID closeMethod = env->GetMethodID(wrapperClass, "close", "()V"); env->CallVoidMethod(wrapperInstance, closeMethod); - // Delete global references env->DeleteGlobalRef(wrapperInstance); env->DeleteGlobalRef(wrapperClass); @@ -190,100 +182,169 @@ SQLRETURN Connection::disconnect() { env = nullptr; } - isConnected = false; + connected = false; return SQL_SUCCESS; } -Connection::~Connection() { - if (isConnected) { - disconnect(); +SQLRETURN Connection::Query(const std::string& query, Statement* stmt) { + stmt->setOriginalQuery(query); + std::string interpolatedQuery = stmt->buildInterpolatedQuery(); + return executeAndGetArrowResult("executeQuery", {JniParam(interpolatedQuery)}, stmt); +} + +SQLRETURN Connection::GetTables( + const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType, + Statement* stmt) const { + + std::vector types; + if (!tableType.empty()) { + std::istringstream ss(tableType); + std::string type; + while (std::getline(ss, type, ',')) { + type.erase(0, type.find_first_not_of(' ')); + type.erase(type.find_last_not_of(' ') + 1); + if (!type.empty()) { + types.push_back(type); + } + } } + + return executeAndGetArrowResult("getTables", { + JniParam(catalogName), + JniParam(schemaName), + JniParam(tableName), + JniParam(types) // This will be converted to String[] + }, stmt); } -// Helper functions -std::string GetModuleDirectory(HMODULE hModule) { - char path[MAX_PATH]; - GetModuleFileNameA(hModule, path, MAX_PATH); - std::string modulePath(path); - size_t pos = modulePath.find_last_of("\\/"); - return (std::string::npos == pos) ? "" : modulePath.substr(0, pos); +SQLRETURN Connection::GetColumns( + const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName, + Statement* stmt) const { + + return executeAndGetArrowResult("getColumns", { + JniParam(catalogName), + JniParam(schemaName), + JniParam(tableName), + JniParam(columnName) + }, stmt); } -std::string WideStringToString(const std::wstring& wstr) { - LOG("WideStringToString: Input length: " + std::to_string(wstr.length())); - - if (wstr.empty()) { - LOG("WideStringToString: Empty input string"); - return {}; - } +bool Connection::hasActiveStmts() const { + return !activeStmts.empty(); +} - // Print first few wide chars as numbers for debugging - char intbuf[100] = {0}; - char* intptr = intbuf; - size_t maxChars = wstr.length() > 5 ? 5 : wstr.length(); - for (size_t i = 0; i < maxChars; ++i) { - intptr += sprintf_s(intptr, 20, "%d ", static_cast(wstr[i])); - } - LOG("WideStringToString: First 5 chars (as ints): " + std::string(intbuf)); - - // Get required buffer size - const int size_needed = WideCharToMultiByte(CP_UTF8, 0, - wstr.data(), static_cast(wstr.length()), - nullptr, 0, - nullptr, nullptr); - - if (size_needed <= 0) { - const DWORD error = GetLastError(); - LOG("WideStringToString: Error calculating buffer size. Error code: " + std::to_string(error)); - return std::string(); +void Connection::cleanupActiveStmts() { + for (auto stmt : activeStmts) { + if (stmt) { + stmt->clearResults(); + } } - - LOG("WideStringToString: Allocating buffer of size: " + std::to_string(size_needed)); - std::string strTo(size_needed, 0); - - int result = WideCharToMultiByte(CP_UTF8, 0, - wstr.data(), static_cast(wstr.length()), - &strTo[0], size_needed, - nullptr, nullptr); - - if (result <= 0) { - DWORD error = GetLastError(); - LOG("WideStringToString: Conversion failed. Error code: " + std::to_string(error)); - return std::string(); + activeStmts.clear(); +} + +bool Connection::initJVM() { + LOG("Initializing JVM"); + + JavaVMInitArgs vm_args; + JavaVMOption options[5]; + + std::string classPath = "-Djava.class.path="; + classPath += GetModuleDirectory() + "\\jni-arrow-1.0.0-jar-with-dependencies.jar"; + LOGF("Java classpath: %s", classPath.c_str()); + + options[0].optionString = const_cast(classPath.c_str()); + options[1].optionString = const_cast("-Xmx512m"); + options[2].optionString = const_cast("--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED"); + options[3].optionString = const_cast("-Dotel.java.global-autoconfigure.enabled=true"); + options[4].optionString = const_cast("-Dlog_level=warn"); + + vm_args.version = JNI_VERSION_1_8; + vm_args.nOptions = 5; + vm_args.options = options; + vm_args.ignoreUnrecognized = JNI_TRUE; + + jint rc = JNI_CreateJavaVM(&jvm, reinterpret_cast(&env), &vm_args); + if (rc != JNI_OK) { + LOGF("Failed to create JVM: %d", rc); + return false; } - - LOG("WideStringToString: Converted string length: " + std::to_string(strTo.length())); - LOG("WideStringToString: Converted string: '" + strTo + "'"); - - return strTo; + + LOG("JVM created successfully"); + return true; } -bool Connection::parseConnectionString(const std::string& connStr, ConnectionParams& params) { - if (connStr.empty()) { - LOG("Empty connection string provided"); +bool Connection::initWrapper(const ConnectionParams& params) { + try { + wrapperClass = env->FindClass("com/hasura/ArrowJdbcWrapper"); + if (!wrapperClass) { + LOG("Failed to find ArrowJdbcWrapper class"); + return false; + } + + wrapperClass = reinterpret_cast(env->NewGlobalRef(wrapperClass)); + + jmethodID constructor = env->GetMethodID(wrapperClass, "", + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V"); + if (!constructor) { + LOG("Failed to find constructor"); + return false; + } + + const std::string jdbcUrl = buildJdbcUrl(params); + jstring jvm_jdbcUrl = env->NewStringUTF(jdbcUrl.c_str()); + jstring username = env->NewStringUTF(params.uid.c_str()); + jstring password = env->NewStringUTF(params.pwd.c_str()); + + jobject localInstance = env->NewObject(wrapperClass, constructor, jvm_jdbcUrl, username, password); + if (!localInstance) { + LOG("Failed to create wrapper instance"); + return false; + } + + wrapperInstance = env->NewGlobalRef(localInstance); + env->DeleteLocalRef(localInstance); + + env->DeleteLocalRef(jvm_jdbcUrl); + env->DeleteLocalRef(username); + env->DeleteLocalRef(password); + + jmethodID healthCheckMethod = env->GetMethodID(wrapperClass, "healthCheck", "()Z"); + if (!healthCheckMethod) { + LOG("Failed to find healthCheck method"); + return false; + } + + jboolean health = env->CallBooleanMethod(wrapperInstance, healthCheckMethod); + return health == JNI_TRUE; + } + catch (const std::exception& e) { + LOGF("Exception in initWrapper: %s", e.what()); return false; } +} - LOGF("Parsing connection string: %s", connStr.c_str()); +bool Connection::parseConnectionString(const std::string& connStr, ConnectionParams& params) { std::istringstream ss(connStr); std::string token; while (std::getline(ss, token, ';')) { - auto pos = token.find('='); - if (pos == std::string::npos) { - LOGF("Invalid token: %s", token.c_str()); - continue; - } + size_t pos = token.find('='); + if (pos == std::string::npos) continue; std::string key = token.substr(0, pos); std::string value = token.substr(pos + 1); - while (!key.empty() && key[0] == ' ') key.erase(0, 1); - while (!key.empty() && key.back() == ' ') key.pop_back(); - while (!value.empty() && value[0] == ' ') value.erase(0, 1); - while (!value.empty() && value.back() == ' ') value.pop_back(); - - LOGF("Parsed key-value pair: %s=%s", key.c_str(), value.c_str()); + // Trim whitespace + key.erase(0, key.find_first_not_of(' ')); + key.erase(key.find_last_not_of(' ') + 1); + value.erase(0, value.find_first_not_of(' ')); + value.erase(value.find_last_not_of(' ') + 1); if (key == "Server") params.server = value; else if (key == "Port") params.port = value; @@ -294,45 +355,14 @@ bool Connection::parseConnectionString(const std::string& connStr, ConnectionPar else if (key == "PWD" || key == "Password") params.pwd = value; else if (key == "Encrypt") params.encrypt = value; else if (key == "Timeout") params.timeout = value; - else if (key == "DRIVER") { - LOGF("Ignoring key: %s", key.c_str()); - } - else { - LOGF("Unknown key: %s", key.c_str()); - } } - LOGF("Server: %s, Port: %s, Database: %s", - params.server.c_str(), params.port.c_str(), params.database.c_str()); - - if (!params.isValid()) { - LOG("Connection parameters are not valid"); - return false; - } - - return true; -} - -bool Connection::hasActiveStmts() const { - // Return true if there are any active statements - return !activeStmts.empty(); -} - -void Connection::cleanupActiveStmts() { - // Clean up any active statements - for (auto stmt : activeStmts) { - if (stmt) { - stmt->clearResults(); - // You might want to also free the statement here - // depending on your memory management strategy - } - } - activeStmts.clear(); + return params.isValid(); } std::string Connection::buildJdbcUrl(const ConnectionParams& params) { std::string protocol = (params.encrypt == "yes") ? "https" : "http"; - + std::ostringstream url; url << "jdbc:graphql:" << protocol << "://" << params.server << ":" << params.port << "/" @@ -356,263 +386,221 @@ std::string Connection::buildJdbcUrl(const ConnectionParams& params) { return url.str(); } -SQLRETURN Connection::GetTables( - const std::string& catalogName, - const std::string& schemaName, - const std::string& tableName, - const std::string& tableType, - Statement* stmt) const { - - LOGF("GetTables called with params - catalog: '%s', schema: '%s', table: '%s', types: '%s'", - catalogName.c_str(), schemaName.c_str(), tableName.c_str(), tableType.c_str()); - - if (!isConnected || !env || !wrapperInstance || !wrapperClass) { - LOG("Connection validation failed:"); - LOGF(" - isConnected: %d", isConnected); - LOGF(" - env: %p", (void*)env); - LOGF(" - wrapperInstance: %p", (void*)wrapperInstance); - LOGF(" - wrapperClass: %p", (void*)wrapperClass); - return SQL_ERROR; - } - +SQLRETURN Connection::populateColumnDescriptors(jobject schemaRoot, Statement* stmt) const { try { - // Helper lambda to convert std::string to jstring - auto createJavaString = [this](const std::string& str) -> jstring { - if (str.empty()) { - LOG("Creating null jstring for empty input"); - return nullptr; - } - LOGF("Creating jstring for input: '%s'", str.c_str()); - return env->NewStringUTF(str.c_str()); - }; - - // Convert input strings to Java strings - LOG("Converting input parameters to Java strings"); - jstring jCatalog = createJavaString(catalogName); - LOGF("Created jCatalog: %p", (void*)jCatalog); - - jstring jSchema = createJavaString(schemaName); - LOGF("Created jSchema: %p", (void*)jSchema); - - jstring jTable = createJavaString(tableName); - LOGF("Created jTable: %p", (void*)jTable); - - // Handle table types - jobjectArray jTypes = nullptr; - if (!tableType.empty()) { - LOG("Processing table types string"); - std::vector typeList; - std::istringstream typeStream(tableType); - std::string type; - - // Parse comma-separated types - while (std::getline(typeStream, type, ',')) { - // Trim whitespace - type.erase(0, type.find_first_not_of(' ')); - type.erase(type.find_last_not_of(' ') + 1); - if (!type.empty()) { - LOGF("Found table type: '%s'", type.c_str()); - typeList.push_back(type); - } - } - - if (!typeList.empty()) { - LOGF("Creating Java string array with %zu types", typeList.size()); - jclass stringClass = env->FindClass("java/lang/String"); - if (!stringClass) { - LOG("Failed to find java.lang.String class"); - return SQL_ERROR; - } - - jTypes = env->NewObjectArray(typeList.size(), stringClass, env->NewStringUTF("")); - if (!jTypes) { - LOG("Failed to create Java string array for types"); - return SQL_ERROR; - } - - for (size_t i = 0; i < typeList.size(); i++) { - LOGF("Setting array element %zu to '%s'", i, typeList[i].c_str()); - jstring typeStr = env->NewStringUTF(typeList[i].c_str()); - env->SetObjectArrayElement(jTypes, i, typeStr); - env->DeleteLocalRef(typeStr); - } - LOGF("Successfully created jTypes array: %p", (void*)jTypes); - } - } else { - LOG("No table types specified"); - } - - // Find getTables method - LOG("Looking up getTables method"); - jmethodID getTablesMethod = env->GetMethodID(wrapperClass, "getTables", - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;[Ljava/lang/String;)Lorg/apache/arrow/vector/VectorSchemaRoot;"); - - if (!getTablesMethod) { - LOG("Failed to find getTables method"); + // Extract schema and get the number of fields + jclass rootClass = env->GetObjectClass(schemaRoot); + jmethodID getSchemaMethod = env->GetMethodID(rootClass, "getSchema", + "()Lorg/apache/arrow/vector/types/pojo/Schema;"); + jobject schema = env->CallObjectMethod(schemaRoot, getSchemaMethod); + + jclass schemaClass = env->GetObjectClass(schema); + jmethodID getFieldsMethod = env->GetMethodID(schemaClass, "getFields", "()Ljava/util/List;"); + jobject fieldsList = env->CallObjectMethod(schema, getFieldsMethod); + + jclass listClass = env->GetObjectClass(fieldsList); + jmethodID sizeMethod = env->GetMethodID(listClass, "size", "()I"); + jint numFields = env->CallIntMethod(fieldsList, sizeMethod); + LOGF("Schema contains %d fields", numFields); + + // Get the IRD handle + SQLHDESC hIRD = nullptr; + SQLRETURN ret = SQLGetStmtAttr(stmt, SQL_ATTR_IMP_ROW_DESC, &hIRD, 0, nullptr); + if (!SQL_SUCCEEDED(ret) || !hIRD) { + LOG("Failed to get IRD handle"); return SQL_ERROR; } - LOGF("Found getTables method: %p", (void*)getTablesMethod); - - // Call getTables - LOG("Calling getTables method"); - jobject schemaRoot = env->CallObjectMethod(wrapperInstance, getTablesMethod, - jCatalog, jSchema, jTable, jTypes); - - // Clean up local references - LOG("Cleaning up local references"); - if (jCatalog) { - env->DeleteLocalRef(jCatalog); - LOG("Deleted jCatalog reference"); - } - if (jSchema) { - env->DeleteLocalRef(jSchema); - LOG("Deleted jSchema reference"); - } - if (jTable) { - env->DeleteLocalRef(jTable); - LOG("Deleted jTable reference"); - } - if (jTypes) { - env->DeleteLocalRef(jTypes); - LOG("Deleted jTypes reference"); + + // Set up descriptors for each column + stmt->resultColumns.resize(numFields); + for (jint i = 0; i < numFields; i++) { + jobject field = env->CallObjectMethod(fieldsList, env->GetMethodID(listClass, "get", "(I)Ljava/lang/Object;"), i); + jclass fieldClass = env->GetObjectClass(field); + + // Get field name + jmethodID getNameMethod = env->GetMethodID(fieldClass, "getName", "()Ljava/lang/String;"); + jstring fieldName = (jstring)env->CallObjectMethod(field, getNameMethod); + const char* nameChars = env->GetStringUTFChars(fieldName, nullptr); + + // Get field type + jmethodID getTypeMethod = env->GetMethodID(fieldClass, "getType", + "()Lorg/apache/arrow/vector/types/pojo/ArrowType;"); + jobject arrowType = env->CallObjectMethod(field, getTypeMethod); + + SQLSMALLINT sqlType = mapArrowTypeToSQL(env, arrowType); + SQLULEN columnSize = getSQLTypeSize(sqlType); + + stmt->resultColumns[i].name = _strdup(nameChars); + stmt->resultColumns[i].nameLength = (SQLSMALLINT)strlen(nameChars); + stmt->resultColumns[i].nullable = SQL_NULLABLE; + stmt->resultColumns[i].columnSize = columnSize; + stmt->resultColumns[i].sqlType = sqlType; + + env->ReleaseStringUTFChars(fieldName, nameChars); + env->DeleteLocalRef(fieldName); + env->DeleteLocalRef(arrowType); + env->DeleteLocalRef(field); + env->DeleteLocalRef(fieldClass); } + env->DeleteLocalRef(listClass); + env->DeleteLocalRef(fieldsList); + env->DeleteLocalRef(schemaClass); + env->DeleteLocalRef(schema); + env->DeleteLocalRef(rootClass); + + return SQL_SUCCESS; + } catch (const std::exception& e) { + LOGF("Exception in populateColumnDescriptors: %s", e.what()); if (env->ExceptionCheck()) { - LOG("Java exception detected during getTables call"); env->ExceptionDescribe(); env->ExceptionClear(); - LOG("Exception cleared"); - return SQL_ERROR; } - - if (!schemaRoot) { - LOG("getTables returned null schemaRoot"); - return SQL_ERROR; - } - LOGF("Successfully got schemaRoot: %p", (void*)schemaRoot); - - // Set up the ODBC result set columns in the statement - LOG("Setting up table result columns in Statement"); - auto tableDefs = stmt->setupTableResultColumns(); - LOG("Successfully set up table result columns"); - - // Convert Arrow VectorSchemaRoot to ODBC result set - LOG("Converting Arrow VectorSchemaRoot to ODBC result set"); - SQLRETURN result = stmt->setArrowResult(schemaRoot, tableDefs); - LOGF("setArrowResult returned: %d", result); - - // Clean up schema root - env->DeleteLocalRef(schemaRoot); - LOG("Cleaned up schemaRoot reference"); - - if (result == SQL_SUCCESS) { - LOG("GetTables completed successfully"); - } else { - LOG("GetTables completed with errors"); - } - return result; - - } catch (const std::exception& e) { - LOGF("Exception in GetTables: %s", e.what()); - LOG("Stack trace (if available):"); - LOG(e.what()); return SQL_ERROR; } catch (...) { - LOG("Unknown exception in GetTables"); + LOG("Unknown exception in populateColumnDescriptors"); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + } return SQL_ERROR; } } -SQLRETURN Connection::GetColumns( - const std::string& catalogName, - const std::string& schemaName, - const std::string& tableName, - const std::string& columnName, +SQLRETURN Connection::executeAndGetArrowResult( + const char* methodName, + const std::vector& params, Statement* stmt) const { - LOGF("GetColumns called with params - catalog: '%s', schema: '%s', table: '%s', column: '%s'", - catalogName.c_str(), schemaName.c_str(), tableName.c_str(), columnName.c_str()); - - if (!isConnected || !env || !wrapperInstance || !wrapperClass) { - LOG("Connection validation failed:"); - LOGF(" - isConnected: %d", isConnected); - LOGF(" - env: %p", (void*)env); - LOGF(" - wrapperInstance: %p", (void*)wrapperInstance); - LOGF(" - wrapperClass: %p", (void*)wrapperClass); + if (!connected || !env || !wrapperInstance || !wrapperClass) { + LOG("Connection not properly initialized"); return SQL_ERROR; } try { - // Helper lambda to convert std::string to jstring - auto createJavaString = [this](const std::string& str) -> jstring { - if (str.empty()) { - LOG("Creating null jstring for empty input"); - return nullptr; - } - LOGF("Creating jstring for input: '%s'", str.c_str()); - return env->NewStringUTF(str.c_str()); - }; - - // Convert input strings to Java strings - LOG("Converting input parameters to Java strings"); - jstring jCatalog = createJavaString(catalogName); - jstring jSchema = createJavaString(schemaName); - jstring jTable = createJavaString(tableName); - jstring jColumn = createJavaString(columnName); - - // Find getColumns method - LOG("Looking up getColumns method"); - jmethodID getColumnsMethod = env->GetMethodID(wrapperClass, "getColumns", - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Lorg/apache/arrow/vector/VectorSchemaRoot;"); - - if (!getColumnsMethod) { - LOG("Failed to find getColumns method"); + // Build method signature + std::string signature = "("; + for (const auto& param : params) { + signature += param.getSignature(); + } + signature += ")Lorg/apache/arrow/vector/VectorSchemaRoot;"; + + LOGF("Looking for method %s with signature %s", methodName, signature.c_str()); + + jmethodID method = env->GetMethodID(wrapperClass, methodName, signature.c_str()); + if (!method) { + LOGF("Failed to find method: %s with signature: %s", methodName, signature.c_str()); return SQL_ERROR; } - // Call getColumns - LOG("Calling getColumns method"); - jobject schemaRoot = env->CallObjectMethod(wrapperInstance, getColumnsMethod, - jCatalog, jSchema, jTable, jColumn); + // Convert parameters to JNI values + std::vector jniValues; + for (const auto& param : params) { + jniValues.push_back(param.toJValue(env)); + } - // Clean up local references - LOG("Cleaning up local references"); - if (jCatalog) env->DeleteLocalRef(jCatalog); - if (jSchema) env->DeleteLocalRef(jSchema); - if (jTable) env->DeleteLocalRef(jTable); - if (jColumn) env->DeleteLocalRef(jColumn); + // Call the method + jobject schemaRoot; + if (params.empty()) { + schemaRoot = env->CallObjectMethod(wrapperInstance, method); + } else { + schemaRoot = env->CallObjectMethodA(wrapperInstance, method, jniValues.data()); + } + + // Clean up parameters + for (size_t i = 0; i < params.size(); i++) { + params[i].cleanup(env, jniValues[i]); + } if (env->ExceptionCheck()) { - LOG("Java exception detected during getColumns call"); env->ExceptionDescribe(); env->ExceptionClear(); return SQL_ERROR; } if (!schemaRoot) { - LOG("getColumns returned null schemaRoot"); + LOG("Method returned null result"); return SQL_ERROR; } - // Set up the ODBC result set columns in the statement - LOG("Setting up column result columns in Statement"); - auto columnDefs = stmt->setupColumnResultColumns(); - - // Convert Arrow VectorSchemaRoot to ODBC result set - LOG("Converting Arrow VectorSchemaRoot to ODBC result set"); - SQLRETURN result = stmt->setArrowResult(schemaRoot, columnDefs); + // Extract schema and create column descriptors + SQLRETURN ret = populateColumnDescriptors(schemaRoot, stmt); + if (!SQL_SUCCEEDED(ret)) { + return ret; + } - // Clean up schema root + // Process the Arrow data + ret = stmt->setArrowResult(schemaRoot, stmt->resultColumns); + jclass schemaRootClass = env->GetObjectClass(schemaRoot); + jmethodID closeMethod = env->GetMethodID(schemaRootClass, "close", "()V"); + if (closeMethod != nullptr) { + env->CallVoidMethod(schemaRoot, closeMethod); + } env->DeleteLocalRef(schemaRoot); - - return result; - + return ret; } catch (const std::exception& e) { - LOGF("Exception in GetColumns: %s", e.what()); - return SQL_ERROR; - } catch (...) { - LOG("Unknown exception in GetColumns"); + LOGF("Exception in %s: %s", methodName, e.what()); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + } return SQL_ERROR; } } +SQLSMALLINT Connection::mapArrowTypeToSQL(JNIEnv* env, jobject arrowType) { + jclass typeClass = env->GetObjectClass(arrowType); + jmethodID getTypeIDMethod = env->GetMethodID(typeClass, "getTypeID", + "()Lorg/apache/arrow/vector/types/pojo/ArrowType$ArrowTypeID;"); + jobject typeId = env->CallObjectMethod(arrowType, getTypeIDMethod); + + jclass enumClass = env->GetObjectClass(typeId); + jmethodID nameMethod = env->GetMethodID(enumClass, "name", "()Ljava/lang/String;"); + auto typeName = reinterpret_cast(env->CallObjectMethod(typeId, nameMethod)); + + const char* typeNameStr = env->GetStringUTFChars(typeName, nullptr); + SQLSMALLINT sqlType; + + // Map Arrow types to SQL types + if (strcmp(typeNameStr, "Int") == 0) sqlType = SQL_INTEGER; + else if (strcmp(typeNameStr, "FloatingPoint") == 0) sqlType = SQL_DOUBLE; + // else if (strcmp(typeNameStr, "Utf8") == 0) sqlType = SQL_VARCHAR; + else if (strcmp(typeNameStr, "Bool") == 0) sqlType = SQL_BIT; + else if (strcmp(typeNameStr, "Date") == 0) sqlType = SQL_TYPE_DATE; + else if (strcmp(typeNameStr, "Time") == 0) sqlType = SQL_TYPE_TIME; + else if (strcmp(typeNameStr, "Timestamp") == 0) sqlType = SQL_TYPE_TIMESTAMP; + else if (strcmp(typeNameStr, "Decimal") == 0) sqlType = SQL_DECIMAL; + else if (strcmp(typeNameStr, "Binary") == 0) sqlType = SQL_BINARY; + else sqlType = SQL_VARCHAR; // Default fallback + + env->ReleaseStringUTFChars(typeName, typeNameStr); + env->DeleteLocalRef(typeName); + env->DeleteLocalRef(typeId); + env->DeleteLocalRef(enumClass); + env->DeleteLocalRef(typeClass); + + return sqlType; +} + +SQLULEN Connection::getSQLTypeSize(SQLSMALLINT sqlType) { + switch (sqlType) { + case SQL_INTEGER: return sizeof(SQLINTEGER); + case SQL_SMALLINT: return sizeof(SQLSMALLINT); + case SQL_BIGINT: return sizeof(SQLBIGINT); + case SQL_DOUBLE: return sizeof(SQLDOUBLE); + case SQL_REAL: return sizeof(SQLREAL); + case SQL_DECIMAL: return 38; // Max precision + case SQL_BIT: return 1; + case SQL_TINYINT: return sizeof(SQLSCHAR); + case SQL_TYPE_DATE: return SQL_DATE_LEN; + case SQL_TYPE_TIME: return SQL_TIME_LEN; + case SQL_TYPE_TIMESTAMP: return SQL_TIMESTAMP_LEN; + case SQL_BINARY: + case SQL_VARBINARY: return 8000; // Max binary length + case SQL_VARCHAR: + case SQL_CHAR: return 8000; // Max string length + case SQL_WVARCHAR: + case SQL_WCHAR: return 4000; // Max Unicode string length (in characters) + default: return 8000; // Default to max string length + } +} \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp index a6dcfa0..9151650 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/main.cpp @@ -16,7 +16,7 @@ BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserv { switch (ul_reason_for_call) { case DLL_PROCESS_ATTACH: - dllPath = GetModuleDirectory(hModule); + dllPath = GetModuleDirectory(); break; case DLL_THREAD_ATTACH: case DLL_THREAD_DETACH: diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp index 81d0f20..5b2bfb6 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp +++ b/calcite-rs-jni/odbc/DDN-ODBC-Driver/src/statement.cpp @@ -1,5 +1,7 @@ // First, include Windows headers +#include #include +#define WIN32_LEAN_AND_MEAN // Then SQL headers in this specific order #include @@ -13,38 +15,12 @@ // Your project headers #include "../include/statement.hpp" +#include +#include + #include "connection.hpp" #include "../include/logging.hpp" -static const std::vector TABLE_COLUMNS = { - {"TABLE_CAT", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE}, - {"TABLE_SCHEM", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE}, - {"TABLE_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS}, - {"TABLE_TYPE", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS}, - {"REMARKS", 0, SQL_VARCHAR, 254, 0, SQL_NULLABLE} -}; - -static const std::vector COLUMN_COLUMNS = { - {"TABLE_CAT", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE}, - {"TABLE_SCHEM", 0, SQL_VARCHAR, 128, 0, SQL_NULLABLE}, - {"TABLE_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS}, - {"COLUMN_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS}, - {"DATA_TYPE", 0, SQL_SMALLINT, 5, 0, SQL_NO_NULLS}, - {"TYPE_NAME", 0, SQL_VARCHAR, 128, 0, SQL_NO_NULLS}, - {"COLUMN_SIZE", 0, SQL_INTEGER, 10, 0, SQL_NULLABLE}, - {"BUFFER_LENGTH", 0, SQL_INTEGER, 10, 0, SQL_NULLABLE}, - {"DECIMAL_DIGITS", 0, SQL_SMALLINT, 5, 0, SQL_NULLABLE}, - {"NUM_PREC_RADIX", 0, SQL_SMALLINT, 5, 0, SQL_NULLABLE}, - {"NULLABLE", 0, SQL_SMALLINT, 5, 0, SQL_NO_NULLS}, - {"REMARKS", 0, SQL_VARCHAR, 254, 0, SQL_NULLABLE}, - {"COLUMN_DEF", 0, SQL_VARCHAR, 254, 0, SQL_NULLABLE}, - {"SQL_DATA_TYPE", 0, SQL_SMALLINT, 5, 0, SQL_NO_NULLS}, - {"SQL_DATETIME_SUB", 0, SQL_SMALLINT, 5, 0, SQL_NULLABLE}, - {"CHAR_OCTET_LENGTH", 0, SQL_INTEGER, 10, 0, SQL_NULLABLE}, - {"ORDINAL_POSITION", 0, SQL_INTEGER, 10, 0, SQL_NO_NULLS}, - {"IS_NULLABLE", 0, SQL_VARCHAR, 3, 0, SQL_NO_NULLS} -}; - SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vector& columnDescriptors) { LOGF("Starting setArrowResult with %zu columns", columnDescriptors.size()); @@ -55,21 +31,58 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->GetObjectClass(schemaRoot); + if (!rootClass) { + LOG("Failed to get root class"); + return SQL_ERROR; + } jmethodID getRowCountMethod = conn->env->GetMethodID(rootClass, "getRowCount", "()I"); + if (!getRowCountMethod) { + LOG("Failed to get getRowCount method"); + conn->env->DeleteLocalRef(rootClass); + return SQL_ERROR; + } jmethodID getFieldVectorsMethod = conn->env->GetMethodID(rootClass, "getFieldVectors", "()Ljava/util/List;"); + if (!getFieldVectorsMethod) { + LOG("Failed to get getFieldVectors method"); + conn->env->DeleteLocalRef(rootClass); + return SQL_ERROR; + } jint rowCount = conn->env->CallIntMethod(schemaRoot, getRowCountMethod); LOGF("Row count: %d", rowCount); jobject vectorsList = conn->env->CallObjectMethod(schemaRoot, getFieldVectorsMethod); + if (!vectorsList) { + LOG("Failed to get vector list"); + conn->env->DeleteLocalRef(rootClass); + return SQL_ERROR; + } jclass listClass = conn->env->GetObjectClass(vectorsList); + if (!listClass) { + LOG("Failed to get list class"); + conn->env->DeleteLocalRef(vectorsList); + conn->env->DeleteLocalRef(rootClass); + return SQL_ERROR; + } jmethodID getMethod = conn->env->GetMethodID(listClass, "get", "(I)Ljava/lang/Object;"); + if (!getMethod) { + LOG("Failed to get get method"); + conn->env->DeleteLocalRef(listClass); + conn->env->DeleteLocalRef(vectorsList); + conn->env->DeleteLocalRef(rootClass); + return SQL_ERROR; + } jmethodID sizeMethod = conn->env->GetMethodID(listClass, "size", "()I"); + if (!sizeMethod) { + LOG("Failed to get size method"); + conn->env->DeleteLocalRef(listClass); + conn->env->DeleteLocalRef(vectorsList); + conn->env->DeleteLocalRef(rootClass); + return SQL_ERROR; + } jint vectorCount = conn->env->CallIntMethod(vectorsList, sizeMethod); LOGF("Vector count: %d", vectorCount); @@ -90,14 +103,32 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->GetObjectClass(fieldVector); + if (!vectorClass) { + LOGF("Failed to get vector class for column %d", col); + conn->env->DeleteLocalRef(fieldVector); + continue; + } jmethodID getObjectMethod = conn->env->GetMethodID(vectorClass, "getObject", "(I)Ljava/lang/Object;"); + if (!getObjectMethod) { + LOGF("Failed to get getObject method for column %d", col); + conn->env->DeleteLocalRef(vectorClass); + conn->env->DeleteLocalRef(fieldVector); + continue; + } jmethodID isNullMethod = conn->env->GetMethodID(vectorClass, "isNull", "(I)Z"); + if (!isNullMethod) { + LOGF("Failed to get isNull method for column %d", col); + conn->env->DeleteLocalRef(vectorClass); + conn->env->DeleteLocalRef(fieldVector); + continue; + } for (jint row = 0; row < rowCount; row++) { // Check if the value is null jboolean isNull = conn->env->CallBooleanMethod(fieldVector, isNullMethod, row); if (isNull || conn->env->ExceptionCheck()) { if (conn->env->ExceptionCheck()) { + conn->env->ExceptionDescribe(); conn->env->ExceptionClear(); } LOGF("Null value at row %d, col %d", row, col); @@ -117,12 +148,21 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->FindClass("java/lang/String"); + if (!stringClass) { + LOGF("Failed to find String class at row %d, col %d", row, col); + conn->env->DeleteLocalRef(value); + continue; + } jmethodID toStringMethod = conn->env->GetMethodID(stringClass, "toString", "()Ljava/lang/String;"); - jstring strValue = (jstring)conn->env->CallObjectMethod(value, toStringMethod); + if (!toStringMethod) { + LOGF("Failed to get toString method at row %d, col %d", row, col); + conn->env->DeleteLocalRef(stringClass); + conn->env->DeleteLocalRef(value); + continue; + } - if (strValue) { - const char* chars = conn->env->GetStringUTFChars(strValue, nullptr); - if (chars) { + if (auto strValue = reinterpret_cast(conn->env->CallObjectMethod(value, toStringMethod))) { + if (const char* chars = conn->env->GetStringUTFChars(strValue, nullptr)) { resultData[row][col].isNull = false; resultData[row][col].data = chars; conn->env->ReleaseStringUTFChars(strValue, chars); @@ -154,15 +194,24 @@ SQLRETURN Statement::setArrowResult(jobject schemaRoot, const std::vectorenv->ExceptionCheck()) { conn->env->ExceptionDescribe(); conn->env->ExceptionClear(); } clearResults(); return SQL_ERROR; + } catch (...) { + LOG("Unknown exception in setArrowResult"); + if (conn != nullptr && conn->env != nullptr) { + if (conn->env->ExceptionCheck()) { + conn->env->ExceptionDescribe(); + conn->env->ExceptionClear(); + } + } + clearResults(); + return SQL_ERROR; } } @@ -197,57 +246,18 @@ bool Statement::hasData() const { return hasResult && currentRow < resultData.size(); } -std::vector Statement::setupColumnResultColumns() { - // Clear any existing result set - clearResults(); - - // Create column descriptors for the COLUMNS result set - resultColumns.clear(); - for (const auto& colDef : COLUMN_COLUMNS) { - ColumnDesc col{}; - col.name = colDef.name; - col.nameLength = static_cast(strlen(colDef.name)); - col.sqlType = colDef.sqlType; - col.columnSize = colDef.columnSize; - col.decimalDigits = colDef.decimalDigits; - col.nullable = colDef.nullable; - resultColumns.push_back(col); - } - - LOGF("Set up %zu column result columns", resultColumns.size()); - return COLUMN_COLUMNS; -} - -Statement::Statement(Connection* connection) : conn(connection) { +Statement::Statement(Connection *connection) : conn(connection) { hasResult = false; currentRow = 0; resultData.clear(); } -std::vector Statement::setupTableResultColumns() { - clearResults(); - resultColumns.clear(); - - for (const auto& colDef : TABLE_COLUMNS) { - ColumnDesc col{}; - col.name = colDef.name; - col.nameLength = static_cast(strlen(colDef.name)); - col.sqlType = colDef.sqlType; - col.columnSize = colDef.columnSize; - col.decimalDigits = colDef.decimalDigits; - col.nullable = colDef.nullable; - LOGF("Setting up column %s with SQL type %d", col.name, col.sqlType); - resultColumns.push_back(col); - } - return TABLE_COLUMNS; -} - // Additional helper method for fetching data from the result set SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType, - SQLPOINTER targetValue, SQLLEN bufferLength, - SQLLEN* strLengthOrIndicator) { + SQLPOINTER targetValue, SQLLEN bufferLength, + SQLLEN *strLengthOrIndicator) { LOGF("getData called for column %d", colNum); - + // Validate state and parameters if (!hasResult || currentRow == 0 || currentRow > resultData.size() || colNum == 0 || colNum > resultColumns.size()) { @@ -255,7 +265,7 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType, return SQL_ERROR; } - const auto& colData = resultData[currentRow - 1][colNum - 1]; + const auto &colData = resultData[currentRow - 1][colNum - 1]; LOGF("Fetching data for row %d, column %d", currentRow - 1, colNum - 1); // Handle NULL values @@ -271,11 +281,11 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType, switch (targetType) { case SQL_C_WCHAR: { LOGF("Converting to WCHAR: '%s'", colData.data.c_str()); - + const int requiredSize = MultiByteToWideChar( - CP_UTF8, 0, colData.data.c_str(), -1, nullptr, 0 - ) * sizeof(WCHAR); - + CP_UTF8, 0, colData.data.c_str(), -1, nullptr, 0 + ) * sizeof(WCHAR); + if (strLengthOrIndicator) { *strLengthOrIndicator = requiredSize - sizeof(WCHAR); } @@ -301,7 +311,7 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType, // Check for truncation if (charsWritten == maxChars) { - static_cast(targetValue)[maxChars - 1] = L'\0'; + static_cast(targetValue)[maxChars - 1] = L'\0'; return SQL_SUCCESS_WITH_INFO; } @@ -315,8 +325,200 @@ SQLRETURN Statement::getData(SQLUSMALLINT colNum, SQLSMALLINT targetType, } void Statement::clearResults() { + LOG("Called clearResults()"); hasResult = false; currentRow = 0; resultData.clear(); - resultColumns.clear(); + boundParams.clear(); +} + +SQLRETURN Statement::bindParameter(SQLUSMALLINT parameterNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN *strLen_or_IndPtr) { + LOGF("Binding parameter %d of type %d", parameterNumber, valueType); + + if (!conn || !conn->env) { + LOG("Invalid connection or environment"); + return SQL_ERROR; + } + + // Parameter numbers are 1-based + if (parameterNumber < 1) { + LOG("Invalid parameter number"); + return SQL_ERROR; + } + + // Check for null indicator + if (strLen_or_IndPtr && *strLen_or_IndPtr == SQL_NULL_DATA) { + // Handle NULL parameter - could use a special JniParam constructor for NULL + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + // You might want to add a setNull method to JniParam + return SQL_SUCCESS; + } + + try { + // Convert ODBC parameter to JniParam based on valueType + switch (valueType) { + case SQL_C_CHAR: { + if (!parameterValuePtr) return SQL_ERROR; + std::string value(static_cast(parameterValuePtr)); + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + boundParams[parameterNumber - 1] = JniParam(value); + break; + } + + case SQL_C_WCHAR: { + if (!parameterValuePtr) return SQL_ERROR; + wchar_t* wstr = static_cast(parameterValuePtr); + // Convert wide string to UTF-8 + int requiredSize = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, nullptr, 0, nullptr, nullptr); + if (requiredSize == 0) return SQL_ERROR; + + std::string utf8str(requiredSize, '\0'); + if (WideCharToMultiByte(CP_UTF8, 0, wstr, -1, &utf8str[0], requiredSize, nullptr, nullptr) == 0) { + return SQL_ERROR; + } + utf8str.resize(strlen(utf8str.c_str())); // Remove trailing null + + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + boundParams[parameterNumber - 1] = JniParam(utf8str); + break; + } + + case SQL_C_LONG: + case SQL_C_SLONG: { + if (!parameterValuePtr) return SQL_ERROR; + int value = *static_cast(parameterValuePtr); + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + boundParams[parameterNumber - 1] = JniParam(value); + break; + } + + case SQL_C_FLOAT: { + if (!parameterValuePtr) return SQL_ERROR; + float value = *static_cast(parameterValuePtr); + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + boundParams[parameterNumber - 1] = JniParam(value); + break; + } + + case SQL_C_DOUBLE: { + if (!parameterValuePtr) return SQL_ERROR; + double value = *static_cast(parameterValuePtr); + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + boundParams[parameterNumber - 1] = JniParam(value); + break; + } + + case SQL_C_BIT: { + if (!parameterValuePtr) return SQL_ERROR; + bool value = (*static_cast(parameterValuePtr)) != 0; + if (parameterNumber > boundParams.size()) { + boundParams.resize(parameterNumber); + } + boundParams[parameterNumber - 1] = JniParam(value); + break; + } + + default: + LOGF("Unsupported parameter type: %d", valueType); + return SQL_ERROR; + } + + return SQL_SUCCESS; + } + catch (const std::exception& e) { + LOGF("Exception in bindParameter: %s", e.what()); + return SQL_ERROR; + } } + +std::string Statement::escapeString(const std::string& str) const { + std::string escaped; + escaped.reserve(str.length() + str.length()/8); // Reserve extra space for escapes + + for (char c : str) { + switch (c) { + case '\'': escaped += "''"; break; // Double single quotes for SQL + default: escaped += c; break; + } + } + return escaped; +} + +std::string Statement::buildInterpolatedQuery() const { + std::string result = originalQuery; + + // Find all ? parameters and replace them + size_t paramIndex = 0; + size_t pos = 0; + + while ((pos = result.find('?', pos)) != std::string::npos) { + if (paramIndex >= boundParams.size()) { + throw std::runtime_error("Not enough parameters bound for query"); + } + + const auto& param = boundParams[paramIndex]; + std::string replacement; + + // Convert parameter to string representation + switch (param.getType()) { + case JniParam::Type::String: + replacement = "'" + escapeString(param.getString()) + "'"; + break; + + case JniParam::Type::StringArray: { + replacement = "("; + bool first = true; + for (const auto& str : param.getStringArray()) { + if (!first) replacement += ","; + replacement += "'" + escapeString(str) + "'"; + first = false; + } + replacement += ")"; + break; + } + + case JniParam::Type::Integer: + replacement = std::to_string(param.getInt()); + break; + + case JniParam::Type::Float: + replacement = std::to_string(param.getFloat()); + break; + + case JniParam::Type::Double: + replacement = std::to_string(param.getDouble()); + break; + + case JniParam::Type::Boolean: + replacement = param.getBool() ? "1" : "0"; + break; + } + + result.replace(pos, 1, replacement); + pos += replacement.length(); + paramIndex++; + } + + LOGF("Interpolated query: %s", result.c_str()); + return result; +} \ No newline at end of file diff --git a/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs b/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs index f27d29d..3e1e37d 100755 --- a/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs +++ b/calcite-rs-jni/odbc/DDN-ODBC-Tester/Program.cs @@ -129,7 +129,26 @@ static void TestBasicQueries(OdbcConnection conn) if (tables.Rows.Count > 0) { string tableName = tables.Rows[0]["TABLE_NAME"].ToString(); - TestMetadata(conn, "Basic Select", $"SELECT * FROM {tableName} LIMIT 5"); + using var cmd = conn.CreateCommand(); + cmd.CommandText = $"SELECT * FROM \"{tableName}\" LIMIT 5"; + try + { + using var reader = cmd.ExecuteReader(); + var columns = Enumerable.Range(0, reader.FieldCount) + .Select(i => reader.GetName(i)); + Console.WriteLine(string.Join("\t", columns)); + + while (reader.Read()) + { + var values = Enumerable.Range(0, reader.FieldCount) + .Select(i => reader[i]?.ToString() ?? "NULL"); + Console.WriteLine(string.Join("\t", values)); + } + } + catch (Exception ex) + { + Console.WriteLine($"Parameterized query not supported: {ex.Message}"); + } } } @@ -177,14 +196,16 @@ static void TestMetadataRetrieval(OdbcConnection conn) cmd.CommandText = $"SELECT * FROM {tableName} LIMIT 1"; using var reader = cmd.ExecuteReader(); - for (int i = 0; i < reader.FieldCount; i++) - { - Console.WriteLine($"Column {i}:"); - Console.WriteLine($" Name: {reader.GetName(i)}"); - Console.WriteLine($" Type: {reader.GetFieldType(i)}"); - Console.WriteLine($" Precision: {reader.GetSchemaTable().Rows[i]["NumericPrecision"]}"); - Console.WriteLine($" Scale: {reader.GetSchemaTable().Rows[i]["NumericScale"]}"); - Console.WriteLine($" IsNullable: {reader.GetSchemaTable().Rows[i]["AllowDBNull"]}"); + while(reader.Read()) { + for (int i = 0; i < reader.FieldCount; i++) + { + Console.WriteLine($"Column {i}:"); + Console.WriteLine($" Name: {reader.GetName(i)}"); + Console.WriteLine($" Type: {reader.GetFieldType(i)}"); + Console.WriteLine($" Precision: {reader.GetSchemaTable().Rows[i]["NumericPrecision"]}"); + Console.WriteLine($" Scale: {reader.GetSchemaTable().Rows[i]["NumericScale"]}"); + Console.WriteLine($" IsNullable: {reader.GetSchemaTable().Rows[i]["AllowDBNull"]}"); + } } } } diff --git a/calcite-rs-jni/odbc/build-log.txt b/calcite-rs-jni/odbc/build-log.txt index 3f5378a..fc52de3 100755 Binary files a/calcite-rs-jni/odbc/build-log.txt and b/calcite-rs-jni/odbc/build-log.txt differ